diff --git a/.gitignore b/.gitignore index 51a9ca1d..315674fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.* +!.gitignore *.pyc .*.swp *.egg @@ -6,4 +8,8 @@ docs/_build build/ dist/ mongoengine.egg-info/ -env/ \ No newline at end of file +env/ +.settings +.project +.pydevproject +tests/bugfix.py diff --git a/AUTHORS b/AUTHORS index 93fe819e..b342830a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,5 +1,69 @@ +The PRIMARY AUTHORS are (and/or have been): + Harry Marr Matt Dennewitz Deepak Thukral Florian Schlachter Steve Challis +Ross Lawley +Wilson Júnior +Dan Crosta https://github.com/dcrosta + +CONTRIBUTORS + +Dervived from the git logs, inevitably incomplete but all of whom and others +have submitted patches, reported bugs and generally helped make MongoEngine +that much better: + + * Harry Marr + * Ross Lawley + * blackbrrr + * Florian Schlachter + * Vincent Driessen + * Steve Challis + * flosch + * Deepak Thukral + * Colin Howe + * Wilson Júnior + * Alistair Roche + * Dan Crosta + * Viktor Kerkez + * Stephan Jaekel + * Rached Ben Mustapha + * Greg Turner + * Daniel Hasselrot + * Mircea Pasoi + * Matt Chisholm + * James Punteney + * TimothéePeignier + * Stuart Rackham + * Serge Matveenko + * Matt Dennewitz + * Don Spaulding + * Ales Zoulek + * sshwsfc + * sib + * Samuel Clay + * Nick Vlku + * martin + * Flavio Amieiro + * Анхбаяр Лхагвадорж + * Zak Johnson + * Victor Farazdagi + * vandersonmota + * Theo Julienne + * sp + * Slavi Pantaleev + * Richard Henry + * Nicolas Perriault + * Nick Vlku Jr + * Michael Henson + * Leo Honkanen + * kuno + * Josh Ourisman + * Jaime + * Igor Ivanov + * Gregg Lind + * Gareth Lloyd + * Albert Choi + * John Arnfield diff --git a/docs/apireference.rst b/docs/apireference.rst index 34d4536d..2442803d 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -41,6 +41,8 @@ Fields .. autoclass:: mongoengine.URLField +.. autoclass:: mongoengine.EmailField + .. autoclass:: mongoengine.IntField .. autoclass:: mongoengine.FloatField @@ -51,12 +53,16 @@ Fields .. autoclass:: mongoengine.DateTimeField +.. autoclass:: mongoengine.ComplexDateTimeField + .. autoclass:: mongoengine.EmbeddedDocumentField .. autoclass:: mongoengine.DictField .. autoclass:: mongoengine.ListField +.. autoclass:: mongoengine.SortedListField + .. autoclass:: mongoengine.BinaryField .. autoclass:: mongoengine.ObjectIdField diff --git a/docs/changelog.rst b/docs/changelog.rst index d7c6fe85..04235db6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,72 @@ Changelog ========= +Changes in dev +============== + +- Added InvalidDocumentError - so Document core methods can't be overwritten +- Added GenericEmbeddedDocument - so you can embed any type of embeddable document +- Added within_polygon support - for those with mongodb 1.9 +- Updated sum / average to use map_reduce as db.eval doesn't work in sharded environments +- Added where() - filter to allowing users to specify query expressions as Javascript +- Added SequenceField - for creating sequential counters +- Added update() convenience method to a document +- Added cascading saves - so changes to Referenced documents are saved on .save() +- Added select_related() support +- Added support for the positional operator +- Updated geo index checking to be recursive and check in embedded documents +- Updated default collection naming convention +- Added Document Mixin support +- Fixed queryet __repr__ mid iteration +- Added hint() support, so cantell Mongo the proper index to use for the query +- Fixed issue with inconsitent setting of _cls breaking inherited referencing +- Added help_text and verbose_name to fields to help with some form libs +- Updated item_frequencies to handle embedded document lookups +- Added delta tracking now only sets / unsets explicitly changed fields +- Fixed saving so sets updated values rather than overwrites +- Added ComplexDateTimeField - Handles datetimes correctly with microseconds +- Added ComplexBaseField - for improved flexibility and performance +- Added get_FIELD_display() method for easy choice field displaying +- Added queryset.slave_okay(enabled) method +- Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable +- Added insert method for bulk inserts +- Added blinker signal support +- Added query_counter context manager for tests +- Added map_reduce method item_frequencies and set as default (as db.eval doesn't work in sharded environments) +- Added inline_map_reduce option to map_reduce +- Updated connection exception so it provides more info on the cause. +- Added searching multiple levels deep in ``DictField`` +- Added ``DictField`` entries containing strings to use matching operators +- Added ``MapField``, similar to ``DictField`` +- Added Abstract Base Classes +- Added Custom Objects Managers +- Added sliced subfields updating +- Added ``NotRegistered`` exception if dereferencing ``Document`` not in the registry +- Added a write concern for ``save``, ``update``, ``update_one`` and ``get_or_create`` +- Added slicing / subarray fetching controls +- Fixed various unique index and other index issues +- Fixed threaded connection issues +- Added spherical geospatial query operators +- Updated queryset to handle latest version of pymongo + map_reduce now requires an output. +- Added ``Document`` __hash__, __ne__ for pickling +- Added ``FileField`` optional size arg for read method +- Fixed ``FileField`` seek and tell methods for reading files +- Added ``QuerySet.clone`` to support copying querysets +- Fixed item_frequencies when using name thats the same as a native js function +- Added reverse delete rules +- Fixed issue with unset operation +- Fixed Q-object bug +- Added ``QuerySet.all_fields`` resets previous .only() and .exclude() +- Added ``QuerySet.exclude`` +- Added django style choices +- Fixed order and filter issue +- Added ``QuerySet.only`` subfield support +- Added creation_counter to ``BaseField`` allowing fields to be sorted in the + way the user has specified them +- Fixed various errors +- Added many tests + Changes in v0.4 =============== - Added ``GridFSStorage`` Django storage backend @@ -32,7 +98,7 @@ Changes in v0.3 =============== - Added MapReduce support - Added ``contains``, ``startswith`` and ``endswith`` query operators (and - case-insensitive versions that are prefixed with 'i') + case-insensitive versions that are prefixed with 'i') - Deprecated fields' ``name`` parameter, replaced with ``db_field`` - Added ``QuerySet.only`` for only retrieving specific fields - Added ``QuerySet.in_bulk()`` for bulk querying using ids @@ -79,7 +145,7 @@ Changes in v0.2 =============== - Added ``Q`` class for building advanced queries - Added ``QuerySet`` methods for atomic updates to documents -- Fields may now specify ``unique=True`` to enforce uniqueness across a +- Fields may now specify ``unique=True`` to enforce uniqueness across a collection - Added option for default document ordering - Fixed bug in index definitions @@ -87,7 +153,7 @@ Changes in v0.2 Changes in v0.1.3 ================= - Added Django authentication backend -- Added ``Document.meta`` support for indexes, which are ensured just before +- Added ``Document.meta`` support for indexes, which are ensured just before querying takes place - A few minor bugfixes diff --git a/docs/conf.py b/docs/conf.py index 2541f49a..03ba047f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,7 +38,7 @@ master_doc = 'index' # General information about the project. project = u'MongoEngine' -copyright = u'2009-2010, Harry Marr' +copyright = u'2009-2011, Harry Marr' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 4c9de931..fd005e40 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -4,14 +4,14 @@ Defining documents In MongoDB, a **document** is roughly equivalent to a **row** in an RDBMS. When working with relational databases, rows are stored in **tables**, which have a strict **schema** that the rows follow. MongoDB stores documents in -**collections** rather than tables - the principle difference is that no schema -is enforced at a database level. +**collections** rather than tables - the principle difference is that no schema +is enforced at a database level. Defining a document's schema ============================ MongoEngine allows you to define schemata for documents as this helps to reduce coding errors, and allows for utility methods to be defined on fields which may -be present. +be present. To define a schema for a document, create a class that inherits from :class:`~mongoengine.Document`. Fields are specified by adding **field @@ -19,7 +19,7 @@ objects** as class attributes to the document class:: from mongoengine import * import datetime - + class Page(Document): title = StringField(max_length=200, required=True) date_modified = DateTimeField(default=datetime.datetime.now) @@ -31,31 +31,35 @@ By default, fields are not required. To make a field mandatory, set the validation constraints available (such as :attr:`max_length` in the example above). Fields may also take default values, which will be used if a value is not provided. Default values may optionally be a callable, which will be called -to retrieve the value (such as in the above example). The field types available +to retrieve the value (such as in the above example). The field types available are as follows: * :class:`~mongoengine.StringField` * :class:`~mongoengine.URLField` +* :class:`~mongoengine.EmailField` * :class:`~mongoengine.IntField` * :class:`~mongoengine.FloatField` * :class:`~mongoengine.DecimalField` * :class:`~mongoengine.DateTimeField` +* :class:`~mongoengine.ComplexDateTimeField` * :class:`~mongoengine.ListField` +* :class:`~mongoengine.SortedListField` * :class:`~mongoengine.DictField` +* :class:`~mongoengine.MapField` * :class:`~mongoengine.ObjectIdField` -* :class:`~mongoengine.EmbeddedDocumentField` * :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.GenericReferenceField` +* :class:`~mongoengine.EmbeddedDocumentField` +* :class:`~mongoengine.GenericEmbeddedDocumentField` * :class:`~mongoengine.BooleanField` * :class:`~mongoengine.FileField` -* :class:`~mongoengine.EmailField` -* :class:`~mongoengine.SortedListField` * :class:`~mongoengine.BinaryField` * :class:`~mongoengine.GeoPointField` +* :class:`~mongoengine.SequenceField` Field arguments --------------- -Each field type can be customized by keyword arguments. The following keyword +Each field type can be customized by keyword arguments. The following keyword arguments can be set on all fields: :attr:`db_field` (Default: None) @@ -74,7 +78,7 @@ arguments can be set on all fields: The definion of default parameters follow `the general rules on Python `__, - which means that some care should be taken when dealing with default mutable objects + which means that some care should be taken when dealing with default mutable objects (like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`):: class ExampleFirst(Document): @@ -89,7 +93,7 @@ arguments can be set on all fields: # This can make an .append call to add values to the default (and all the following objects), # instead to just an object values = ListField(IntField(), default=[1,2,3]) - + :attr:`unique` (Default: False) When True, no documents in the collection will have the same value for this @@ -104,7 +108,13 @@ arguments can be set on all fields: :attr:`choices` (Default: None) An iterable of choices to which the value of this field should be limited. - + +:attr:`help_text` (Default: None) + Optional help text to output with the field - used by form libraries + +:attr:`verbose` (Default: None) + Optional human-readable name for the field - used by form libraries + List fields ----------- @@ -121,7 +131,7 @@ Embedded documents MongoDB has the ability to embed documents within other documents. Schemata may be defined for these embedded documents, just as they may be for regular documents. To create an embedded document, just define a document as usual, but -inherit from :class:`~mongoengine.EmbeddedDocument` rather than +inherit from :class:`~mongoengine.EmbeddedDocument` rather than :class:`~mongoengine.Document`:: class Comment(EmbeddedDocument): @@ -144,7 +154,7 @@ Often, an embedded document may be used instead of a dictionary -- generally this is recommended as dictionaries don't support validation or custom field types. However, sometimes you will not know the structure of what you want to store; in this situation a :class:`~mongoengine.DictField` is appropriate:: - + class SurveyResponse(Document): date = DateTimeField() user = ReferenceField(User) @@ -152,16 +162,19 @@ store; in this situation a :class:`~mongoengine.DictField` is appropriate:: survey_response = SurveyResponse(date=datetime.now(), user=request.user) response_form = ResponseForm(request.POST) - survey_response.answers = response_form.cleaned_data() + survey_response.answers = response_form.cleaned_data() survey_response.save() +Dictionaries can store complex data, other dictionaries, lists, references to +other objects, so are the most flexible field type available. + Reference fields ---------------- References may be stored to other documents in the database using the :class:`~mongoengine.ReferenceField`. Pass in another document class as the first argument to the constructor, then simply assign document objects to the field:: - + class User(Document): name = StringField() @@ -193,19 +206,72 @@ as the constructor's argument:: class ProfilePage(Document): content = StringField() + +Dealing with deletion of referred documents +''''''''''''''''''''''''''''''''''''''''''' +By default, MongoDB doesn't check the integrity of your data, so deleting +documents that other documents still hold references to will lead to consistency +issues. Mongoengine's :class:`ReferenceField` adds some functionality to +safeguard against these kinds of database integrity problems, providing each +reference with a delete rule specification. A delete rule is specified by +supplying the :attr:`reverse_delete_rule` attributes on the +:class:`ReferenceField` definition, like this:: + + class Employee(Document): + ... + profile_page = ReferenceField('ProfilePage', reverse_delete_rule=mongoengine.NULLIFY) + +The declaration in this example means that when an :class:`Employee` object is +removed, the :class:`ProfilePage` that belongs to that employee is removed as +well. If a whole batch of employees is removed, all profile pages that are +linked are removed as well. + +Its value can take any of the following constants: + +:const:`mongoengine.DO_NOTHING` + This is the default and won't do anything. Deletes are fast, but may cause + database inconsistency or dangling references. +:const:`mongoengine.DENY` + Deletion is denied if there still exist references to the object being + deleted. +:const:`mongoengine.NULLIFY` + Any object's fields still referring to the object being deleted are removed + (using MongoDB's "unset" operation), effectively nullifying the relationship. +:const:`mongoengine.CASCADE` + Any object containing fields that are refererring to the object being deleted + are deleted first. + + +.. warning:: + A safety note on setting up these delete rules! Since the delete rules are + not recorded on the database level by MongoDB itself, but instead at runtime, + in-memory, by the MongoEngine module, it is of the upmost importance + that the module that declares the relationship is loaded **BEFORE** the + delete is invoked. + + If, for example, the :class:`Employee` object lives in the + :mod:`payroll` app, and the :class:`ProfilePage` in the :mod:`people` + app, it is extremely important that the :mod:`people` app is loaded + before any employee is removed, because otherwise, MongoEngine could + never know this relationship exists. + + In Django, be sure to put all apps that have such delete rule declarations in + their :file:`models.py` in the :const:`INSTALLED_APPS` tuple. + + Generic reference fields '''''''''''''''''''''''' A second kind of reference field also exists, :class:`~mongoengine.GenericReferenceField`. This allows you to reference any -kind of :class:`~mongoengine.Document`, and hence doesn't take a +kind of :class:`~mongoengine.Document`, and hence doesn't take a :class:`~mongoengine.Document` subclass as a constructor argument:: class Link(Document): url = StringField() - + class Post(Document): title = StringField() - + class Bookmark(Document): bookmark_object = GenericReferenceField() @@ -219,9 +285,10 @@ kind of :class:`~mongoengine.Document`, and hence doesn't take a Bookmark(bookmark_object=post).save() .. note:: + Using :class:`~mongoengine.GenericReferenceField`\ s is slightly less efficient than the standard :class:`~mongoengine.ReferenceField`\ s, so if - you will only be referencing one document type, prefer the standard + you will only be referencing one document type, prefer the standard :class:`~mongoengine.ReferenceField`. Uniqueness constraints @@ -229,7 +296,7 @@ Uniqueness constraints MongoEngine allows you to specify that a field should be unique across a collection by providing ``unique=True`` to a :class:`~mongoengine.Field`\ 's constructor. If you try to save a document that has the same value for a unique -field as a document that is already in the database, a +field as a document that is already in the database, a :class:`~mongoengine.OperationError` will be raised. You may also specify multi-field uniqueness constraints by using :attr:`unique_with`, which may be either a single field name, or a list or tuple of field names:: @@ -241,14 +308,14 @@ either a single field name, or a list or tuple of field names:: Skipping Document validation on save ------------------------------------ -You can also skip the whole document validation process by setting -``validate=False`` when caling the :meth:`~mongoengine.document.Document.save` +You can also skip the whole document validation process by setting +``validate=False`` when caling the :meth:`~mongoengine.document.Document.save` method:: class Recipient(Document): name = StringField() email = EmailField() - + recipient = Recipient(name='admin', email='root@localhost') recipient.save() # will raise a ValidationError while recipient.save(validate=False) # won't @@ -276,7 +343,7 @@ A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying stored in the collection, and :attr:`max_size` is the maximum size of the collection in bytes. If :attr:`max_size` is not specified and :attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB). -The following example shows a :class:`Log` document that will be limited to +The following example shows a :class:`Log` document that will be limited to 1000 entries and 2MB of disk space:: class Log(Document): @@ -288,9 +355,10 @@ Indexes You can specify indexes on collections to make querying faster. This is done by creating a list of index specifications called :attr:`indexes` in the :attr:`~mongoengine.Document.meta` dictionary, where an index specification may -either be a single field name, or a tuple containing multiple field names. A -direction may be specified on fields by prefixing the field name with a **+** -or a **-** sign. Note that direction only matters on multi-field indexes. :: +either be a single field name, a tuple containing multiple field names, or a +dictionary containing a full index definition. A direction may be specified on +fields by prefixing the field name with a **+** or a **-** sign. Note that +direction only matters on multi-field indexes. :: class Page(Document): title = StringField() @@ -299,10 +367,26 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: 'indexes': ['title', ('title', '-rating')] } +If a dictionary is passed then the following options are available: + +:attr:`fields` (Default: None) + The fields to index. Specified in the same format as described above. + +:attr:`types` (Default: True) + Whether the index should have the :attr:`_types` field added automatically + to the start of the index. + +:attr:`sparse` (Default: False) + Whether the index should be sparse. + +:attr:`unique` (Default: False) + Whether the index should be sparse. + .. note:: - Geospatial indexes will be automatically created for all + + Geospatial indexes will be automatically created for all :class:`~mongoengine.GeoPointField`\ s - + Ordering ======== A default ordering can be specified for your @@ -324,7 +408,7 @@ subsequent calls to :meth:`~mongoengine.queryset.QuerySet.order_by`. :: blog_post_1 = BlogPost(title="Blog Post #1") blog_post_1.published_date = datetime(2010, 1, 5, 0, 0 ,0) - blog_post_2 = BlogPost(title="Blog Post #2") + blog_post_2 = BlogPost(title="Blog Post #2") blog_post_2.published_date = datetime(2010, 1, 6, 0, 0 ,0) blog_post_3 = BlogPost(title="Blog Post #3") @@ -336,7 +420,7 @@ subsequent calls to :meth:`~mongoengine.queryset.QuerySet.order_by`. :: # get the "first" BlogPost using default ordering # from BlogPost.meta.ordering - latest_post = BlogPost.objects.first() + latest_post = BlogPost.objects.first() assert latest_post.title == "Blog Post #3" # override default ordering, order BlogPosts by "published_date" @@ -365,7 +449,7 @@ Working with existing data To enable correct retrieval of documents involved in this kind of heirarchy, two extra attributes are stored on each document in the database: :attr:`_cls` and :attr:`_types`. These are hidden from the user through the MongoEngine -interface, but may not be present if you are trying to use MongoEngine with +interface, but may not be present if you are trying to use MongoEngine with an existing database. For this reason, you may disable this inheritance mechansim, removing the dependency of :attr:`_cls` and :attr:`_types`, enabling you to work with existing databases. To disable inheritance on a document diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index 7b5d165b..317bfef1 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -4,12 +4,12 @@ Documents instances To create a new document object, create an instance of the relevant document class, providing values for its fields as its constructor keyword arguments. You may provide values for any of the fields on the document:: - + >>> page = Page(title="Test Page") >>> page.title 'Test Page' -You may also assign values to the document's fields using standard object +You may also assign values to the document's fields using standard object attribute syntax:: >>> page.title = "Example Page" @@ -18,10 +18,22 @@ attribute syntax:: Saving and deleting documents ============================= -To save the document to the database, call the -:meth:`~mongoengine.Document.save` method. If the document does not exist in -the database, it will be created. If it does already exist, it will be -updated. +MongoEngine tracks changes to documents to provide efficient saving. To save +the document to the database, call the :meth:`~mongoengine.Document.save` method. +If the document does not exist in the database, it will be created. If it does +already exist, then any changes will be updated atomically. For example:: + + >>> page = Page(title="Test Page") + >>> page.save() # Performs an insert + >>> page.title = "My Page" + >>> page.save() # Performs an atomic set on the title field. + +.. note:: + + Changes to documents are tracked and on the whole perform `set` operations. + + * ``list_field.pop(0)`` - *sets* the resulting list + * ``del(list_field)`` - *unsets* whole list To delete a document, call the :meth:`~mongoengine.Document.delete` method. Note that this will only work if the document exists in the database and has a @@ -67,6 +79,7 @@ is an alias to :attr:`id`:: >>> page.id == page.pk .. note:: + If you define your own primary key field, the field implicitly becomes required, so a :class:`ValidationError` will be thrown if you don't provide it. diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst index 0cd06539..3abad775 100644 --- a/docs/guide/gridfs.rst +++ b/docs/guide/gridfs.rst @@ -66,6 +66,7 @@ Deleting stored files is achieved with the :func:`delete` method:: marmot.photo.delete() .. note:: + The FileField in a Document actually only stores the ID of a file in a separate GridFS collection. This means that deleting a document with a defined FileField does not actually delete the file. You must be diff --git a/docs/guide/index.rst b/docs/guide/index.rst index aac72469..d56e7479 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -11,3 +11,4 @@ User Guide document-instances querying gridfs + signals diff --git a/docs/guide/installing.rst b/docs/guide/installing.rst index 132f1079..f15d3dbb 100644 --- a/docs/guide/installing.rst +++ b/docs/guide/installing.rst @@ -1,31 +1,31 @@ ====================== Installing MongoEngine ====================== + To use MongoEngine, you will need to download `MongoDB `_ and ensure it is running in an accessible location. You will also need `PyMongo `_ to use MongoEngine, but if you install MongoEngine using setuptools, then the dependencies will be handled for you. -MongoEngine is available on PyPI, so to use it you can use -:program:`easy_install`: - +MongoEngine is available on PyPI, so to use it you can use :program:`pip`: + .. code-block:: console - # easy_install mongoengine + $ pip install mongoengine -Alternatively, if you don't have setuptools installed, `download it from PyPi +Alternatively, if you don't have setuptools installed, `download it from PyPi `_ and run .. code-block:: console - # python setup.py install + $ python setup.py install To use the bleeding-edge version of MongoEngine, you can get the source from `GitHub `_ and install it as above: - + .. code-block:: console - # git clone git://github.com/hmarr/mongoengine - # cd mongoengine - # python setup.py install + $ git clone git://github.com/hmarr/mongoengine + $ cd mongoengine + $ python setup.py install diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 832fed50..13a374cc 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -5,8 +5,8 @@ Querying the database is used for accessing the objects in the database associated with the class. The :attr:`objects` attribute is actually a :class:`~mongoengine.queryset.QuerySetManager`, which creates and returns a new -a new :class:`~mongoengine.queryset.QuerySet` object on access. The -:class:`~mongoengine.queryset.QuerySet` object may may be iterated over to +:class:`~mongoengine.queryset.QuerySet` object on access. The +:class:`~mongoengine.queryset.QuerySet` object may be iterated over to fetch documents from the database:: # Prints out the names of all the users in the database @@ -14,6 +14,7 @@ fetch documents from the database:: print user.name .. note:: + Once the iteration finishes (when :class:`StopIteration` is raised), :meth:`~mongoengine.queryset.QuerySet.rewind` will be called so that the :class:`~mongoengine.queryset.QuerySet` may be iterated over again. The @@ -23,7 +24,7 @@ fetch documents from the database:: Filtering queries ================= The query may be filtered by calling the -:class:`~mongoengine.queryset.QuerySet` object with field lookup keyword +:class:`~mongoengine.queryset.QuerySet` object with field lookup keyword arguments. The keys in the keyword arguments correspond to fields on the :class:`~mongoengine.Document` you are querying:: @@ -39,29 +40,6 @@ syntax:: # been written by a user whose 'country' field is set to 'uk' uk_pages = Page.objects(author__country='uk') -Querying lists --------------- -On most fields, this syntax will look up documents where the field specified -matches the given value exactly, but when the field refers to a -:class:`~mongoengine.ListField`, a single item may be provided, in which case -lists that contain that item will be matched:: - - class Page(Document): - tags = ListField(StringField()) - - # This will match all pages that have the word 'coding' as an item in the - # 'tags' list - Page.objects(tags='coding') - -Raw queries ------------ -It is possible to provide a raw PyMongo query as a query parameter, which will -be integrated directly into the query. This is done using the ``__raw__`` -keyword argument:: - - Page.objects(__raw__={'tags': 'coding'}) - -.. versionadded:: 0.4 Query operators =============== @@ -84,7 +62,7 @@ Available operators are as follows: * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values * ``all`` -- every item in list of values provided is in array -* ``size`` -- the size of the array is +* ``size`` -- the size of the array is * ``exists`` -- value for field exists The following operators are available as shortcuts to querying with regular @@ -99,26 +77,67 @@ expressions: * ``endswith`` -- string field ends with value * ``iendswith`` -- string field ends with value (case insensitive) -.. versionadded:: 0.3 - There are a few special operators for performing geographical queries, that may used with :class:`~mongoengine.GeoPointField`\ s: * ``within_distance`` -- provide a list containing a point and a maximum distance (e.g. [(41.342, -87.653), 5]) +* ``within_spherical_distance`` -- Same as above but using the spherical geo model + (e.g. [(41.342, -87.653), 5/earth_radius]) +* ``near`` -- order the documents by how close they are to a given point +* ``near_sphere`` -- Same as above but using the spherical geo model * ``within_box`` -- filter documents to those within a given bounding box (e.g. [(35.0, -125.0), (40.0, -100.0)]) -* ``near`` -- order the documents by how close they are to a given point +* ``within_polygon`` -- filter documents to those within a given polygon (e.g. + [(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]). + .. note:: Requires Mongo Server 2.0 -.. versionadded:: 0.4 -Querying by position -==================== +Querying lists +-------------- +On most fields, this syntax will look up documents where the field specified +matches the given value exactly, but when the field refers to a +:class:`~mongoengine.ListField`, a single item may be provided, in which case +lists that contain that item will be matched:: + + class Page(Document): + tags = ListField(StringField()) + + # This will match all pages that have the word 'coding' as an item in the + # 'tags' list + Page.objects(tags='coding') + It is possible to query by position in a list by using a numerical value as a query operator. So if you wanted to find all pages whose first tag was ``db``, you could use the following query:: - BlogPost.objects(tags__0='db') + Page.objects(tags__0='db') + +If you only want to fetch part of a list eg: you want to paginate a list, then +the `slice` operator is required:: + + # comments - skip 5, limit 10 + Page.objects.fields(slice__comments=[5, 10]) + +For updating documents, if you don't know the position in a list, you can use +the $ positional operator :: + + Post.objects(comments__by="joe").update(**{'inc__comments__$__votes': 1}) + +However, this doesn't map well to the syntax so you can also use a capital S instead :: + + Post.objects(comments__by="joe").update(inc__comments__S__votes=1) + + .. note:: Due to Mongo currently the $ operator only applies to the first matched item in the query. + + +Raw queries +----------- +It is possible to provide a raw PyMongo query as a query parameter, which will +be integrated directly into the query. This is done using the ``__raw__`` +keyword argument:: + + Page.objects(__raw__={'tags': 'coding'}) .. versionadded:: 0.4 @@ -163,9 +182,9 @@ To retrieve a result that should be unique in the collection, use and :class:`~mongoengine.queryset.MultipleObjectsReturned` if more than one document matched the query. -A variation of this method exists, +A variation of this method exists, :meth:`~mongoengine.queryset.Queryset.get_or_create`, that will create a new -document with the query arguments if no documents match the query. An +document with the query arguments if no documents match the query. An additional keyword argument, :attr:`defaults` may be provided, which will be used as default values for the new document, in the case that it should need to be created:: @@ -175,6 +194,22 @@ to be created:: >>> a.name == b.name and a.age == b.age True +Dereferencing results +--------------------- +When iterating the results of :class:`~mongoengine.ListField` or +:class:`~mongoengine.DictField` we automatically dereference any +:class:`~pymongo.dbref.DBRef` objects as efficiently as possible, reducing the +number the queries to mongo. + +There are times when that efficiency is not enough, documents that have +:class:`~mongoengine.ReferenceField` objects or +:class:`~mongoengine.GenericReferenceField` objects at the top level are +expensive as the number of queries to MongoDB can quickly rise. + +To limit the number of queries use +:func:`~mongoengine.queryset.QuerySet.select_related` which converts the +QuerySet to a list and dereferences as efficiently as possible. + Default Document queries ======================== By default, the objects :attr:`~mongoengine.Document.objects` attribute on a @@ -240,7 +275,7 @@ Javascript code that is executed on the database server. Counting results ---------------- Just as with limiting and skipping results, there is a method on -:class:`~mongoengine.queryset.QuerySet` objects -- +:class:`~mongoengine.queryset.QuerySet` objects -- :meth:`~mongoengine.queryset.QuerySet.count`, but there is also a more Pythonic way of achieving this:: @@ -254,6 +289,7 @@ You may sum over the values of a specific field on documents using yearly_expense = Employee.objects.sum('salary') .. note:: + If the field isn't present on a document, that document will be ignored from the sum. @@ -302,6 +338,11 @@ will be given:: >>> f.rating # default value 3 +.. note:: + + The :meth:`~mongoengine.queryset.QuerySet.exclude` is the opposite of + :meth:`~mongoengine.queryset.QuerySet.only` if you want to exclude a field. + If you later need the missing fields, just call :meth:`~mongoengine.Document.reload` on your document. @@ -309,11 +350,11 @@ Advanced queries ================ Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword arguments can't fully express the query you want to use -- for example if you -need to combine a number of constraints using *and* and *or*. This is made +need to combine a number of constraints using *and* and *or*. This is made possible in MongoEngine through the :class:`~mongoengine.queryset.Q` class. A :class:`~mongoengine.queryset.Q` object represents part of a query, and can be initialised using the same keyword-argument syntax you use to query -documents. To build a complex query, you may combine +documents. To build a complex query, you may combine :class:`~mongoengine.queryset.Q` objects using the ``&`` (and) and ``|`` (or) operators. To use a :class:`~mongoengine.queryset.Q` object, pass it in as the first positional argument to :attr:`Document.objects` when you filter it by @@ -325,11 +366,66 @@ calling it with keyword arguments:: # Get top posts Post.objects((Q(featured=True) & Q(hits__gte=1000)) | Q(hits__gte=5000)) -.. warning:: - Only use these advanced queries if absolutely necessary as they will execute - significantly slower than regular queries. This is because they are not - natively supported by MongoDB -- they are compiled to Javascript and sent - to the server for execution. +.. _guide-atomic-updates: + +Atomic updates +============== +Documents may be updated atomically by using the +:meth:`~mongoengine.queryset.QuerySet.update_one` and +:meth:`~mongoengine.queryset.QuerySet.update` methods on a +:meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" +that you may use with these methods: + +* ``set`` -- set a particular value +* ``unset`` -- delete a particular value (since MongoDB v1.3+) +* ``inc`` -- increment a value by a given amount +* ``dec`` -- decrement a value by a given amount +* ``pop`` -- remove the last item from a list +* ``push`` -- append a value to a list +* ``push_all`` -- append several values to a list +* ``pop`` -- remove the first or last element of a list +* ``pull`` -- remove a value from a list +* ``pull_all`` -- remove several values from a list +* ``add_to_set`` -- add value to a list only if its not in the list already + +The syntax for atomic updates is similar to the querying syntax, but the +modifier comes before the field, not after it:: + + >>> post = BlogPost(title='Test', page_views=0, tags=['database']) + >>> post.save() + >>> BlogPost.objects(id=post.id).update_one(inc__page_views=1) + >>> post.reload() # the document has been changed, so we need to reload it + >>> post.page_views + 1 + >>> BlogPost.objects(id=post.id).update_one(set__title='Example Post') + >>> post.reload() + >>> post.title + 'Example Post' + >>> BlogPost.objects(id=post.id).update_one(push__tags='nosql') + >>> post.reload() + >>> post.tags + ['database', 'nosql'] + +.. note :: + + In version 0.5 the :meth:`~mongoengine.Document.save` runs atomic updates + on changed documents by tracking changes to that document. + +The positional operator allows you to update list items without knowing the +index position, therefore making the update a single atomic operation. As we +cannot use the `$` syntax in keyword arguments it has been mapped to `S`:: + + >>> post = BlogPost(title='Test', page_views=0, tags=['database', 'mongo']) + >>> post.save() + >>> BlogPost.objects(id=post.id, tags='mongo').update(set__tags__S='mongodb') + >>> post.reload() + >>> post.tags + ['database', 'mongodb'] + +.. note :: + Currently only top level lists are handled, future versions of mongodb / + pymongo plan to support nested positional operators. See `The $ positional + operator `_. Server-side javascript execution ================================ @@ -433,43 +529,3 @@ following example shows how the substitutions are made:: return comments; } """) - -.. _guide-atomic-updates: - -Atomic updates -============== -Documents may be updated atomically by using the -:meth:`~mongoengine.queryset.QuerySet.update_one` and -:meth:`~mongoengine.queryset.QuerySet.update` methods on a -:meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" -that you may use with these methods: - -* ``set`` -- set a particular value -* ``unset`` -- delete a particular value (since MongoDB v1.3+) -* ``inc`` -- increment a value by a given amount -* ``dec`` -- decrement a value by a given amount -* ``pop`` -- remove the last item from a list -* ``push`` -- append a value to a list -* ``push_all`` -- append several values to a list -* ``pop`` -- remove the first or last element of a list -* ``pull`` -- remove a value from a list -* ``pull_all`` -- remove several values from a list -* ``add_to_set`` -- add value to a list only if its not in the list already - -The syntax for atomic updates is similar to the querying syntax, but the -modifier comes before the field, not after it:: - - >>> post = BlogPost(title='Test', page_views=0, tags=['database']) - >>> post.save() - >>> BlogPost.objects(id=post.id).update_one(inc__page_views=1) - >>> post.reload() # the document has been changed, so we need to reload it - >>> post.page_views - 1 - >>> BlogPost.objects(id=post.id).update_one(set__title='Example Post') - >>> post.reload() - >>> post.title - 'Example Post' - >>> BlogPost.objects(id=post.id).update_one(push__tags='nosql') - >>> post.reload() - >>> post.tags - ['database', 'nosql'] diff --git a/docs/guide/signals.rst b/docs/guide/signals.rst new file mode 100644 index 00000000..58b3d6ed --- /dev/null +++ b/docs/guide/signals.rst @@ -0,0 +1,49 @@ +.. _signals: + +Signals +======= + +.. versionadded:: 0.5 + +Signal support is provided by the excellent `blinker`_ library and +will gracefully fall back if it is not available. + + +The following document signals exist in MongoEngine and are pretty self explaintary: + + * `mongoengine.signals.pre_init` + * `mongoengine.signals.post_init` + * `mongoengine.signals.pre_save` + * `mongoengine.signals.post_save` + * `mongoengine.signals.pre_delete` + * `mongoengine.signals.post_delete` + +Example usage:: + + from mongoengine import * + from mongoengine import signals + + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_save(cls, sender, document, **kwargs): + logging.debug("Pre Save: %s" % document.name) + + @classmethod + def post_save(cls, sender, document, **kwargs): + logging.debug("Post Save: %s" % document.name) + if 'created' in kwargs: + if kwargs['created']: + logging.debug("Created") + else: + logging.debug("Updated") + + signals.pre_save.connect(Author.pre_save, sender=Author) + signals.post_save.connect(Author.post_save, sender=Author) + + +.. _blinker: http://pypi.python.org/pypi/blinker diff --git a/docs/index.rst b/docs/index.rst index ccb7fbe2..920ddf60 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,34 +2,62 @@ MongoEngine User Documentation ============================== -MongoEngine is an Object-Document Mapper, written in Python for working with +**MongoEngine** is an Object-Document Mapper, written in Python for working with MongoDB. To install it, simply run .. code-block:: console # pip install -U mongoengine -The source is available on `GitHub `_. +:doc:`tutorial` + Start here for a quick overview. + +:doc:`guide/index` + The Full guide to MongoEngine + +:doc:`apireference` + The complete API documentation. + +:doc:`django` + Using MongoEngine and Django + +Community +--------- To get help with using MongoEngine, use the `MongoEngine Users mailing list `_ or come chat on the `#mongoengine IRC channel `_. -If you are interested in contributing, join the developers' `mailing list +Contributing +------------ + +The source is available on `GitHub `_ and +contributions are always encouraged. Contributions can be as simple as +minor tweaks to this documentation. To contribute, fork the project on +`GitHub `_ and send a +pull request. + +Also, you can join the developers' `mailing list `_. +Changes +------- +See the :doc:`changelog` for a full list of changes to MongoEngine. + .. toctree:: - :maxdepth: 2 + :hidden: tutorial guide/index apireference django changelog + upgrade Indices and tables -================== +------------------ * :ref:`genindex` +* :ref:`modindex` * :ref:`search` diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 5db2c4df..6ce8d102 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -22,7 +22,7 @@ function. The only argument we need to provide is the name of the MongoDB database to use:: from mongoengine import * - + connect('tumblelog') For more information about connecting to MongoDB see :ref:`guide-connecting`. @@ -112,7 +112,7 @@ link table, we can just store a list of tags in each post. So, for both efficiency and simplicity's sake, we'll store the tags as strings directly within the post, rather than storing references to tags in a separate collection. Especially as tags are generally very short (often even shorter -than a document's id), this denormalisation won't impact very strongly on the +than a document's id), this denormalisation won't impact very strongly on the size of our database. So let's take a look that the code our modified :class:`Post` class:: @@ -152,6 +152,21 @@ We can then store a list of comment documents in our post document:: tags = ListField(StringField(max_length=30)) comments = ListField(EmbeddedDocumentField(Comment)) +Handling deletions of references +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`~mongoengine.ReferenceField` object takes a keyword +`reverse_delete_rule` for handling deletion rules if the reference is deleted. +To delete all the posts if a user is deleted set the rule:: + + class Post(Document): + title = StringField(max_length=120, required=True) + author = ReferenceField(User, reverse_delete_rule=CASCADE) + tags = ListField(StringField(max_length=30)) + comments = ListField(EmbeddedDocumentField(Comment)) + +See :class:`~mongoengine.ReferenceField` for more information. + Adding data to our Tumblelog ============================ Now that we've defined how our documents will be structured, let's start adding @@ -250,5 +265,5 @@ the first matched by the query you provide. Aggregation functions may also be used on :class:`~mongoengine.queryset.QuerySet` objects:: num_posts = Post.objects(tags='mongodb').count() - print 'Found % posts with tag "mongodb"' % num_posts - + print 'Found %d posts with tag "mongodb"' % num_posts + diff --git a/docs/upgrade.rst b/docs/upgrade.rst new file mode 100644 index 00000000..c684c1ad --- /dev/null +++ b/docs/upgrade.rst @@ -0,0 +1,97 @@ +========= +Upgrading +========= + +0.4 to 0.5 +=========== + +There have been the following backwards incompatibilities from 0.4 to 0.5. The +main areas of changed are: choices in fields, map_reduce and collection names. + +Choice options: +-------------- + +Are now expected to be an iterable of tuples, with the first element in each +tuple being the actual value to be stored. The second element is the +human-readable name for the option. + + +PyMongo / MongoDB +----------------- + +map reduce now requires pymongo 1.11+- The pymongo merge_output and reduce_output +parameters, have been depreciated. + +More methods now use map_reduce as db.eval is not supported for sharding as such +the following have been changed: + + * :meth:`~mongoengine.queryset.QuerySet.sum` + * :meth:`~mongoengine.queryset.QuerySet.average` + * :meth:`~mongoengine.queryset.QuerySet.item_frequencies` + + +Default collection naming +------------------------- + +Previously it was just lowercase, its now much more pythonic and readable as its +lowercase and underscores, previously :: + + class MyAceDocument(Document): + pass + + MyAceDocument._meta['collection'] == myacedocument + +In 0.5 this will change to :: + + class MyAceDocument(Document): + pass + + MyAceDocument._get_collection_name() == my_ace_document + +To upgrade use a Mixin class to set meta like so :: + + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class MyAceDocument(Document, BaseMixin): + pass + + MyAceDocument._get_collection_name() == myacedocument + +Alternatively, you can rename your collections eg :: + + from mongoengine.connection import _get_db + from mongoengine.base import _document_registry + + def rename_collections(): + db = _get_db() + + failure = False + + collection_names = [d._get_collection_name() for d in _document_registry.values()] + + for new_style_name in collection_names: + if not new_style_name: # embedded documents don't have collections + continue + old_style_name = new_style_name.replace('_', '') + + if old_style_name == new_style_name: + continue # Nothing to do + + existing = db.collection_names() + if old_style_name in existing: + if new_style_name in existing: + failure = True + print "FAILED to rename: %s to %s (already exists)" % ( + old_style_name, new_style_name) + else: + db[old_style_name].rename(new_style_name) + print "Renamed: %s to %s" % (old_style_name, new_style_name) + + if failure: + print "Upgrading collection names failed" + else: + print "Upgraded collection names" + diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..0d271783 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -6,13 +6,16 @@ import connection from connection import * import queryset from queryset import * +import signals +from signals import * __all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__) + queryset.__all__ + signals.__all__) __author__ = 'Harry Marr' -VERSION = (0, 4, 0) +VERSION = (0, 4, 1) + def get_version(): version = '%s.%s' % (VERSION[0], VERSION[1]) @@ -21,4 +24,3 @@ def get_version(): return version __version__ = get_version() - diff --git a/mongoengine/base.py b/mongoengine/base.py index 6340e319..c4bcee1e 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -1,33 +1,68 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned +from queryset import DO_NOTHING +from mongoengine import signals + +import weakref import sys import pymongo import pymongo.objectid +import operator +from functools import partial -_document_registry = {} +class NotRegistered(Exception): + pass -def get_document(name): - return _document_registry[name] +class InvalidDocumentError(Exception): + pass class ValidationError(Exception): pass +_document_registry = {} + + +def get_document(name): + doc = _document_registry.get(name, None) + if not doc: + # Possible old style names + end = ".%s" % name + possible_match = [k for k in _document_registry.keys() if k.endswith(end)] + if len(possible_match) == 1: + doc = _document_registry.get(possible_match.pop(), None) + if not doc: + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() % name) + return doc + + class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. + + .. versionchanged:: 0.5 - added verbose and help text """ - # Fields may have _types inserted into indexes by default + # Fields may have _types inserted into indexes by default _index_with_types = True _geo_index = False - def __init__(self, db_field=None, name=None, required=False, default=None, + # These track each time a Field instance is created. Used to retain order. + # The auto_creation_counter is used for fields that MongoEngine implicitly + # creates, creation_counter is used for all user-specified fields. + creation_counter = 0 + auto_creation_counter = -1 + + def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, - validation=None, choices=None): + validation=None, choices=None, verbose_name=None, help_text=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: import warnings @@ -41,9 +76,19 @@ class BaseField(object): self.primary_key = primary_key self.validation = validation self.choices = choices + self.verbose_name = verbose_name + self.help_text = help_text + + # Adjust the appropriate creation counter, and save our local copy. + if self.db_field == '_id': + self.creation_counter = BaseField.auto_creation_counter + BaseField.auto_creation_counter -= 1 + else: + self.creation_counter = BaseField.creation_counter + BaseField.creation_counter += 1 def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do + """Descriptor for retrieving a value from a field in a document. Do any necessary conversion between Python and MongoDB types. """ if instance is None: @@ -57,12 +102,19 @@ class BaseField(object): # Allow callable default values if callable(value): value = value() + + # Convert lists / values so we can watch for any changes on them + if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): + value = BaseList(value, instance=instance, name=self.name) + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, instance=instance, name=self.name) return value def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ instance._data[self.name] = value + instance._mark_as_changed(self.name) def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -87,9 +139,9 @@ class BaseField(object): def _validate(self, value): # check choices if self.choices is not None: - if value not in self.choices: - raise ValidationError("Value must be one of %s." - % unicode(self.choices)) + option_keys = [option_key for option_key, option_value in self.choices] + if value not in option_keys: + raise ValidationError("Value must be one of %s." % unicode(option_keys)) # check validation argument if self.validation is not None: @@ -102,13 +154,159 @@ class BaseField(object): self.validate(value) + +class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. + + Allows for nesting of embedded documents inside complex types. + Handles the lazy dereferencing of a queryset by lazily dereferencing all + items in a list / dict rather than one at a time. + + .. versionadded:: 0.5 + """ + + field = None + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + from dereference import dereference + instance._data[self.name] = dereference( + instance._data.get(self.name), max_depth=1, instance=instance, name=self.name, get=True + ) + return super(ComplexBaseField, self).__get__(instance, owner) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + from mongoengine import Document + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_python'): + return value.to_python() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + collection = v._get_collection_name() + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_python'): + value_dict[k] = v.to_python() + else: + value_dict[k] = self.to_python(v) + + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return value_dict + + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + from mongoengine import Document + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_mongo'): + return value.to_mongo() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + # If its a document that is not inheritable it won't have + # _types / _cls data so make it a generic reference allows + # us to dereference + meta = getattr(v, 'meta', getattr(v, '_meta', {})) + if meta and not meta['allow_inheritance'] and not self.field: + from fields import GenericReferenceField + value_dict[k] = GenericReferenceField().to_mongo(v) + else: + collection = v._get_collection_name() + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_mongo'): + value_dict[k] = v.to_mongo() + else: + value_dict[k] = self.to_mongo(v) + + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return value_dict + + def validate(self, value): + """If field provided ensure the value is valid. + """ + if self.field: + try: + if hasattr(value, 'iteritems'): + [self.field.validate(v) for k,v in value.iteritems()] + else: + [self.field.validate(v) for v in value] + except Exception, err: + raise ValidationError('Invalid %s item (%s)' % ( + self.field.__class__.__name__, str(v))) + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def lookup_member(self, member_name): + if self.field: + return self.field.lookup_member(member_name) + return None + + def _set_owner_document(self, owner_document): + if self.field: + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ def to_python(self, value): return value - # return unicode(value) def to_mongo(self, value): if not isinstance(value, pymongo.objectid.ObjectId): @@ -143,25 +341,30 @@ class DocumentMetaclass(type): class_name = [name] superclasses = {} simple_class = True + for base in bases: # Include all fields present in superclasses if hasattr(base, '_fields'): doc_fields.update(base._fields) - class_name.append(base._class_name) # Get superclasses from superclass superclasses[base._class_name] = base superclasses.update(base._superclasses) + else: # Add any mixin fields + attrs.update(dict([(k,v) for k,v in base.__dict__.items() + if issubclass(v.__class__, BaseField)])) - if hasattr(base, '_meta'): - # Ensure that the Document class may be subclassed - - # inheritance may be disabled to remove dependency on + if hasattr(base, '_meta') and not base._meta.get('abstract'): + # Ensure that the Document class may be subclassed - + # inheritance may be disabled to remove dependency on # additional fields _cls and _types + class_name.append(base._class_name) if base._meta.get('allow_inheritance', True) == False: raise ValueError('Document %s may not be subclassed' % base.__name__) else: simple_class = False + doc_class_name = '.'.join(reversed(class_name)) meta = attrs.get('_meta', attrs.get('meta', {})) if 'allow_inheritance' not in meta: @@ -169,12 +372,11 @@ class DocumentMetaclass(type): # Only simple classes - direct subclasses of Document - may set # allow_inheritance to False - if not simple_class and not meta['allow_inheritance']: + if not simple_class and not meta['allow_inheritance'] and not meta['abstract']: raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') attrs['_meta'] = meta - - attrs['_class_name'] = '.'.join(reversed(class_name)) + attrs['_class_name'] = doc_class_name attrs['_superclasses'] = superclasses # Add the document's fields to the _fields attribute @@ -186,26 +388,37 @@ class DocumentMetaclass(type): attr_value.db_field = attr_name doc_fields[attr_name] = attr_value attrs['_fields'] = doc_fields + attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items() if k!=v.db_field]) + attrs['_reverse_db_field_map'] = dict([(v, k) for k, v in attrs['_db_field_map'].items()]) + + from mongoengine import Document new_class = super_new(cls, name, bases, attrs) for field in new_class._fields.values(): field.owner_document = new_class + delete_rule = getattr(field, 'reverse_delete_rule', DO_NOTHING) + if delete_rule != DO_NOTHING: + field.document_type.register_delete_rule(new_class, field.name, + delete_rule) + + if field.name and hasattr(Document, field.name): + raise InvalidDocumentError("%s is a document method and not a valid field name" % field.name) module = attrs.get('__module__') - base_excs = tuple(base.DoesNotExist for base in bases + base_excs = tuple(base.DoesNotExist for base in bases if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) exc = subclass_exception('DoesNotExist', base_excs, module) new_class.add_to_class('DoesNotExist', exc) - base_excs = tuple(base.MultipleObjectsReturned for base in bases + base_excs = tuple(base.MultipleObjectsReturned for base in bases if hasattr(base, 'MultipleObjectsReturned')) base_excs = base_excs or (MultipleObjectsReturned,) exc = subclass_exception('MultipleObjectsReturned', base_excs, module) new_class.add_to_class('MultipleObjectsReturned', exc) global _document_registry - _document_registry[name] = new_class + _document_registry[doc_class_name] = new_class return new_class @@ -220,15 +433,24 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): def __new__(cls, name, bases, attrs): super_new = super(TopLevelDocumentMetaclass, cls).__new__ - # Classes defined in this package are abstract and should not have + # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. - # __metaclass__ is only set on the class with the __metaclass__ + # __metaclass__ is only set on the class with the __metaclass__ # attribute (i.e. it is not set on subclasses). This differentiates # 'real' documents from the 'Document' class - if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: + # + # Also assume a class is abstract if it has abstract set to True in + # its meta dictionary. This allows custom Document superclasses. + if (attrs.get('__metaclass__') == TopLevelDocumentMetaclass or + ('meta' in attrs and attrs['meta'].get('abstract', False))): + # Make sure no base class was non-abstract + non_abstract_bases = [b for b in bases + if hasattr(b,'_meta') and not b._meta.get('abstract', False)] + if non_abstract_bases: + raise ValueError("Abstract document cannot have non-abstract base") return super_new(cls, name, bases, attrs) - collection = name.lower() + collection = ''.join('_%s' % c if c.isupper() else c for c in name).strip('_').lower() id_field = None base_indexes = [] @@ -236,28 +458,45 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Subclassed documents inherit collection from superclass for base in bases: - if hasattr(base, '_meta') and 'collection' in base._meta: - collection = base._meta['collection'] - + if hasattr(base, '_meta'): + if 'collection' in attrs.get('meta', {}) and not base._meta.get('abstract', False): + import warnings + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['meta']['collection']) + if base._get_collection_name(): + collection = base._get_collection_name() # Propagate index options. for key in ('index_background', 'index_drop_dups', 'index_opts'): - if key in base._meta: - base_meta[key] = base._meta[key] + if key in base._meta: + base_meta[key] = base._meta[key] id_field = id_field or base._meta.get('id_field') base_indexes += base._meta.get('indexes', []) + # Propagate 'allow_inheritance' + if 'allow_inheritance' in base._meta: + base_meta['allow_inheritance'] = base._meta['allow_inheritance'] + if 'queryset_class' in base._meta: + base_meta['queryset_class'] = base._meta['queryset_class'] + try: + base_meta['objects'] = base.__getattribute__(base, 'objects') + except AttributeError: + pass meta = { + 'abstract': False, 'collection': collection, 'max_documents': None, 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime + 'ordering': [], # default ordering applied at runtime + 'indexes': [], # indexes to be ensured at runtime 'id_field': id_field, 'index_background': False, 'index_drop_dups': False, 'index_opts': {}, 'queryset_class': QuerySet, + 'delete_rules': {}, + 'allow_inheritance': True } meta.update(base_meta) @@ -269,14 +508,44 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) + collection = attrs['_meta'].get('collection', None) + if callable(collection): + new_class._meta['collection'] = collection(new_class) + # Provide a default queryset unless one has been manually provided - if not hasattr(new_class, 'objects'): - new_class.objects = QuerySetManager() + manager = attrs.get('objects', meta.get('objects', QuerySetManager())) + if hasattr(manager, 'queryset_class'): + meta['queryset_class'] = manager.queryset_class + new_class.objects = manager user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes new_class._meta['indexes'] = user_indexes + unique_indexes = cls._unique_with_indexes(new_class) + new_class._meta['unique_indexes'] = unique_indexes + + for field_name, field in new_class._fields.items(): + # Check for custom primary key + if field.primary_key: + current_pk = new_class._meta['id_field'] + if current_pk and current_pk != field_name: + raise ValueError('Cannot override primary key field') + + if not current_pk: + new_class._meta['id_field'] = field_name + # Make 'Document.id' an alias to the real primary key field + new_class.id = field + + if not new_class._meta['id_field']: + new_class._meta['id_field'] = 'id' + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class.id = new_class._fields['id'] + + return new_class + + @classmethod + def _unique_with_indexes(cls, new_class, namespace=""): unique_indexes = [] for field_name, field in new_class._fields.items(): # Generate a list of indexes needed by uniqueness constraints @@ -302,52 +571,50 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): unique_fields += unique_with # Add the new index to the list - index = [(f, pymongo.ASCENDING) for f in unique_fields] + index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] unique_indexes.append(index) - # Check for custom primary key - if field.primary_key: - current_pk = new_class._meta['id_field'] - if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') + # Grab any embedded document field unique indexes + if field.__class__.__name__ == "EmbeddedDocumentField": + field_namespace = "%s." % field_name + unique_indexes += cls._unique_with_indexes(field.document_type, + field_namespace) - if not current_pk: - new_class._meta['id_field'] = field_name - # Make 'Document.id' an alias to the real primary key field - new_class.id = field - - new_class._meta['unique_indexes'] = unique_indexes - - if not new_class._meta['id_field']: - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class.id = new_class._fields['id'] - - return new_class + return unique_indexes class BaseDocument(object): def __init__(self, **values): + signals.pre_init.send(self.__class__, document=self, values=values) + self._data = {} + self._initialised = False # Assign default values to instance - for attr_name in self._fields.keys(): - # Use default value if present + for attr_name, field in self._fields.items(): value = getattr(self, attr_name, None) setattr(self, attr_name, value) + # Assign initial values to instance for attr_name in values.keys(): try: - setattr(self, attr_name, values.pop(attr_name)) + value = values.pop(attr_name) + setattr(self, attr_name, value) except AttributeError: pass + # Set any get_fieldname_display methods + self.__set_field_display() + # Flag initialised + self._initialised = True + signals.post_init.send(self.__class__, document=self) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. """ # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) + fields = [(field, getattr(self, name)) for name, field in self._fields.items()] # Ensure that each field is matched to a valid value @@ -356,11 +623,44 @@ class BaseDocument(object): try: field._validate(value) except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field of type "%s": %s' - % (field.__class__.__name__, value)) + raise ValidationError('Invalid value for field named "%s" of type "%s": %s' + % (field.name, field.__class__.__name__, value)) elif field.required: raise ValidationError('Field "%s" is required' % field.name) + @apply + def pk(): + """Primary key alias + """ + def fget(self): + return getattr(self, self._meta['id_field']) + def fset(self, value): + return setattr(self, self._meta['id_field'], value) + return property(fget, fset) + + def to_mongo(self): + """Return data dictionary ready for use with MongoDB. + """ + data = {} + for field_name, field in self._fields.items(): + value = getattr(self, field_name, None) + if value is not None: + data[field.db_field] = field.to_mongo(value) + # Only add _cls and _types if allow_inheritance is not False + if not (hasattr(self, '_meta') and + self._meta.get('allow_inheritance', True) == False): + data['_cls'] = self._class_name + data['_types'] = self._superclasses.keys() + [self._class_name] + if '_id' in data and data['_id'] is None: + del data['_id'] + return data + + @classmethod + def _get_collection_name(cls): + """Returns the collection name for this class. + """ + return cls._meta.get('collection', None) + @classmethod def _get_subclasses(cls): """Return a dictionary of all subclasses (found recursively). @@ -376,15 +676,184 @@ class BaseDocument(object): all_subclasses.update(subclass._get_subclasses()) return all_subclasses - @apply - def pk(): - """Primary key alias + @classmethod + def _from_son(cls, son): + """Create an instance of a Document (subclass) from a PyMongo SON. """ - def fget(self): - return getattr(self, self._meta['id_field']) - def fset(self, value): - return setattr(self, self._meta['id_field'], value) - return property(fget, fset) + # get the class name from the document, falling back to the given + # class if unavailable + class_name = son.get(u'_cls', cls._class_name) + data = dict((str(key), value) for key, value in son.items()) + + if '_types' in data: + del data['_types'] + + if '_cls' in data: + del data['_cls'] + + # Return correct subclass for document type + if class_name != cls._class_name: + subclasses = cls._get_subclasses() + if class_name not in subclasses: + # Type of document is probably more generic than the class + # that has been queried to return this SON + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, + has it been imported? + """.strip() % class_name) + cls = subclasses[class_name] + + present_fields = data.keys() + for field_name, field in cls._fields.items(): + if field.db_field in data: + value = data[field.db_field] + data[field_name] = (value if value is None + else field.to_python(value)) + + obj = cls(**data) + obj._changed_fields = [] + return obj + + def _mark_as_changed(self, key): + """Marks a key as explicitly changed by the user + """ + if not key: + return + key = self._db_field_map.get(key, key) + if hasattr(self, '_changed_fields') and key not in self._changed_fields: + self._changed_fields.append(key) + + def _get_changed_fields(self, key=''): + """Returns a list of all fields that have explicitly been changed. + """ + from mongoengine import EmbeddedDocument + _changed_fields = [] + _changed_fields += getattr(self, '_changed_fields', []) + for field_name in self._fields: + db_field_name = self._db_field_map.get(field_name, field_name) + key = '%s.' % db_field_name + field = getattr(self, field_name, None) + if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed + _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] + elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(field, 'items'): + iterator = enumerate(field) + else: + iterator = field.iteritems() + for index, value in iterator: + if not hasattr(value, '_get_changed_fields'): + continue + list_key = "%s%s." % (key, index) + _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] + + return _changed_fields + + def _delta(self): + """Returns the delta (set, unset) of the changes for a document. + Gets any values that have been explicitly changed. + """ + # Handles cases where not loaded from_son but has _id + doc = self.to_mongo() + set_fields = self._get_changed_fields() + set_data = {} + unset_data = {} + if hasattr(self, '_changed_fields'): + set_data = {} + # Fetch each set item from its path + for path in set_fields: + parts = path.split('.') + d = doc + for p in parts: + if hasattr(d, '__getattr__'): + d = getattr(p, d) + elif p.isdigit(): + d = d[int(p)] + else: + d = d.get(p) + set_data[path] = d + else: + set_data = doc + if '_id' in set_data: + del(set_data['_id']) + + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value: + continue + + # If we've set a value that ain't the default value dont unset it. + default = None + + if path in self._fields: + default = self._fields[path].default + else: # Perform a full lookup for lists / embedded lookups + d = self + parts = path.split('.') + db_field_name = parts.pop() + for p in parts: + if p.isdigit(): + d = d[int(p)] + elif hasattr(d, '__getattribute__') and not isinstance(d, dict): + real_path = d._reverse_db_field_map.get(p, p) + d = getattr(d, real_path) + else: + d = d.get(p) + + if hasattr(d, '_fields'): + field_name = d._reverse_db_field_map.get(db_field_name, + db_field_name) + + default = d._fields[field_name].default + + if default is not None: + if callable(default): + default = default() + if default != value: + continue + + del(set_data[path]) + unset_data[path] = 1 + return set_data, unset_data + + @classmethod + def _geo_indices(cls, inspected_classes=None): + inspected_classes = inspected_classes or [] + geo_indices = [] + inspected_classes.append(cls) + for field in cls._fields.values(): + if hasattr(field, 'document_type'): + field_cls = field.document_type + if field_cls in inspected_classes: + continue + if hasattr(field_cls, '_geo_indices'): + geo_indices += field_cls._geo_indices(inspected_classes) + elif field._geo_index: + geo_indices.append(field) + return geo_indices + + def __getstate__(self): + self_dict = self.__dict__ + removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices] + for k in removals: + if hasattr(self, k): + delattr(self, k) + return self.__dict__ + + def __setstate__(self, __dict__): + self.__dict__ = __dict__ + self.__set_field_display() + + def __set_field_display(self): + for attr_name, field in self._fields.items(): + if field.choices: # dynamically adds a way to get the display value for a field with choices + setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) + + def __get_field_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + return dict(field.choices).get(value, value) def __iter__(self): return iter(self._fields) @@ -429,60 +898,6 @@ class BaseDocument(object): return unicode(self).encode('utf-8') return '%s object' % self.__class__.__name__ - def to_mongo(self): - """Return data dictionary ready for use with MongoDB. - """ - data = {} - for field_name, field in self._fields.items(): - value = getattr(self, field_name, None) - if value is not None: - data[field.db_field] = field.to_mongo(value) - # Only add _cls and _types if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', True) == False): - data['_cls'] = self._class_name - data['_types'] = self._superclasses.keys() + [self._class_name] - if data.has_key('_id') and not data['_id']: - del data['_id'] - return data - - @classmethod - def _from_son(cls, son): - """Create an instance of a Document (subclass) from a PyMongo SON. - """ - # get the class name from the document, falling back to the given - # class if unavailable - class_name = son.get(u'_cls', cls._class_name) - - data = dict((str(key), value) for key, value in son.items()) - - if '_types' in data: - del data['_types'] - - if '_cls' in data: - del data['_cls'] - - # Return correct subclass for document type - if class_name != cls._class_name: - subclasses = cls._get_subclasses() - if class_name not in subclasses: - # Type of document is probably more generic than the class - # that has been queried to return this SON - return None - cls = subclasses[class_name] - - present_fields = data.keys() - - for field_name, field in cls._fields.items(): - if field.db_field in data: - value = data[field.db_field] - data[field_name] = (value if value is None - else field.to_python(value)) - - obj = cls(**data) - obj._present_fields = present_fields - return obj - def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): if self.id == other.id: @@ -493,16 +908,115 @@ class BaseDocument(object): return not self.__eq__(other) def __hash__(self): - """ For list, dic key """ if self.pk is None: # For new object return super(BaseDocument,self).__hash__() else: return hash(self.pk) + +class BaseList(list): + """A special list so we can watch any changes + """ + + def __init__(self, list_items, instance, name): + self.instance = instance + self.name = name + super(BaseList, self).__init__(list_items) + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseList, self).__setitem__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseList, self).__delitem__(*args, **kwargs) + + def append(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).append(*args, **kwargs) + + def extend(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).extend(*args, **kwargs) + + def insert(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).insert(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).pop(*args, **kwargs) + + def remove(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).remove(*args, **kwargs) + + def reverse(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).reverse(*args, **kwargs) + + def sort(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).sort(*args, **kwargs) + + def _mark_as_changed(self): + """Marks a list as changed if has an instance and a name""" + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + + +class BaseDict(dict): + """A special dict so we can watch any changes + """ + + def __init__(self, dict_items, instance, name): + self.instance = instance + self.name = name + super(BaseDict, self).__init__(dict_items) + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).__setitem__(*args, **kwargs) + + def __setattr__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).__setattr__(*args, **kwargs) + + def __delete__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).__delete__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).__delitem__(*args, **kwargs) + + def __delattr__(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).__delattr__(*args, **kwargs) + + def clear(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).clear(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).clear(*args, **kwargs) + + def popitem(self, *args, **kwargs): + self._mark_as_changed() + super(BaseDict, self).clear(*args, **kwargs) + + def _mark_as_changed(self): + """Marks a dict as changed if has an instance and a name""" + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + if sys.version_info < (2, 5): # Prior to Python 2.5, Exception was an old-style class + import types def subclass_exception(name, parents, unused): + import types return types.ClassType(name, parents, {}) else: def subclass_exception(name, parents, module): diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 814fde13..7b5cd210 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,6 @@ from pymongo import Connection import multiprocessing +import threading __all__ = ['ConnectionError', 'connect'] @@ -22,17 +23,22 @@ class ConnectionError(Exception): def _get_connection(reconnect=False): + """Handles the connection to the database + """ global _connection identity = get_identity() # Connect to the database if not already connected if _connection.get(identity) is None or reconnect: try: _connection[identity] = Connection(**_connection_settings) - except: - raise ConnectionError('Cannot connect to the database') + except Exception, e: + raise ConnectionError("Cannot connect to the database:\n%s" % e) return _connection[identity] def _get_db(reconnect=False): + """Handles database connections and authentication based on the current + identity + """ global _db, _connection identity = get_identity() # Connect if not already connected @@ -52,12 +58,17 @@ def _get_db(reconnect=False): return _db[identity] def get_identity(): + """Creates an identity key based on the current process and thread + identity. + """ identity = multiprocessing.current_process()._identity identity = 0 if not identity else identity[0] + + identity = (identity, threading.current_thread().ident) return identity - + def connect(db, username=None, password=None, **kwargs): - """Connect to the database specified by the 'db' argument. Connection + """Connect to the database specified by the 'db' argument. Connection settings may be provided here as well if the database is not running on the default port on localhost. If authentication is needed, provide username and password arguments as well. diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py new file mode 100644 index 00000000..7fe9ba2f --- /dev/null +++ b/mongoengine/dereference.py @@ -0,0 +1,184 @@ +import operator + +import pymongo + +from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass +from fields import ReferenceField +from connection import _get_db +from queryset import QuerySet +from document import Document + + +class DeReference(object): + + def __call__(self, items, max_depth=1, instance=None, name=None, get=False): + """ + Cheaply dereferences the items to a set depth. + Also handles the convertion of complex data types. + + :param items: The iterable (dict, list, queryset) to be dereferenced. + :param max_depth: The maximum depth to recurse to + :param instance: The owning instance used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param name: The name of the field, used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param get: A boolean determining if being called by __get__ + """ + if items is None or isinstance(items, basestring): + return items + + # cheapest way to convert a queryset to a list + # list(queryset) uses a count() query to determine length + if isinstance(items, QuerySet): + items = [i for i in items] + + self.max_depth = max_depth + + doc_type = None + if instance and instance._fields: + doc_type = instance._fields[name].field + + if isinstance(doc_type, ReferenceField): + doc_type = doc_type.document_type + + self.reference_map = self._find_references(items) + self.object_map = self._fetch_objects(doc_type=doc_type) + return self._attach_objects(items, 0, instance, name, get) + + def _find_references(self, items, depth=0): + """ + Recursively finds all db references to be dereferenced + + :param items: The iterable (dict, list, queryset) + :param depth: The current depth of recursion + """ + reference_map = {} + if not items: + return reference_map + + # Determine the iterator to use + if not hasattr(items, 'items'): + iterator = enumerate(items) + else: + iterator = items.iteritems() + + # Recursively find dbreferences + for k, item in iterator: + if hasattr(item, '_fields'): + for field_name, field in item._fields.iteritems(): + v = item._data.get(field_name, None) + if isinstance(v, (pymongo.dbref.DBRef)): + reference_map.setdefault(field.document_type, []).append(v.id) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) + elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: + field_cls = getattr(getattr(field, 'field', None), 'document_type', None) + references = self._find_references(v, depth) + for key, refs in references.iteritems(): + if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): + key = field_cls + reference_map.setdefault(key, []).extend(refs) + elif isinstance(item, (pymongo.dbref.DBRef)): + reference_map.setdefault(item.collection, []).append(item.id) + elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: + reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) + elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth: + references = self._find_references(item, depth) + for key, refs in references.iteritems(): + reference_map.setdefault(key, []).extend(refs) + depth += 1 + return reference_map + + def _fetch_objects(self, doc_type=None): + """Fetch all references and convert to their document objects + """ + object_map = {} + for col, dbrefs in self.reference_map.iteritems(): + keys = object_map.keys() + refs = list(set([dbref for dbref in dbrefs if str(dbref) not in keys])) + if hasattr(col, 'objects'): # We have a document class for the refs + references = col.objects.in_bulk(refs) + for key, doc in references.iteritems(): + object_map[key] = doc + else: # Generic reference: use the refs data to convert to document + references = _get_db()[col].find({'_id': {'$in': refs}}) + for ref in references: + if '_cls' in ref: + doc = get_document(ref['_cls'])._from_son(ref) + else: + doc = doc_type._from_son(ref) + object_map[doc.id] = doc + return object_map + + def _attach_objects(self, items, depth=0, instance=None, name=None, get=False): + """ + Recursively finds all db references to be dereferenced + + :param items: The iterable (dict, list, queryset) + :param depth: The current depth of recursion + :param instance: The owning instance used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param name: The name of the field, used for tracking changes by + :class:`~mongoengine.base.ComplexBaseField` + :param get: A boolean determining if being called by __get__ + """ + if not items: + if isinstance(items, (BaseDict, BaseList)): + return items + + if instance: + if isinstance(items, dict): + return BaseDict(items, instance=instance, name=name) + else: + return BaseList(items, instance=instance, name=name) + + if isinstance(items, (dict, pymongo.son.SON)): + if '_ref' in items: + return self.object_map.get(items['_ref'].id, items) + elif '_types' in items and '_cls' in items: + doc = get_document(items['_cls'])._from_son(items) + if not get: + doc._data = self._attach_objects(doc._data, depth, doc, name, get) + return doc + + if not hasattr(items, 'items'): + is_list = True + iterator = enumerate(items) + data = [] + else: + is_list = False + iterator = items.iteritems() + data = {} + + for k, v in iterator: + if is_list: + data.append(v) + else: + data[k] = v + + if k in self.object_map: + data[k] = self.object_map[k] + elif hasattr(v, '_fields'): + for field_name, field in v._fields.iteritems(): + v = data[k]._data.get(field_name, None) + if isinstance(v, (pymongo.dbref.DBRef)): + data[k]._data[field_name] = self.object_map.get(v.id, v) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) + elif isinstance(v, dict) and depth < self.max_depth: + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif isinstance(v, (list, tuple)): + data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth: + data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif hasattr(v, 'id'): + data[k] = self.object_map.get(v.id, v) + + if instance and name: + if is_list: + return BaseList(data, instance=instance, name=name) + return BaseDict(data, instance=instance, name=name) + depth += 1 + return data + +dereference = DeReference() diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index 595852ef..38370cc5 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -3,6 +3,7 @@ from mongoengine import * from django.utils.hashcompat import md5_constructor, sha_constructor from django.utils.encoding import smart_str from django.contrib.auth.models import AnonymousUser +from django.utils.translation import ugettext_lazy as _ import datetime @@ -21,16 +22,38 @@ class User(Document): """A User document that aims to mirror most of the API specified by Django at http://docs.djangoproject.com/en/dev/topics/auth/#users """ - username = StringField(max_length=30, required=True) - first_name = StringField(max_length=30) - last_name = StringField(max_length=30) - email = StringField() - password = StringField(max_length=128) - is_staff = BooleanField(default=False) - is_active = BooleanField(default=True) - is_superuser = BooleanField(default=False) - last_login = DateTimeField(default=datetime.datetime.now) - date_joined = DateTimeField(default=datetime.datetime.now) + username = StringField(max_length=30, required=True, + verbose_name=_('username'), + help_text=_("Required. 30 characters or fewer. Letters, numbers and @/./+/-/_ characters")) + + first_name = StringField(max_length=30, + verbose_name=_('first name')) + + last_name = StringField(max_length=30, + verbose_name=_('last name')) + email = EmailField(verbose_name=_('e-mail address')) + password = StringField(max_length=128, + verbose_name=_('password'), + help_text=_("Use '[algo]$[salt]$[hexdigest]' or use the change password form.")) + is_staff = BooleanField(default=False, + verbose_name=_('staff status'), + help_text=_("Designates whether the user can log into this admin site.")) + is_active = BooleanField(default=True, + verbose_name=_('active'), + help_text=_("Designates whether this user should be treated as active. Unselect this instead of deleting accounts.")) + is_superuser = BooleanField(default=False, + verbose_name=_('superuser status'), + help_text=_("Designates that this user has all permissions without explicitly assigning them.")) + last_login = DateTimeField(default=datetime.datetime.now, + verbose_name=_('last login')) + date_joined = DateTimeField(default=datetime.datetime.now, + verbose_name=_('date joined')) + + meta = { + 'indexes': [ + {'fields': ['username'], 'unique': True} + ] + } def __unicode__(self): return self.username @@ -86,7 +109,7 @@ class User(Document): else: email = '@'.join([email_name, domain_part.lower()]) - user = User(username=username, email=email, date_joined=now) + user = cls(username=username, email=email, date_joined=now) user.set_password(password) user.save() return user @@ -99,6 +122,10 @@ class MongoEngineBackend(object): """Authenticate using MongoEngine and mongoengine.django.auth.User. """ + supports_object_permissions = False + supports_anonymous_user = False + supports_inactive_user = False + def authenticate(self, username=None, password=None): user = User.objects(username=username).first() if user: diff --git a/mongoengine/django/shortcuts.py b/mongoengine/django/shortcuts.py index 29bc17a8..59a20741 100644 --- a/mongoengine/django/shortcuts.py +++ b/mongoengine/django/shortcuts.py @@ -1,6 +1,7 @@ from django.http import Http404 from mongoengine.queryset import QuerySet from mongoengine.base import BaseDocument +from mongoengine.base import ValidationError def _get_queryset(cls): """Inspired by django.shortcuts.*""" @@ -25,7 +26,7 @@ def get_document_or_404(cls, *args, **kwargs): queryset = _get_queryset(cls) try: return queryset.get(*args, **kwargs) - except queryset._document.DoesNotExist: + except (queryset._document.DoesNotExist, ValidationError): raise Http404('No %s matches the given query.' % queryset._document._class_name) def get_list_or_404(cls, *args, **kwargs): diff --git a/mongoengine/document.py b/mongoengine/document.py index f15e6836..3ccc4ddc 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,12 +1,17 @@ +from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, - ValidationError) + ValidationError, BaseDict, BaseList) from queryset import OperationError from connection import _get_db import pymongo +__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', + 'OperationError', 'InvalidCollectionError'] -__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] + +class InvalidCollectionError(Exception): + pass class EmbeddedDocument(BaseDocument): @@ -18,6 +23,18 @@ class EmbeddedDocument(BaseDocument): __metaclass__ = DocumentMetaclass + def __delattr__(self, *args, **kwargs): + """Handle deletions of fields""" + field_name = args[0] + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + super(EmbeddedDocument, self).__delattr__(*args, **kwargs) + + class Document(BaseDocument): """The base class used for defining the structure and properties of @@ -40,44 +57,125 @@ class Document(BaseDocument): presence of `_cls` and `_types`, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. - A :class:`~mongoengine.Document` may use a **Capped Collection** by + A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` dictionary. :attr:`max_documents` is the maximum number of documents that - is allowed to be stored in the collection, and :attr:`max_size` is the - maximum size of the collection in bytes. If :attr:`max_size` is not - specified and :attr:`max_documents` is, :attr:`max_size` defaults to + is allowed to be stored in the collection, and :attr:`max_size` is the + maximum size of the collection in bytes. If :attr:`max_size` is not + specified and :attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB). Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` - dictionary. The value should be a list of field names or tuples of field + dictionary. The value should be a list of field names or tuples of field names. Index direction may be specified by prefixing the field names with a **+** or **-** sign. - """ + By default, _types will be added to the start of every index (that + doesn't contain a list) if allow_inheritence is True. This can be + disabled by either setting types to False on the specific index or + by setting index_types to False on the meta dictionary for the document. + """ __metaclass__ = TopLevelDocumentMetaclass - def save(self, safe=True, force_insert=False, validate=True): + @classmethod + def _get_collection(self): + """Returns the collection for the document.""" + db = _get_db() + collection_name = self._get_collection_name() + + if not hasattr(self, '_collection') or self._collection is None: + # Create collection as a capped collection if specified + if self._meta['max_size'] or self._meta['max_documents']: + # Get max document limit and max byte size from meta + max_size = self._meta['max_size'] or 10000000 # 10MB default + max_documents = self._meta['max_documents'] + + if collection_name in db.collection_names(): + self._collection = db[collection_name] + # The collection already exists, check if its capped + # options match the specified capped options + options = self._collection.options() + if options.get('max') != max_documents or \ + options.get('size') != max_size: + msg = ('Cannot create collection "%s" as a capped ' + 'collection as it already exists') % self._collection + raise InvalidCollectionError(msg) + else: + # Create the collection as a capped collection + opts = {'capped': True, 'size': max_size} + if max_documents: + opts['max'] = max_documents + self._collection = db.create_collection( + collection_name, **opts + ) + else: + self._collection = db[collection_name] + return self._collection + + def save(self, safe=True, force_insert=False, validate=True, write_options=None, _refs=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an + If ``safe=True`` and the operation is unsuccessful, an :class:`~mongoengine.OperationError` will be raised. :param safe: check if the operation succeeded before returning - :param force_insert: only try to create a new document, don't allow + :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. + :param write_options: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.save` OR + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant ``getLastError`` command. + For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers + have recorded the write and will force an fsync on each server being written to. + + .. versionchanged:: 0.5 + In existing documents it only saves changed fields using set / unset + Saves are cascaded and any :class:`~pymongo.dbref.DBRef` objects + that have changes are saved as well. """ + from fields import ReferenceField, GenericReferenceField + + signals.pre_save.send(self.__class__, document=self) + if validate: self.validate() + + if not write_options: + write_options = {} + doc = self.to_mongo() + + created = '_id' in doc + creation_mode = force_insert or not created try: collection = self.__class__.objects._collection - if force_insert: - object_id = collection.insert(doc, safe=safe) + if creation_mode: + if force_insert: + object_id = collection.insert(doc, safe=safe, **write_options) + else: + object_id = collection.save(doc, safe=safe, **write_options) else: - object_id = collection.save(doc, safe=safe) + object_id = doc['_id'] + updates, removals = self._delta() + if updates: + collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) + if removals: + collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) + + # Save any references / generic references + _refs = _refs or [] + for name, cls in self._fields.items(): + if isinstance(cls, (ReferenceField, GenericReferenceField)): + ref = getattr(self, name) + if ref and str(ref) not in _refs: + _refs.append(str(ref)) + ref.save(safe=safe, force_insert=force_insert, + validate=validate, write_options=write_options, + _refs=_refs) + except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): @@ -86,12 +184,42 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) + def reset_changed_fields(doc, inspected_docs=None): + """Loop through and reset changed fields lists""" + + inspected_docs = inspected_docs or [] + inspected_docs.append(doc) + if hasattr(doc, '_changed_fields'): + doc._changed_fields = [] + + for field_name in doc._fields: + field = getattr(doc, field_name) + if field not in inspected_docs and hasattr(field, '_changed_fields'): + reset_changed_fields(field, inspected_docs) + + reset_changed_fields(self) + signals.post_save.send(self.__class__, document=self, created=creation_mode) + + def update(self, **kwargs): + """Performs an update on the :class:`~mongoengine.Document` + A convenience wrapper to :meth:`~mongoengine.QuerySet.update`. + + Raises :class:`OperationError` if called on an object that has not yet + been saved. + """ + if not self.pk: + raise OperationError('attempt to update a document not yet saved') + + return self.__class__.objects(pk=self.pk).update_one(**kwargs) + def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. :param safe: check if the operation succeeded before returning """ + signals.pre_delete.send(self.__class__, document=self) + id_field = self._meta['id_field'] object_id = self._fields[id_field].to_mongo(self[id_field]) try: @@ -100,6 +228,18 @@ class Document(BaseDocument): message = u'Could not delete document (%s)' % err.message raise OperationError(message) + signals.post_delete.send(self.__class__, document=self) + + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to + a maximum depth in order to cut down the number queries to mongodb. + + .. versionadded:: 0.5 + """ + from dereference import dereference + self._data = dereference(self._data, max_depth) + return self + def reload(self): """Reloads all attributes from the database. @@ -108,7 +248,37 @@ class Document(BaseDocument): id_field = self._meta['id_field'] obj = self.__class__.objects(**{id_field: self[id_field]}).first() for field in self._fields: - setattr(self, field, obj[field]) + setattr(self, field, self._reload(field, obj[field])) + self._changed_fields = [] + + def _reload(self, key, value): + """Used by :meth:`~mongoengine.Document.reload` to ensure the + correct instance is linked to self. + """ + if isinstance(value, BaseDict): + value = [(k, self._reload(k,v)) for k,v in value.items()] + value = BaseDict(value, instance=self, name=key) + elif isinstance(value, BaseList): + value = [self._reload(key, v) for v in value] + value = BaseList(value, instance=self, name=key) + elif isinstance(value, EmbeddedDocument): + value._changed_fields = [] + return value + + def to_dbref(self): + """Returns an instance of :class:`~pymongo.dbref.DBRef` useful in + `__raw__` queries.""" + if not self.pk: + msg = "Only saved documents can have a valid dbref" + raise OperationError(msg) + return pymongo.dbref.DBRef(self.__class__._get_collection_name(), self.pk) + + @classmethod + def register_delete_rule(cls, document_cls, field_name, rule): + """This method registers the delete rules to apply when removing this + object. + """ + cls._meta['delete_rules'][(document_cls, field_name)] = rule @classmethod def drop_collection(cls): @@ -116,16 +286,16 @@ class Document(BaseDocument): :class:`~mongoengine.Document` type from the database. """ db = _get_db() - db.drop_collection(cls._meta['collection']) + db.drop_collection(cls._get_collection_name()) class MapReduceDocument(object): """A document returned from a map/reduce query. :param collection: An instance of :class:`~pymongo.Collection` - :param key: Document/result key, often an instance of - :class:`~pymongo.objectid.ObjectId`. If supplied as - an ``ObjectId`` found in the given ``collection``, + :param key: Document/result key, often an instance of + :class:`~pymongo.objectid.ObjectId`. If supplied as + an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. @@ -140,7 +310,7 @@ class MapReduceDocument(object): @property def object(self): - """Lazy-load the object referenced by ``self.key``. ``self.key`` + """Lazy-load the object referenced by ``self.key``. ``self.key`` should be the ``primary_key``. """ id_field = self._document()._meta['id_field'] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f39d6e6f..c5734430 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,6 @@ -from base import BaseField, ObjectIdField, ValidationError, get_document +from base import (BaseField, ComplexBaseField, ObjectIdField, + ValidationError, get_document) +from queryset import DO_NOTHING from document import Document, EmbeddedDocument from connection import _get_db from operator import itemgetter @@ -8,18 +10,18 @@ import pymongo import pymongo.dbref import pymongo.son import pymongo.binary -import datetime +import datetime, time import decimal import gridfs -import warnings -import types __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', - 'ObjectIdField', 'ReferenceField', 'ValidationError', - 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', - 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] + 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', + 'DecimalField', 'ComplexDateTimeField', 'URLField', + 'GenericReferenceField', 'FileField', 'BinaryField', + 'SortedListField', 'EmailField', 'GeoPointField', + 'SequenceField', 'GenericEmbeddedDocumentField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -118,8 +120,8 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom - r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -150,6 +152,9 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Integer value is too large') + def prepare_query_value(self, op, value): + return int(value) + class FloatField(BaseField): """An floating point number field. @@ -173,6 +178,10 @@ class FloatField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Float value is too large') + def prepare_query_value(self, op, value): + return float(value) + + class DecimalField(BaseField): """A fixed-point decimal number field. @@ -222,15 +231,151 @@ class BooleanField(BaseField): class DateTimeField(BaseField): """A datetime field. + + Note: Microseconds are rounded to the nearest millisecond. + Pre UTC microsecond support is effecively broken. + Use :class:`~mongoengine.fields.ComplexDateTimeField` if you + need accurate microsecond support. """ def validate(self, value): - assert isinstance(value, datetime.datetime) + assert isinstance(value, (datetime.datetime, datetime.date)) + + def to_mongo(self, value): + return self.prepare_query_value(None, value) + + def prepare_query_value(self, op, value): + if value is None: + return value + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + return datetime.datetime(value.year, value.month, value.day) + + # Attempt to parse a datetime: + # value = smart_str(value) + # split usecs, because they are not recognized by strptime. + if '.' in value: + try: + value, usecs = value.split('.') + usecs = int(usecs) + except ValueError: + return None + else: + usecs = 0 + kwargs = {'microsecond': usecs} + try: # Seconds are optional, so try converting seconds first. + return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], + **kwargs) + except ValueError: + try: # Try without seconds. + return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], + **kwargs) + except ValueError: # Try without hour/minutes/seconds. + try: + return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], + **kwargs) + except ValueError: + return None + + +class ComplexDateTimeField(StringField): + """ + ComplexDateTimeField handles microseconds exactly instead of rounding + like DateTimeField does. + + Derives from a StringField so you can do `gte` and `lte` filtering by + using lexicographical comparison when filtering / sorting strings. + + The stored string has the following format: + + YYYY,MM,DD,HH,MM,SS,NNNNNN + + Where NNNNNN is the number of microseconds of the represented `datetime`. + The `,` as the separator can be easily modified by passing the `separator` + keyword when initializing the field. + + .. versionadded:: 0.5 + """ + + def __init__(self, separator=',', **kwargs): + self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond'] + self.separtor = separator + super(ComplexDateTimeField, self).__init__(**kwargs) + + def _leading_zero(self, number): + """ + Converts the given number to a string. + + If it has only one digit, a leading zero so as it has always at least + two digits. + """ + if int(number) < 10: + return "0%s" % number + else: + return str(number) + + def _convert_from_datetime(self, val): + """ + Convert a `datetime` object to a string representation (which will be + stored in MongoDB). This is the reverse function of + `_convert_from_string`. + + >>> a = datetime(2011, 6, 8, 20, 26, 24, 192284) + >>> RealDateTimeField()._convert_from_datetime(a) + '2011,06,08,20,26,24,192284' + """ + data = [] + for name in self.names: + data.append(self._leading_zero(getattr(val, name))) + return ','.join(data) + + def _convert_from_string(self, data): + """ + Convert a string representation to a `datetime` object (the object you + will manipulate). This is the reverse function of + `_convert_from_datetime`. + + >>> a = '2011,06,08,20,26,24,192284' + >>> ComplexDateTimeField()._convert_from_string(a) + datetime.datetime(2011, 6, 8, 20, 26, 24, 192284) + """ + data = data.split(',') + data = map(int, data) + values = {} + for i in range(7): + values[self.names[i]] = data[i] + return datetime.datetime(**values) + + def __get__(self, instance, owner): + data = super(ComplexDateTimeField, self).__get__(instance, owner) + if data == None: + return datetime.datetime.now() + return self._convert_from_string(data) + + def __set__(self, instance, value): + value = self._convert_from_datetime(value) + return super(ComplexDateTimeField, self).__set__(instance, value) + + def validate(self, value): + if not isinstance(value, datetime.datetime): + raise ValidationError('Only datetime objects may used in a \ + ComplexDateTimeField') + + def to_python(self, value): + return self._convert_from_string(value) + + def to_mongo(self, value): + return self._convert_from_datetime(value) + + def prepare_query_value(self, op, value): + return self._convert_from_datetime(value) class EmbeddedDocumentField(BaseField): - """An embedded document field. Only valid values are subclasses of - :class:`~mongoengine.EmbeddedDocument`. + """An embedded document field - with a declared document_type. + Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. """ def __init__(self, document_type, **kwargs): @@ -256,6 +401,8 @@ class EmbeddedDocumentField(BaseField): return value def to_mongo(self, value): + if not isinstance(value, self.document_type): + return value return self.document_type.to_mongo(value) def validate(self, value): @@ -275,7 +422,41 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(BaseField): +class GenericEmbeddedDocumentField(BaseField): + """A generic embedded document field - allows any + :class:`~mongoengine.EmbeddedDocument` to be stored. + + Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. + """ + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def to_python(self, value): + if isinstance(value, dict): + doc_cls = get_document(value['_cls']) + value = doc_cls._from_son(value) + + return value + + def validate(self, value): + if not isinstance(value, EmbeddedDocument): + raise ValidationError('Invalid embedded document instance ' + 'provided to an GenericEmbeddedDocumentField') + + value.validate() + + def to_mongo(self, document): + if document is None: + return None + + data = document.to_mongo() + if not '_cls' in data: + data['_cls'] = document._class_name + return data + + +class ListField(ComplexBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -283,84 +464,26 @@ class ListField(BaseField): # ListFields cannot be indexed with _types - MongoDB doesn't support this _index_with_types = False - def __init__(self, field, **kwargs): - if not isinstance(field, BaseField): - raise ValidationError('Argument to ListField constructor must be ' - 'a valid field') + def __init__(self, field=None, **kwargs): self.field = field kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - if isinstance(self.field, ReferenceField): - referenced_type = self.field.document_type - # Get value from document instance if available - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - if isinstance(self.field, GenericReferenceField): - value_list = instance._data.get(self.name) - if value_list: - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): - deref_list.append(self.field.dereference(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list - - return super(ListField, self).__get__(instance, owner) - - def to_python(self, value): - return [self.field.to_python(item) for item in value] - - def to_mongo(self, value): - return [self.field.to_mongo(item) for item in value] - def validate(self, value): """Make sure that a list of valid fields is being used. """ if not isinstance(value, (list, tuple)): raise ValidationError('Only lists and tuples may be used in a ' 'list field') - - try: - [self.field.validate(item) for item in value] - except Exception, err: - raise ValidationError('Invalid ListField item (%s)' % str(item)) + super(ListField, self).validate(value) def prepare_query_value(self, op, value): - if op in ('set', 'unset'): - return [self.field.prepare_query_value(op, v) for v in value] - return self.field.prepare_query_value(op, value) - - def lookup_member(self, member_name): - return self.field.lookup_member(member_name) - - def _set_owner_document(self, owner_document): - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) + if self.field: + if op in ('set', 'unset') and (not isinstance(value, basestring) + and hasattr(value, '__iter__')): + return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) + return super(ListField, self).prepare_query_value(op, value) class SortedListField(ListField): @@ -379,20 +502,22 @@ class SortedListField(ListField): super(SortedListField, self).__init__(field, **kwargs) def to_mongo(self, value): + value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: - return sorted([self.field.to_mongo(item) for item in value], - key=itemgetter(self._ordering)) - return sorted([self.field.to_mongo(item) for item in value]) + return sorted(value, key=itemgetter(self._ordering)) + return sorted(value) -class DictField(BaseField): +class DictField(ComplexBaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Can now handle complex / varying types of data """ - def __init__(self, basecls=None, *args, **kwargs): + def __init__(self, basecls=None, field=None, *args, **kwargs): + self.field = field self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) kwargs.setdefault('default', lambda: {}) @@ -408,21 +533,67 @@ class DictField(BaseField): if any(('.' in k or '$' in k) for k in value): raise ValidationError('Invalid dictionary key name - keys may not ' 'contain "." or "$" characters') + super(DictField, self).validate(value) def lookup_member(self, member_name): - return self.basecls(db_field=member_name) + return DictField(basecls=self.basecls, db_field=member_name) + + def prepare_query_value(self, op, value): + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact'] + + if op in match_operators and isinstance(value, basestring): + return StringField().prepare_query_value(op, value) + + return super(DictField, self).prepare_query_value(op, value) + + +class MapField(DictField): + """A field that maps a name to a specified field type. Similar to + a DictField, except the 'value' of each item must match the specified + field type. + + .. versionadded:: 0.5 + """ + + def __init__(self, field=None, *args, **kwargs): + if not isinstance(field, BaseField): + raise ValidationError('Argument to MapField constructor must be ' + 'a valid field') + super(MapField, self).__init__(field=field, *args, **kwargs) + + class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). + + Use the `reverse_delete_rule` to handle what should happen if the document + the field is referencing is deleted. + + The options are: + + * DO_NOTHING - don't do anything (default). + * NULLIFY - Updates the reference to null. + * CASCADE - Deletes the documents associated with the reference. + * DENY - Prevent the deletion of the reference object. + + .. versionchanged:: 0.5 added `reverse_delete_rule` """ - def __init__(self, document_type, **kwargs): + def __init__(self, document_type, reverse_delete_rule=DO_NOTHING, **kwargs): + """Initialises the Reference Field. + + :param reverse_delete_rule: Determines what to do when the referring + object is deleted + """ if not isinstance(document_type, basestring): if not issubclass(document_type, (Document, basestring)): raise ValidationError('Argument to ReferenceField constructor ' 'must be a document class or a string') self.document_type_obj = document_type + self.reverse_delete_rule = reverse_delete_rule super(ReferenceField, self).__init__(**kwargs) @property @@ -465,7 +636,7 @@ class ReferenceField(BaseField): id_ = document id_ = id_field.to_mongo(id_) - collection = self.document_type._meta['collection'] + collection = self.document_type._get_collection_name() return pymongo.dbref.DBRef(collection, id_) def prepare_query_value(self, op, value): @@ -474,6 +645,11 @@ class ReferenceField(BaseField): def validate(self, value): assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) + if isinstance(value, Document) and value.id is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -482,6 +658,9 @@ class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). + ..note :: Any documents used as a generic reference must be registered in the + document registry. Importing the model will automatically register it. + .. versionadded:: 0.3 """ @@ -495,6 +674,15 @@ class GenericReferenceField(BaseField): return super(GenericReferenceField, self).__get__(instance, owner) + def validate(self, value): + if not isinstance(value, (Document, pymongo.dbref.DBRef)): + raise ValidationError('GenericReferences can only contain documents') + + # We need the id from the saved object to create the DBRef + if isinstance(value, Document) and value.id is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + def dereference(self, value): doc_cls = get_document(value['_cls']) reference = value['_ref'] @@ -504,6 +692,9 @@ class GenericReferenceField(BaseField): return doc def to_mongo(self, document): + if document is None: + return None + id_field_name = document.__class__._meta['id_field'] id_field = document.__class__._fields[id_field_name] @@ -517,9 +708,9 @@ class GenericReferenceField(BaseField): id_ = document id_ = id_field.to_mongo(id_) - collection = document._meta['collection'] + collection = document._get_collection_name() ref = pymongo.dbref.DBRef(collection, id_) - return {'_cls': document.__class__.__name__, '_ref': ref} + return {'_cls': document._class_name, '_ref': ref} def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -555,12 +746,16 @@ class GridFSProxy(object): """Proxy object to handle writing and reading of files to and from GridFS .. versionadded:: 0.4 + .. versionchanged:: 0.5 - added optional size param to read """ - def __init__(self, grid_id=None): + def __init__(self, grid_id=None, key=None, instance=None): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file + self.gridout = None + self.key = key + self.instance = instance def __getattr__(self, name): obj = self.get() @@ -571,11 +766,18 @@ class GridFSProxy(object): def __get__(self, instance, value): return self + def __nonzero__(self): + return bool(self.grid_id) + def get(self, id=None): if id: self.grid_id = id + if self.grid_id is None: + return None try: - return self.fs.get(id or self.grid_id) + if self.gridout is None: + self.gridout = self.fs.get(self.grid_id) + return self.gridout except: # File has been deleted return None @@ -584,11 +786,12 @@ class GridFSProxy(object): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id - def put(self, file, **kwargs): + def put(self, file_obj, **kwargs): if self.grid_id: raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') - self.grid_id = self.fs.put(file, **kwargs) + self.grid_id = self.fs.put(file_obj, **kwargs) + self._mark_as_changed() def write(self, string): if self.grid_id: @@ -603,11 +806,11 @@ class GridFSProxy(object): if not self.newfile: self.new_file() self.grid_id = self.newfile._id - self.newfile.writelines(lines) + self.newfile.writelines(lines) - def read(self): + def read(self, size=-1): try: - return self.get().read() + return self.get().read(size) except: return None @@ -615,20 +818,28 @@ class GridFSProxy(object): # Delete file from GridFS, FileField still remains self.fs.delete(self.grid_id) self.grid_id = None + self.gridout = None + self._mark_as_changed() - def replace(self, file, **kwargs): + def replace(self, file_obj, **kwargs): self.delete() - self.put(file, **kwargs) + self.put(file_obj, **kwargs) def close(self): if self.newfile: self.newfile.close() + def _mark_as_changed(self): + """Inform the instance that `self.key` has been changed""" + if self.instance: + self.instance._mark_as_changed(self.key) + class FileField(BaseField): """A GridFS storage field. .. versionadded:: 0.4 + .. versionchanged:: 0.5 added optional size param for read """ def __init__(self, **kwargs): @@ -641,11 +852,15 @@ class FileField(BaseField): # Check if a file already exists for this model grid_file = instance._data.get(self.name) self.grid_file = grid_file - if self.grid_file: + if isinstance(self.grid_file, GridFSProxy): + if not self.grid_file.key: + self.grid_file.key = self.name + self.grid_file.instance = instance return self.grid_file - return GridFSProxy() + return GridFSProxy(key=self.name, instance=instance) def __set__(self, instance, value): + key = self.name if isinstance(value, file) or isinstance(value, str): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) @@ -659,10 +874,12 @@ class FileField(BaseField): grid_file.put(value) else: # Create a new proxy object as we don't already have one - instance._data[self.name] = GridFSProxy() - instance._data[self.name].put(value) + instance._data[key] = GridFSProxy(key=key, instance=instance) + instance._data[key].put(value) else: - instance._data[self.name] = value + instance._data[key] = value + + instance._mark_as_changed(key) def to_mongo(self, value): # Store the GridFS file id in MongoDB @@ -700,3 +917,61 @@ class GeoPointField(BaseField): if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): raise ValidationError('Both values in point must be float or int.') + + +class SequenceField(IntField): + """Provides a sequental counter. + + ..note:: Although traditional databases often use increasing sequence + numbers for primary keys. In MongoDB, the preferred approach is to + use Object IDs instead. The concept is that in a very large + cluster of machines, it is easier to create an object ID than have + global, uniformly increasing sequence numbers. + + .. versionadded:: 0.5 + """ + def __init__(self, collection_name=None, *args, **kwargs): + self.collection_name = collection_name or 'mongoengine.counters' + return super(SequenceField, self).__init__(*args, **kwargs) + + def generate_new_value(self): + """ + Generate and Increment the counter + """ + sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), + self.name) + collection = _get_db()[self.collection_name] + counter = collection.find_and_modify(query={"_id": sequence_id}, + update={"$inc": {"next": 1}}, + new=True, + upsert=True) + return counter['next'] + + def __get__(self, instance, owner): + + if instance is None: + return self + + if not instance._data: + return + + value = instance._data.get(self.name) + + if not value and instance._initialised: + value = self.generate_new_value() + instance._data[self.name] = value + instance._mark_as_changed(self.name) + + return value + + def __set__(self, instance, value): + + if value is None and instance._initialised: + value = self.generate_new_value() + + return super(SequenceField, self).__set__(instance, value) + + def to_python(self, value): + if value is None: + value = self.generate_new_value() + return value diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0af8dead..a6626855 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -8,13 +8,21 @@ import pymongo.objectid import re import copy import itertools +import operator __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', - 'InvalidCollectionError'] + 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] + # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 +# Delete rules +DO_NOTHING = 0 +NULLIFY = 1 +CASCADE = 2 +DENY = 3 + class DoesNotExist(Exception): pass @@ -32,10 +40,6 @@ class OperationError(Exception): pass -class InvalidCollectionError(Exception): - pass - - RE_TYPE = type(re.compile('')) @@ -78,7 +82,7 @@ class SimplificationVisitor(QNodeVisitor): # to a single field intersection = ops.intersection(query_ops) if intersection: - msg = 'Duplicate query contitions: ' + msg = 'Duplicate query conditions: ' raise InvalidQueryError(msg + ', '.join(intersection)) query_ops.update(ops) @@ -118,7 +122,6 @@ class QueryTreeTransformerVisitor(QNodeVisitor): q_object = reduce(lambda a, b: a & b, and_parts, Q()) q_object = reduce(lambda a, b: a & b, or_group, q_object) clauses.append(q_object) - # Finally, $or the generated clauses in to one query. Each of the # clauses is sufficient for the query to succeed. return reduce(lambda a, b: a | b, clauses, Q()) @@ -179,7 +182,7 @@ class QueryCompilerVisitor(QNodeVisitor): # once to a single field intersection = current_ops.intersection(new_ops) if intersection: - msg = 'Duplicate query contitions: ' + msg = 'Duplicate query conditions: ' raise InvalidQueryError(msg + ', '.join(intersection)) # Right! We've got two non-overlapping dicts of operations! @@ -243,7 +246,8 @@ class QCombination(QNode): def accept(self, visitor): for i in range(len(self.children)): - self.children[i] = self.children[i].accept(visitor) + if isinstance(self.children[i], QNode): + self.children[i] = self.children[i].accept(visitor) return visitor.visit_combination(self) @@ -268,37 +272,102 @@ class Q(QNode): return not bool(self.query) +class QueryFieldList(object): + """Object that handles combinations of .only() and .exclude() calls""" + ONLY = True + EXCLUDE = False + + def __init__(self, fields=[], value=ONLY, always_include=[]): + self.value = value + self.fields = set(fields) + self.always_include = set(always_include) + + def as_dict(self): + return dict((field, self.value) for field in self.fields) + + def __add__(self, f): + if not self.fields: + self.fields = f.fields + self.value = f.value + elif self.value is self.ONLY and f.value is self.ONLY: + self.fields = self.fields.intersection(f.fields) + elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: + self.fields = self.fields.union(f.fields) + elif self.value is self.ONLY and f.value is self.EXCLUDE: + self.fields -= f.fields + elif self.value is self.EXCLUDE and f.value is self.ONLY: + self.value = self.ONLY + self.fields = f.fields - self.fields + + if self.always_include: + if self.value is self.ONLY and self.fields: + self.fields = self.fields.union(self.always_include) + else: + self.fields -= self.always_include + return self + + def reset(self): + self.fields = set([]) + self.value = self.ONLY + + def __nonzero__(self): + return bool(self.fields) + + class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. """ + __already_indexed = set() + def __init__(self, document, collection): self._document = document self._collection_obj = collection - self._accessed_collection = False self._mongo_query = None self._query_obj = Q() self._initial_query = {} self._where_clause = None - self._loaded_fields = [] + self._loaded_fields = QueryFieldList() self._ordering = [] self._snapshot = False self._timeout = True + self._class_check = True + self._slave_okay = False # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get('allow_inheritance'): self._initial_query = {'_types': self._document._class_name} + self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None self._limit = None self._skip = None + self._hint = -1 # Using -1 as None is a valid value for hint + + def clone(self): + """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 + """ + c = self.__class__(self._document, self._collection_obj) + + copy_props = ('_initial_query', '_query_obj', '_where_clause', + '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_limit', '_skip', '_slave_okay', '_hint') + + for prop in copy_props: + val = getattr(self, prop) + setattr(c, prop, copy.deepcopy(val)) + + return c @property def _query(self): if self._mongo_query is None: self._mongo_query = self._query_obj.to_query(self._document) - self._mongo_query.update(self._initial_query) + if self._class_check: + self._mongo_query.update(self._initial_query) return self._mongo_query def ensure_index(self, key_or_list, drop_dups=False, background=False, @@ -309,21 +378,27 @@ class QuerySet(object): construct a multi-field index); keys may be prefixed with a **+** or a **-** to determine the index ordering """ - index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list, drop_dups=drop_dups, - background=background) + index_spec = QuerySet._build_index_spec(self._document, key_or_list) + self._collection.ensure_index( + index_spec['fields'], + drop_dups=drop_dups, + background=background, + sparse=index_spec.get('sparse', False), + unique=index_spec.get('unique', False)) return self @classmethod - def _build_index_spec(cls, doc_cls, key_or_list): + def _build_index_spec(cls, doc_cls, spec): """Build a PyMongo index spec from a MongoEngine index spec. """ - if isinstance(key_or_list, basestring): - key_or_list = [key_or_list] + if isinstance(spec, basestring): + spec = {'fields': [spec]} + if isinstance(spec, (list, tuple)): + spec = {'fields': spec} index_list = [] use_types = doc_cls._meta.get('allow_inheritance', True) - for key in key_or_list: + for key in spec['fields']: # Get direction from + or - direction = pymongo.ASCENDING if key.startswith("-"): @@ -344,12 +419,26 @@ class QuerySet(object): use_types = False # If _types is being used, prepend it to every specified index - if doc_cls._meta.get('allow_inheritance') and use_types: + index_types = doc_cls._meta.get('index_types', True) + allow_inheritance = doc_cls._meta.get('allow_inheritance') + if spec.get('types', index_types) and allow_inheritance and use_types: index_list.insert(0, ('_types', 1)) - return index_list + spec['fields'] = index_list - def __call__(self, q_obj=None, **query): + if spec.get('sparse', False) and len(spec['fields']) > 1: + raise ValueError( + 'Sparse indexes can only have one field in them. ' + 'See https://jira.mongodb.org/browse/SERVER-2193') + + return spec + + @classmethod + def _reset_already_indexed(cls): + """Helper to reset already indexed, can be useful for testing purposes""" + cls.__already_indexed = set() + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -357,16 +446,19 @@ class QuerySet(object): the query; the :class:`~mongoengine.queryset.QuerySet` is filtered multiple times with different :class:`~mongoengine.queryset.Q` objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. :param query: Django-style query keyword arguments """ - #if q_obj: - #self._where_clause = q_obj.as_js(self._document) query = Q(**query) if q_obj: query &= q_obj self._query_obj &= query self._mongo_query = None self._cursor_obj = None + self._class_check = class_check return self def filter(self, *q_objs, **query): @@ -383,55 +475,83 @@ class QuerySet(object): """Property that returns the collection object. This allows us to perform operations only if the collection is accessed. """ - if not self._accessed_collection: - self._accessed_collection = True + if self._document not in QuerySet.__already_indexed: + QuerySet.__already_indexed.add(self._document) background = self._document._meta.get('index_background', False) drop_dups = self._document._meta.get('index_drop_dups', False) index_opts = self._document._meta.get('index_options', {}) + index_types = self._document._meta.get('index_types', True) - # Ensure document-defined indexes are created - if self._document._meta['indexes']: - for key_or_list in self._document._meta['indexes']: - self._collection.ensure_index(key_or_list, - background=background, **index_opts) + # determine if an index which we are creating includes + # _type as its first field; if so, we can avoid creating + # an extra index on _type, as mongodb will use the existing + # index to service queries against _type + types_indexed = False + def includes_types(fields): + first_field = None + if len(fields): + if isinstance(fields[0], basestring): + first_field = fields[0] + elif isinstance(fields[0], (list, tuple)) and len(fields[0]): + first_field = fields[0][0] + return first_field == '_types' # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: + types_indexed = types_indexed or includes_types(index) self._collection.ensure_index(index, unique=True, background=background, drop_dups=drop_dups, **index_opts) - # If _types is being used (for polymorphism), it needs an index - if '_types' in self._query: + # Ensure document-defined indexes are created + if self._document._meta['indexes']: + for spec in self._document._meta['indexes']: + types_indexed = types_indexed or includes_types(spec['fields']) + opts = index_opts.copy() + opts['unique'] = spec.get('unique', False) + opts['sparse'] = spec.get('sparse', False) + self._collection.ensure_index(spec['fields'], + background=background, **opts) + + # If _types is being used (for polymorphism), it needs an index, + # only if another index doesn't begin with _types + if index_types and '_types' in self._query and not types_indexed: self._collection.ensure_index('_types', background=background, **index_opts) - # Ensure all needed field indexes are created - for field in self._document._fields.values(): - if field.__class__._geo_index: - index_spec = [(field.db_field, pymongo.GEO2D)] - self._collection.ensure_index(index_spec, - background=background, **index_opts) + # Add geo indicies + for field in self._document._geo_indices(): + index_spec = [(field.db_field, pymongo.GEO2D)] + self._collection.ensure_index(index_spec, + background=background, **index_opts) return self._collection_obj + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + 'slave_okay': self._slave_okay + } + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + @property def _cursor(self): if self._cursor_obj is None: - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - } - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields - self._cursor_obj = self._collection.find(self._query, - **cursor_args) + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) # Apply where clauses to cursor if self._where_clause: self._cursor_obj.where(self._where_clause) # apply default ordering - if self._document._meta['ordering']: + if self._ordering: + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: self.order_by(*self._document._meta['ordering']) if self._limit is not None: @@ -440,6 +560,9 @@ class QuerySet(object): if self._skip is not None: self._cursor_obj.skip(self._skip) + if self._hint != -1: + self._cursor_obj.hint(self._hint) + return self._cursor_obj @classmethod @@ -451,7 +574,17 @@ class QuerySet(object): parts = [parts] fields = [] field = None + for field_name in parts: + # Handle ListField indexing: + if field_name.isdigit(): + try: + new_field = field.field + except AttributeError, err: + raise InvalidQueryError( + "Can't use index on unsubscriptable field (%s)" % err) + fields.append(field_name) + continue if field is None: # Look up first field from the document if field_name == 'pk': @@ -460,11 +593,17 @@ class QuerySet(object): field = document._fields[field_name] else: # Look up subfield on the previous field - field = field.lookup_member(field_name) - if field is None: + new_field = field.lookup_member(field_name) + from base import ComplexBaseField + if not new_field and isinstance(field, ComplexBaseField): + fields.append(field_name) + continue + elif not new_field: raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) + % field_name) + field = new_field # update field to the new field type fields.append(field) + return fields @classmethod @@ -476,14 +615,14 @@ class QuerySet(object): return '.'.join(parts) @classmethod - def _transform_query(cls, _doc_cls=None, **query): + def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not'] - geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'within_polygon', 'near', 'near_sphere'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -508,14 +647,33 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + + cleaned_fields = [] + append_field = True + for field in fields: + if isinstance(field, str): + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] + field = cleaned_fields[-1] + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += match_operators if op in singular_ops: - value = field.prepare_query_value(op, value) + if isinstance(field, basestring): + if op in match_operators and isinstance(value, basestring): + from mongoengine import StringField + value = StringField().prepare_query_value(op, value) + else: + value = field + else: + value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] @@ -527,6 +685,8 @@ class QuerySet(object): value = {'$within': {'$center': value}} elif op == "within_spherical_distance": value = {'$within': {'$centerSphere': value}} + elif op == "within_polygon": + value = {'$within': {'$polygon': value}} elif op == "near": value = {'$near': value} elif op == "near_sphere": @@ -572,9 +732,9 @@ class QuerySet(object): raise self._document.DoesNotExist("%s matching query does not exist." % self._document._class_name) - def get_or_create(self, *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a tuple of - ``(object, created)``, where ``object`` is the retrieved or created object + def get_or_create(self, write_options=None, *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a tuple of + ``(object, created)``, where ``object`` is the retrieved or created object and ``created`` is a boolean specifying whether a new object was created. Raises :class:`~mongoengine.queryset.MultipleObjectsReturned` or `DocumentName.MultipleObjectsReturned` if multiple results are found. @@ -582,6 +742,10 @@ class QuerySet(object): dictionary of default values for the new document may be provided as a keyword argument called :attr:`defaults`. + :param write_options: optional extra keyword arguments used if we + have to create a new document. + Passes any write_options onto :meth:`~mongoengine.Document.save` + .. versionadded:: 0.3 """ defaults = query.get('defaults', {}) @@ -593,7 +757,7 @@ class QuerySet(object): if count == 0: query.update(defaults) doc = self._document(**query) - doc.save() + doc.save(write_options=write_options) return doc, True elif count == 1: return self.first(), False @@ -619,18 +783,52 @@ class QuerySet(object): result = None return result + def insert(self, doc_or_docs, load_bulk=True): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document instances + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + from document import Document + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = "Some documents inserted aren't instances of %s" % str(self._document) + raise OperationError(msg) + if doc.pk: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + ids = self._collection.insert(raw) + + if not load_bulk: + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + return return_one and results[0] or results + def with_id(self, object_id): """Retrieve the object matching the id provided. :param object_id: the value for the id of the document to look up """ - id_field = self._document._meta['id_field'] - object_id = self._document._fields[id_field].to_mongo(object_id) - - result = self._collection.find_one({'_id': object_id}) - if result is not None: - result = self._document._from_son(result) - return result + return self._document.objects(pk=object_id).first() def in_bulk(self, object_ids): """Retrieve a set of documents by their ids. @@ -643,7 +841,8 @@ class QuerySet(object): """ doc_map = {} - docs = self._collection.find({'_id': {'$in': object_ids}}) + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) for doc in docs: doc_map[doc['_id']] = self._document._from_son(doc) @@ -677,8 +876,8 @@ class QuerySet(object): def __len__(self): return self.count() - def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, - scope=None, keep_temp=False): + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): """Perform a map/reduce query using the current query spec and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, it must be the last call made, as it does not return a maleable @@ -691,26 +890,33 @@ class QuerySet(object): :param map_f: map function, as :class:`~pymongo.code.Code` or string :param reduce_f: reduce function, as :class:`~pymongo.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` :param finalize_f: finalize function, an optional function that performs any post-reduction processing. :param scope: values to insert into map/reduce global scope. Optional. :param limit: number of objects from current query to provide to map/reduce method - :param keep_temp: keep temporary table (boolean, default ``True``) Returns an iterator yielding :class:`~mongoengine.document.MapReduceDocument`. - .. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.2**. + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 .. versionadded:: 0.3 """ from document import MapReduceDocument if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.1.1") + raise NotImplementedError("Requires MongoDB >= 1.7.1") map_f_scope = {} if isinstance(map_f, pymongo.code.Code): @@ -725,7 +931,7 @@ class QuerySet(object): reduce_f_code = self._sub_js_fields(reduce_f) reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) - mr_args = {'query': self._query, 'keeptemp': keep_temp} + mr_args = {'query': self._query} if finalize_f: finalize_f_scope = {} @@ -742,8 +948,16 @@ class QuerySet(object): if limit: mr_args['limit'] = limit - results = self._collection.map_reduce(map_f, reduce_f, **mr_args) - results = results.find() + if output == 'inline' and not self._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(self._collection, map_reduce_function)(map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() if self._ordering: results = results.sort(self._ordering) @@ -777,6 +991,23 @@ class QuerySet(object): self._skip = n return self + def hint(self, index=None): + """Added 'hint' support, telling Mongo the proper index to use for the + query. + + Judicious use of hints can greatly improve query performance. When doing + a query on multiple fields (at least one of which is indexed) pass the + indexed field as a hint to the query. + + Hinting will not do anything if the corresponding index does not exist. + The last hint applied to this cursor takes precedence over all others. + + .. versionadded:: 0.5 + """ + self._cursor.hint(index) + self._hint = index + return self + def __getitem__(self, key): """Support skip and limit using getitem and slicing syntax. """ @@ -787,7 +1018,7 @@ class QuerySet(object): self._skip, self._limit = key.start, key.stop except IndexError, err: # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. + # bin it, kill it. start = key.start or 0 if start >= 0 and key.stop >= 0 and key.step is None: if start == key.stop: @@ -814,26 +1045,82 @@ class QuerySet(object): def only(self, *fields): """Load only a subset of this document's fields. :: - post = BlogPost.objects(...).only("title") + post = BlogPost.objects(...).only("title", "author.name") :param fields: fields to include .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Added subfield support """ - self._loaded_fields = [] - for field in fields: - if '.' in field: - raise InvalidQueryError('Subfields cannot be used as ' - 'arguments to QuerySet.only') - # Translate field name - field = QuerySet._lookup_field(self._document, field)[-1].db_field - self._loaded_fields.append(field) + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(**fields) - # _cls is needed for polymorphism - if self._document._meta.get('allow_inheritance'): - self._loaded_fields += ['_cls'] + def exclude(self, *fields): + """Opposite to .only(), exclude some document's fields. :: + + post = BlogPost.objects(...).exclude("comments") + + :param fields: fields to exclude + + .. versionadded:: 0.5 + """ + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements: + + You can use the $slice operator to retrieve a subrange of elements in + an array :: + + post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = self._fields_to_dbfields(fields) + self._loaded_fields += QueryFieldList(fields, value=value) return self + def all_fields(self): + """Include all fields. Reset all previously calls of .only() and .exclude(). :: + + post = BlogPost.objects(...).exclude("comments").only("title").all_fields() + + .. versionadded:: 0.5 + """ + self._loaded_fields = QueryFieldList(always_include=self._loaded_fields.always_include) + return self + + def _fields_to_dbfields(self, fields): + """Translate fields paths to its db equivalents""" + ret = [] + for field in fields: + field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) + ret.append(field) + return ret + def order_by(self, *keys): """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The order may be specified by prepending each of the keys by a + or a -. @@ -851,6 +1138,10 @@ class QuerySet(object): if key[0] in ('-', '+'): key = key[1:] key = key.replace('__', '.') + try: + key = QuerySet._translate_field_name(self._document, key) + except: + pass key_list.append((key, direction)) self._ordering = key_list @@ -873,21 +1164,57 @@ class QuerySet(object): """Enable or disable snapshot mode when querying. :param enabled: whether or not snapshot mode is enabled + + ..versionchanged:: 0.5 - made chainable """ self._snapshot = enabled + return self def timeout(self, enabled): """Enable or disable the default mongod timeout when querying. :param enabled: whether or not the timeout is used + + ..versionchanged:: 0.5 - made chainable """ self._timeout = enabled + return self + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + self._slave_okay = enabled + return self def delete(self, safe=False): """Delete the documents matched by the query. :param safe: check if the operation succeeded before returning """ + doc = self._document + + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in doc._meta['delete_rules']: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects(**{field_name + '__in': self}).count() > 0: + msg = u'Could not delete document (at least %s.%s refers to it)' % \ + (document_cls.__name__, field_name) + raise OperationError(msg) + + for rule_entry in doc._meta['delete_rules']: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + document_cls.objects(**{field_name + '__in': self}).delete(safe=safe) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + safe_update=safe, + **{'unset__%s' % field_name: 1}) + self._collection.remove(self._query, safe=safe) @classmethod @@ -919,12 +1246,26 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + + cleaned_fields = [] + append_field = True + for field in fields: + if isinstance(field, str): + # Convert the S operator to $ + if field == 'S': + field = '$' + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] - if op in (None, 'set', 'unset', 'pop', 'push', 'pull', - 'addToSet'): + field = cleaned_fields[-1] + + if op in (None, 'set', 'push', 'pull', 'addToSet'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] @@ -942,22 +1283,27 @@ class QuerySet(object): return mongo_update - def update(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on the fields matched by the query. When + def update(self, safe_update=True, upsert=False, multi=True, write_options=None, **update): + """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning - :param update: Django-style update keyword arguments + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` .. versionadded:: 0.2 """ - if pymongo.version < '1.1.1': - raise OperationError('update() method requires PyMongo 1.1.1+') + if not update: + raise OperationError("No update parameters, would remove data") + + if not write_options: + write_options = {} update = QuerySet._transform_update(self._document, **update) try: - ret = self._collection.update(self._query, update, multi=True, - upsert=upsert, safe=safe_update) + ret = self._collection.update(self._query, update, multi=multi, + upsert=upsert, safe=safe_update, + **write_options) if ret is not None and 'n' in ret: return ret['n'] except pymongo.errors.OperationFailure, err: @@ -966,26 +1312,30 @@ class QuerySet(object): raise OperationError(message) raise OperationError(u'Update failed (%s)' % unicode(err)) - def update_one(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on first field matched by the query. When + def update_one(self, safe_update=True, upsert=False, write_options=None, **update): + """Perform an atomic update on first field matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` :param update: Django-style update keyword arguments .. versionadded:: 0.2 """ + if not update: + raise OperationError("No update parameters, would remove data") + + if not write_options: + write_options = {} update = QuerySet._transform_update(self._document, **update) try: # Explicitly provide 'multi=False' to newer versions of PyMongo # as the default may change to 'True' - if pymongo.version >= '1.1.1': - ret = self._collection.update(self._query, update, multi=False, - upsert=upsert, safe=safe_update) - else: - # Older versions of PyMongo don't support 'multi' - ret = self._collection.update(self._query, update, - safe=safe_update) + ret = self._collection.update(self._query, update, multi=False, + upsert=upsert, safe=safe_update, + **write_options) + if ret is not None and 'n' in ret: return ret['n'] except pymongo.errors.OperationFailure, e: @@ -995,8 +1345,8 @@ class QuerySet(object): return self def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ @@ -1007,7 +1357,16 @@ class QuerySet(object): # Substitute the correct name for the field into the javascript return u'["%s"]' % fields[-1].db_field - return re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = QuerySet._lookup_field(self._document, field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, code) + return code def exec_js(self, code, *fields, **options): """Execute a Javascript function on the server. A list of fields may be @@ -1019,9 +1378,9 @@ class QuerySet(object): options specified as keyword arguments. As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` + using the :attr:`db_field` keyword argument to a :class:`Field` constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a + with the database field names in Javascript code. When accessing a field, use square-bracket notation, and prefix the MongoEngine field name with a tilde (~). @@ -1035,7 +1394,7 @@ class QuerySet(object): fields = [QuerySet._translate_field_name(self._document, f) for f in fields] - collection = self._document._meta['collection'] + collection = self._document._get_collection_name() scope = { 'collection': collection, @@ -1052,62 +1411,170 @@ class QuerySet(object): db = _get_db() return db.eval(code, *fields) + def where(self, where_clause): + """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript + expression). Performs automatic field name substitution like + :meth:`mongoengine.queryset.Queryset.exec_js`. + + .. note:: When using this mode of query, the database will call your + function, or evaluate your predicate clause, for each object + in the collection. + + .. versionadded:: 0.5 + """ + where_clause = self._sub_js_fields(where_clause) + self._where_clause = where_clause + return self + def sum(self, field): """Sum over the values of the specified field. :param field: the field to sum over; use dot-notation to refer to embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. """ - sum_func = """ - function(sumField) { - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - total += (doc[sumField] || 0.0); - }); - return total; + map_func = pymongo.code.Code(""" + function() { + emit(1, this[field] || 0); } - """ - return self.exec_js(sum_func, field) + """, scope={'field': field}) + + reduce_func = pymongo.code.Code(""" + function(key, values) { + var sum = 0; + for (var i in values) { + sum += values[i]; + } + return sum; + } + """) + + for result in self.map_reduce(map_func, reduce_func, output='inline'): + return result.value + else: + return 0 def average(self, field): """Average over the values of the specified field. :param field: the field to average over; use dot-notation to refer to embedded document fields - """ - average_func = """ - function(averageField) { - var total = 0.0; - var num = 0; - db[collection].find(query).forEach(function(doc) { - if (doc[averageField] !== undefined) { - total += doc[averageField]; - num += 1; - } - }); - return total / num; - } - """ - return self.exec_js(average_func, field) - def item_frequencies(self, field, normalize=False): + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = pymongo.code.Code(""" + function() { + if (this.hasOwnProperty(field)) + emit(1, {t: this[field] || 0, c: 1}); + } + """, scope={'field': field}) + + reduce_func = pymongo.code.Code(""" + function(key, values) { + var out = {t: 0, c: 0}; + for (var i in values) { + var value = values[i]; + out.t += value.t; + out.c += value.c; + } + return out; + } + """) + + finalize_func = pymongo.code.Code(""" + function(key, value) { + return value.t / value.c; + } + """) + + for result in self.map_reduce(map_func, reduce_func, finalize_f=finalize_func, output='inline'): + return result.value + else: + return 0 + + def item_frequencies(self, field, normalize=False, map_reduce=True): """Returns a dictionary of all items present in a field across the whole queried set of documents, and their corresponding frequency. This is useful for generating tag clouds, or searching documents. + .. note:: + + Can only do direct simple mappings and cannot map across + :class:`~mongoengine.ReferenceField` or + :class:`~mongoengine.GenericReferenceField` for more complex + counting a manual map reduce call would is required. + If the field is a :class:`~mongoengine.ListField`, the items within each list will be counted individually. :param field: the field to use :param normalize: normalize the results so they add to 1.0 + :param map_reduce: Use map_reduce over exec_js + + .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded + document lookups """ + if map_reduce: + return self._item_frequencies_map_reduce(field, normalize=normalize) + return self._item_frequencies_exec_js(field, normalize=normalize) + + def _item_frequencies_map_reduce(self, field, normalize=False): + map_func = """ + function() { + path = '{{~%(field)s}}'.split('.'); + field = this; + for (p in path) { field = field[path[p]]; } + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(item, 1); + }); + } else { + emit(field, 1); + } + } + """ % dict(field=field) + reduce_func = """ + function(key, values) { + var total = 0; + var valuesSize = values.length; + for (var i=0; i < valuesSize; i++) { + total += parseInt(values[i], 10); + } + return total; + } + """ + values = self.map_reduce(map_func, reduce_func, 'inline') + frequencies = {} + for f in values: + key = f.key + if isinstance(key, float): + if int(key) == key: + key = int(key) + key = str(key) + frequencies[key] = f.value + + if normalize: + count = sum(frequencies.values()) + frequencies = dict([(k, v / count) for k, v in frequencies.items()]) + + return frequencies + + def _item_frequencies_exec_js(self, field, normalize=False): + """Uses exec_js to execute""" freq_func = """ - function(field) { + function(path) { + path = path.split('.'); + if (options.normalize) { var total = 0.0; db[collection].find(query).forEach(function(doc) { - if (doc[field].constructor == Array) { - total += doc[field].length; + field = doc; + for (p in path) { field = field[path[p]]; } + if (field && field.constructor == Array) { + total += field.length; } else { total++; } @@ -1120,34 +1587,55 @@ class QuerySet(object): inc /= total; } db[collection].find(query).forEach(function(doc) { - if (doc[field].constructor == Array) { - doc[field].forEach(function(item) { - frequencies[item] = inc + (frequencies[item] || 0); + field = doc; + for (p in path) { field = field[path[p]]; } + if (field && field.constructor == Array) { + field.forEach(function(item) { + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); }); } else { - var item = doc[field]; - frequencies[item] = inc + (frequencies[item] || 0); + var item = field; + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); } }); return frequencies; } """ - return self.exec_js(freq_func, field, normalize=normalize) + data = self.exec_js(freq_func, field, normalize=normalize) + if 'undefined' in data: + data[None] = data['undefined'] + del(data['undefined']) + return data def __repr__(self): limit = REPR_OUTPUT_SIZE + 1 if self._limit is not None and self._limit < limit: limit = self._limit - data = list(self[self._skip:limit]) + try: + data = list(self[self._skip:limit]) + except pymongo.errors.InvalidOperation: + return ".. queryset mid-iteration .." if len(data) > REPR_OUTPUT_SIZE: data[-1] = "...(remaining elements truncated)..." return repr(data) + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to + a maximum depth in order to cut down the number queries to mongodb. + + .. versionadded:: 0.5 + """ + from dereference import dereference + return dereference(self, max_depth=max_depth) + class QuerySetManager(object): - def __init__(self, manager_func=None): - self._manager_func = manager_func + get_queryset = None + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func self._collections = {} def __get__(self, instance, owner): @@ -1158,44 +1646,14 @@ class QuerySetManager(object): # Document class being used rather than a document object return self - db = _get_db() - collection = owner._meta['collection'] - if (db, collection) not in self._collections: - # Create collection as a capped collection if specified - if owner._meta['max_size'] or owner._meta['max_documents']: - # Get max document limit and max byte size from meta - max_size = owner._meta['max_size'] or 10000000 # 10MB default - max_documents = owner._meta['max_documents'] - - if collection in db.collection_names(): - self._collections[(db, collection)] = db[collection] - # The collection already exists, check if its capped - # options match the specified capped options - options = self._collections[(db, collection)].options() - if options.get('max') != max_documents or \ - options.get('size') != max_size: - msg = ('Cannot create collection "%s" as a capped ' - 'collection as it already exists') % collection - raise InvalidCollectionError(msg) - else: - # Create the collection as a capped collection - opts = {'capped': True, 'size': max_size} - if max_documents: - opts['max'] = max_documents - self._collections[(db, collection)] = db.create_collection( - collection, **opts - ) - else: - self._collections[(db, collection)] = db[collection] - # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet - queryset = queryset_class(owner, self._collections[(db, collection)]) - if self._manager_func: - if self._manager_func.func_code.co_argcount == 1: - queryset = self._manager_func(queryset) + queryset = queryset_class(owner, owner._get_collection()) + if self.get_queryset: + if self.get_queryset.func_code.co_argcount == 1: + queryset = self.get_queryset(queryset) else: - queryset = self._manager_func(owner, queryset) + queryset = self.get_queryset(owner, queryset) return queryset diff --git a/mongoengine/signals.py b/mongoengine/signals.py new file mode 100644 index 00000000..0a697534 --- /dev/null +++ b/mongoengine/signals.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', + 'pre_delete', 'post_delete'] + +signals_available = False +try: + from blinker import Namespace + signals_available = True +except ImportError: + class Namespace(object): + def signal(self, name, doc=None): + return _FakeSignal(name, doc) + + class _FakeSignal(object): + """If blinker is unavailable, create a fake class with the same + interface that allows sending of signals but will fail with an + error on anything else. Instead of doing anything on send, it + will just ignore the arguments and do nothing instead. + """ + + def __init__(self, name, doc=None): + self.name = name + self.__doc__ = doc + + def _fail(self, *args, **kwargs): + raise RuntimeError('signalling support is unavailable ' + 'because the blinker library is ' + 'not installed.') + send = lambda *a, **kw: None + connect = disconnect = has_receivers_for = receivers_for = \ + temporarily_connected_to = _fail + del _fail + +# the namespace for code signals. If you are not mongoengine code, do +# not put signals in here. Create your own namespace instead. +_signals = Namespace() + +pre_init = _signals.signal('pre_init') +post_init = _signals.signal('post_init') +pre_save = _signals.signal('pre_save') +post_save = _signals.signal('post_save') +pre_delete = _signals.signal('pre_delete') +post_delete = _signals.signal('post_delete') diff --git a/mongoengine/tests.py b/mongoengine/tests.py new file mode 100644 index 00000000..9584bc7c --- /dev/null +++ b/mongoengine/tests.py @@ -0,0 +1,59 @@ +from mongoengine.connection import _get_db + + +class query_counter(object): + """ Query_counter contextmanager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = _get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/setup.py b/setup.py index e0585b7c..6877b625 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version(version_tuple): version = '%s.%s' % (version, version_tuple[2]) return version -# Dirty hack to get version number from monogengine/__init__.py - we can't +# Dirty hack to get version number from monogengine/__init__.py - we can't # import it as it depends on PyMongo and PyMongo isn't installed until this # file is read init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py') @@ -47,4 +47,5 @@ setup(name='mongoengine', classifiers=CLASSIFIERS, install_requires=['pymongo'], test_suite='tests', + tests_require=['blinker', 'django==1.3'] ) diff --git a/tests/dereference.py b/tests/dereference.py new file mode 100644 index 00000000..a98267fd --- /dev/null +++ b/tests/dereference.py @@ -0,0 +1,658 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import _get_db +from mongoengine.tests import query_counter + + +class FieldTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = _get_db() + + def test_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + + group = Group(members=User.objects) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + [m for m in group_obj.members] + self.assertEqual(q, 2) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + + User.drop_collection() + Group.drop_collection() + + def test_recursive_reference(self): + """Ensure that ReferenceFields can reference their own documents. + """ + class Employee(Document): + name = StringField() + boss = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + Employee.drop_collection() + + bill = Employee(name='Bill Lumbergh') + bill.save() + + michael = Employee(name='Michael Bolton') + michael.save() + + samir = Employee(name='Samir Nagheenanajar') + samir.save() + + friends = [michael, samir] + peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter.save() + + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id) + self.assertEqual(q, 1) + + peter.boss + self.assertEqual(q, 2) + + peter.friends + self.assertEqual(q, 3) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + peter = Employee.objects.with_id(peter.id).select_related() + self.assertEqual(q, 2) + + self.assertEquals(peter.boss, bill) + self.assertEqual(q, 2) + + self.assertEquals(peter.friends, friends) + self.assertEqual(q, 2) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + employees = Employee.objects(boss=bill).select_related() + self.assertEqual(q, 2) + + for employee in employees: + self.assertEquals(employee.boss, bill) + self.assertEqual(q, 2) + + self.assertEquals(employee.friends, friends) + self.assertEqual(q, 2) + + def test_generic_reference(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_list_field_complex(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_map_field_reference(self): + + class User(Document): + name = StringField() + + class Group(Document): + members = MapField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + user = User(name='user %s' % i) + user.save() + members.append(user) + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + + User.drop_collection() + Group.drop_collection() + + def test_dict_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = DictField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + Group.objects.delete() + Group().save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + self.assertEqual(group_obj.members, {}) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_dict_field_no_field_inheritance(self): + + class UserA(Document): + name = StringField() + meta = {'allow_inheritance': False} + + class Group(Document): + members = DictField() + + UserA.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + members += [a] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 2) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + UserA.drop_collection() + Group.drop_collection() + + def test_generic_reference_map_field(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = MapField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Document select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first().select_related() + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + # Queryset select_related + with query_counter() as q: + self.assertEqual(q, 0) + + group_objs = Group.objects.select_related() + self.assertEqual(q, 4) + + for group_obj in group_objs: + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + + Group.objects.delete() + Group().save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 1) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() diff --git a/tests/django_tests.py b/tests/django_tests.py new file mode 100644 index 00000000..9c7e3280 --- /dev/null +++ b/tests/django_tests.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +import unittest + +from mongoengine import * +from mongoengine.django.shortcuts import get_document_or_404 + +from django.http import Http404 +from django.template import Context, Template +from django.conf import settings +settings.configure() + +class QuerySetTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + class Person(Document): + name = StringField() + age = IntField() + self.Person = Person + + def test_order_by_in_django_template(self): + """Ensure that QuerySets are properly ordered in Django template. + """ + self.Person.drop_collection() + + self.Person(name="A", age=20).save() + self.Person(name="D", age=10).save() + self.Person(name="B", age=40).save() + self.Person(name="C", age=30).save() + + t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}") + + d = {"ol": self.Person.objects.order_by('-name')} + self.assertEqual(t.render(Context(d)), u'D-10:C-30:B-40:A-20:') + d = {"ol": self.Person.objects.order_by('+name')} + self.assertEqual(t.render(Context(d)), u'A-20:B-40:C-30:D-10:') + d = {"ol": self.Person.objects.order_by('-age')} + self.assertEqual(t.render(Context(d)), u'B-40:C-30:A-20:D-10:') + d = {"ol": self.Person.objects.order_by('+age')} + self.assertEqual(t.render(Context(d)), u'D-10:A-20:C-30:B-40:') + + self.Person.drop_collection() + + def test_q_object_filter_in_template(self): + + self.Person.drop_collection() + + self.Person(name="A", age=20).save() + self.Person(name="D", age=10).save() + self.Person(name="B", age=40).save() + self.Person(name="C", age=30).save() + + t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}") + + d = {"ol": self.Person.objects.filter(Q(age=10) | Q(name="C"))} + self.assertEqual(t.render(Context(d)), 'D-10:C-30:') + + # Check double rendering doesn't throw an error + self.assertEqual(t.render(Context(d)), 'D-10:C-30:') + + def test_get_document_or_404(self): + p = self.Person(name="G404") + p.save() + + self.assertRaises(Http404, get_document_or_404, self.Person, pk='1234') + self.assertEqual(p, get_document_or_404(self.Person, pk=p.pk)) + diff --git a/tests/document.py b/tests/document.py index 280b671e..95f37748 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,13 +1,23 @@ -import unittest -from datetime import datetime +import pickle import pymongo +import unittest +import warnings + +from datetime import datetime + +import pymongo +import pickle +import weakref + +from fixtures import Base, Mixin, PickleEmbedded, PickleTest from mongoengine import * +from mongoengine.base import _document_registry, NotRegistered, InvalidDocumentError from mongoengine.connection import _get_db class DocumentTest(unittest.TestCase): - + def setUp(self): connect(db='mongoenginetest') self.db = _get_db() @@ -17,12 +27,15 @@ class DocumentTest(unittest.TestCase): age = IntField() self.Person = Person + def tearDown(self): + self.Person.drop_collection() + def test_drop_collection(self): """Ensure that the collection may be dropped from the database. """ self.Person(name='Test').save() - collection = self.Person._meta['collection'] + collection = self.Person._get_collection_name() self.assertTrue(collection in self.db.collection_names()) self.Person.drop_collection() @@ -38,7 +51,7 @@ class DocumentTest(unittest.TestCase): name = name_field age = age_field non_field = True - + self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['age'], age_field) self.assertFalse('non_field' in Person._fields) @@ -49,6 +62,73 @@ class DocumentTest(unittest.TestCase): # Ensure Document isn't treated like an actual document self.assertFalse(hasattr(Document, '_fields')) + def test_collection_name(self): + """Ensure that a collection with a specified name may be used. + """ + + class DefaultNamingTest(Document): + pass + self.assertEquals('default_naming_test', DefaultNamingTest._get_collection_name()) + + class CustomNamingTest(Document): + meta = {'collection': 'pimp_my_collection'} + + self.assertEquals('pimp_my_collection', CustomNamingTest._get_collection_name()) + + class DynamicNamingTest(Document): + meta = {'collection': lambda c: "DYNAMO"} + self.assertEquals('DYNAMO', DynamicNamingTest._get_collection_name()) + + # Use Abstract class to handle backwards compatibility + class BaseDocument(Document): + meta = { + 'abstract': True, + 'collection': lambda c: c.__name__.lower() + } + + class OldNamingConvention(BaseDocument): + pass + self.assertEquals('oldnamingconvention', OldNamingConvention._get_collection_name()) + + class InheritedAbstractNamingTest(BaseDocument): + meta = {'collection': 'wibble'} + self.assertEquals('wibble', InheritedAbstractNamingTest._get_collection_name()) + + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + class NonAbstractBase(Document): + pass + + class InheritedDocumentFailTest(NonAbstractBase): + meta = {'collection': 'fail'} + + self.assertTrue(issubclass(w[0].category, SyntaxWarning)) + self.assertEquals('non_abstract_base', InheritedDocumentFailTest._get_collection_name()) + + # Mixin tests + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class OldMixinNamingConvention(Document, BaseMixin): + pass + self.assertEquals('oldmixinnamingconvention', OldMixinNamingConvention._get_collection_name()) + + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class BaseDocument(Document, BaseMixin): + pass + + class MyDocument(BaseDocument): + pass + self.assertEquals('mydocument', OldMixinNamingConvention._get_collection_name()) + def test_get_superclasses(self): """Ensure that the correct list of superclasses is assembled. """ @@ -60,7 +140,7 @@ class DocumentTest(unittest.TestCase): mammal_superclasses = {'Animal': Animal} self.assertEqual(Mammal._superclasses, mammal_superclasses) - + dog_superclasses = { 'Animal': Animal, 'Animal.Mammal': Mammal, @@ -68,7 +148,7 @@ class DocumentTest(unittest.TestCase): self.assertEqual(Dog._superclasses, dog_superclasses) def test_get_subclasses(self): - """Ensure that the correct list of subclasses is retrieved by the + """Ensure that the correct list of subclasses is retrieved by the _get_subclasses method. """ class Animal(Document): pass @@ -78,19 +158,64 @@ class DocumentTest(unittest.TestCase): class Dog(Mammal): pass mammal_subclasses = { - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) - + animal_subclasses = { 'Animal.Fish': Fish, 'Animal.Mammal': Mammal, - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Animal._get_subclasses(), animal_subclasses) + def test_external_super_and_sub_classes(self): + """Ensure that the correct list of sub and super classes is assembled. + when importing part of the model + """ + class Animal(Base): pass + class Fish(Animal): pass + class Mammal(Animal): pass + class Human(Mammal): pass + class Dog(Mammal): pass + + mammal_superclasses = {'Base': Base, 'Base.Animal': Animal} + self.assertEqual(Mammal._superclasses, mammal_superclasses) + + dog_superclasses = { + 'Base': Base, + 'Base.Animal': Animal, + 'Base.Animal.Mammal': Mammal, + } + self.assertEqual(Dog._superclasses, dog_superclasses) + + animal_subclasses = { + 'Base.Animal.Fish': Fish, + 'Base.Animal.Mammal': Mammal, + 'Base.Animal.Mammal.Dog': Dog, + 'Base.Animal.Mammal.Human': Human + } + self.assertEqual(Animal._get_subclasses(), animal_subclasses) + + mammal_subclasses = { + 'Base.Animal.Mammal.Dog': Dog, + 'Base.Animal.Mammal.Human': Human + } + self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) + + Base.drop_collection() + + h = Human() + h.save() + + self.assertEquals(Human.objects.count(), 1) + self.assertEquals(Mammal.objects.count(), 1) + self.assertEquals(Animal.objects.count(), 1) + self.assertEquals(Base.objects.count(), 1) + Base.drop_collection() + def test_polymorphic_queries(self): """Ensure that the correct subclasses are returned from a query""" class Animal(Document): pass @@ -99,6 +224,8 @@ class DocumentTest(unittest.TestCase): class Human(Mammal): pass class Dog(Mammal): pass + Animal.drop_collection() + Animal().save() Fish().save() Mammal().save() @@ -116,6 +243,77 @@ class DocumentTest(unittest.TestCase): Animal.drop_collection() + def test_polymorphic_references(self): + """Ensure that the correct subclasses are returned from a query when + using references / generic references + """ + class Animal(Document): pass + class Fish(Animal): pass + class Mammal(Animal): pass + class Human(Mammal): pass + class Dog(Mammal): pass + + class Zoo(Document): + animals = ListField(ReferenceField(Animal)) + + Zoo.drop_collection() + Animal.drop_collection() + + Animal().save() + Fish().save() + Mammal().save() + Human().save() + Dog().save() + + # Save a reference to each animal + zoo = Zoo(animals=Animal.objects) + zoo.save() + zoo.reload() + + classes = [a.__class__ for a in Zoo.objects.first().animals] + self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) + + Zoo.drop_collection() + + class Zoo(Document): + animals = ListField(GenericReferenceField(Animal)) + + # Save a reference to each animal + zoo = Zoo(animals=Animal.objects) + zoo.save() + zoo.reload() + + classes = [a.__class__ for a in Zoo.objects.first().animals] + self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) + + Zoo.drop_collection() + Animal.drop_collection() + + def test_reference_inheritance(self): + class Stats(Document): + created = DateTimeField(default=datetime.now) + + meta = {'allow_inheritance': False} + + class CompareStats(Document): + generated = DateTimeField(default=datetime.now) + stats = ListField(ReferenceField(Stats)) + + Stats.drop_collection() + CompareStats.drop_collection() + + list_stats = [] + + for i in xrange(10): + s = Stats() + s.save() + list_stats.append(s) + + cmp_stats = CompareStats(stats=list_stats) + cmp_stats.save() + + self.assertEqual(list_stats, CompareStats.objects.first().stats) + def test_inheritance(self): """Ensure that document may inherit fields from a superclass document. """ @@ -124,8 +322,8 @@ class DocumentTest(unittest.TestCase): self.assertTrue('name' in Employee._fields) self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._meta['collection'], - self.Person._meta['collection']) + self.assertEqual(Employee._get_collection_name(), + self.Person._get_collection_name()) # Ensure that MRO error is not raised class A(Document): pass @@ -136,21 +334,21 @@ class DocumentTest(unittest.TestCase): """Ensure that inheritance may be disabled on simple classes and that _cls and _types will not be used. """ + class Animal(Document): - meta = {'allow_inheritance': False} name = StringField() + meta = {'allow_inheritance': False} Animal.drop_collection() - def create_dog_class(): class Dog(Animal): pass self.assertRaises(ValueError, create_dog_class) - + # Check that _cls etc aren't present on simple documents dog = Animal(name='dog') dog.save() - collection = self.db[Animal._meta['collection']] + collection = self.db[Animal._get_collection_name()] obj = collection.find_one() self.assertFalse('_cls' in obj) self.assertFalse('_types' in obj) @@ -161,7 +359,7 @@ class DocumentTest(unittest.TestCase): class Employee(self.Person): meta = {'allow_inheritance': False} self.assertRaises(ValueError, create_employee_class) - + # Test the same for embedded documents class Comment(EmbeddedDocument): content = StringField() @@ -176,6 +374,123 @@ class DocumentTest(unittest.TestCase): self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo()) + def test_allow_inheritance_abstract_document(self): + """Ensure that abstract documents can set inheritance rules and that + _cls and _types will not be used. + """ + class FinalDocument(Document): + meta = {'abstract': True, + 'allow_inheritance': False} + + class Animal(FinalDocument): + name = StringField() + + Animal.drop_collection() + def create_dog_class(): + class Dog(Animal): + pass + self.assertRaises(ValueError, create_dog_class) + + # Check that _cls etc aren't present on simple documents + dog = Animal(name='dog') + dog.save() + collection = self.db[Animal._get_collection_name()] + obj = collection.find_one() + self.assertFalse('_cls' in obj) + self.assertFalse('_types' in obj) + + Animal.drop_collection() + + def test_how_to_turn_off_inheritance(self): + """Demonstrates migrating from allow_inheritance = True to False. + """ + class Animal(Document): + name = StringField() + meta = { + 'indexes': ['name'] + } + + Animal.drop_collection() + + dog = Animal(name='dog') + dog.save() + + collection = self.db[Animal._get_collection_name()] + obj = collection.find_one() + self.assertTrue('_cls' in obj) + self.assertTrue('_types' in obj) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'_types', 1), (u'name', 1)]], info) + + # Turn off inheritance + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': False, + 'indexes': ['name'] + } + collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) + + # Confirm extra data is removed + obj = collection.find_one() + self.assertFalse('_cls' in obj) + self.assertFalse('_types' in obj) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'_types', 1), (u'name', 1)]], info) + + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)]], info) + + # Recreate indexes + dog = Animal.objects.first() + dog.save() + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'name', 1),]], info) + + Animal.drop_collection() + + def test_abstract_documents(self): + """Ensure that a document superclass can be marked as abstract + thereby not using it as the name for the collection.""" + + class Animal(Document): + name = StringField() + meta = {'abstract': True} + + class Fish(Animal): pass + class Guppy(Fish): pass + + class Mammal(Animal): + meta = {'abstract': True} + class Human(Mammal): pass + + self.assertFalse('collection' in Animal._meta) + self.assertFalse('collection' in Mammal._meta) + + self.assertEqual(Animal._get_collection_name(), None) + self.assertEqual(Mammal._get_collection_name(), None) + + self.assertEqual(Fish._get_collection_name(), 'fish') + self.assertEqual(Guppy._get_collection_name(), 'fish') + self.assertEqual(Human._get_collection_name(), 'human') + + def create_bad_abstract(): + class EvilHuman(Human): + evil = BooleanField(default=True) + meta = {'abstract': True} + self.assertRaises(ValueError, create_bad_abstract) + def test_collection_name(self): """Ensure that a collection with a specified name may be used. """ @@ -186,7 +501,7 @@ class DocumentTest(unittest.TestCase): class Person(Document): name = StringField() meta = {'collection': collection} - + user = Person(name="Test User") user.save() self.assertTrue(collection in self.db.collection_names()) @@ -200,17 +515,40 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_collection_name_and_primary(self): + """Ensure that a collection with a specified name may be used. + """ + + class Person(Document): + name = StringField(primary_key=True) + meta = {'collection': 'app'} + + user = Person(name="Test User") + user.save() + + user_obj = Person.objects[0] + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' collections. """ - class Drink(Document): - name = StringField() + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") - class AlcoholicDrink(Drink): - meta = {'collection': 'booze'} + class Drink(Document): + name = StringField() - class Drinker(Document): - drink = GenericReferenceField() + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + class Drinker(Document): + drink = GenericReferenceField() + + # Confirm we triggered a SyntaxWarning + assert issubclass(w[0].category, SyntaxWarning) Drink.drop_collection() AlcoholicDrink.drop_collection() @@ -224,7 +562,6 @@ class DocumentTest(unittest.TestCase): beer = AlcoholicDrink(name='Beer') beer.save() - real_person = Drinker(drink=beer) real_person.save() @@ -280,7 +617,7 @@ class DocumentTest(unittest.TestCase): tags = ListField(StringField()) meta = { 'indexes': [ - '-date', + '-date', 'tags', ('category', '-date') ], @@ -289,19 +626,22 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() - # _id, types, '-date', 'tags', ('cat', 'date') - self.assertEqual(len(info), 5) + # _id, '-date', 'tags', ('cat', 'date') + # NB: there is no index on _types by itself, since + # the indices on -date and tags will both contain + # _types as first element in the key + self.assertEqual(len(info), 4) # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index self.assertTrue([('tags', 1)] in info) - + class ExtendedBlogPost(BlogPost): title = StringField() meta = {'indexes': ['title']} @@ -311,13 +651,137 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('title', 1)] in info) BlogPost.drop_collection() + def test_dictionary_indexes(self): + """Ensure that indexes are used when meta[indexes] contains dictionaries + instead of lists. + """ + class BlogPost(Document): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + { 'fields': ['-date'], 'unique': True, + 'sparse': True, 'types': False }, + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # _id, '-date' + self.assertEqual(len(info), 3) + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('addDate', -1)], True, True) in info) + + BlogPost.drop_collection() + + def test_embedded_document_index(self): + """Tests settings an index on an embedded document + """ + class Date(EmbeddedDocument): + year = IntField(db_field='yr') + + class BlogPost(Document): + title = StringField() + date = EmbeddedDocumentField(Date) + + meta = { + 'indexes': [ + '-date.year' + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + self.assertEqual(info.keys(), ['_types_1_date.yr_-1', '_id_']) + BlogPost.drop_collection() + + def test_list_embedded_document_index(self): + """Ensure list embedded documents can be indexed + """ + class Tag(EmbeddedDocument): + name = StringField(db_field='tag') + + class BlogPost(Document): + title = StringField() + tags = ListField(EmbeddedDocumentField(Tag)) + + meta = { + 'indexes': [ + 'tags.name' + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # we don't use _types in with list fields by default + self.assertEqual(info.keys(), ['_id_', '_types_1', 'tags.tag_1']) + + post1 = BlogPost(title="Embedded Indexes tests in place", + tags=[Tag(name="about"), Tag(name="time")] + ) + post1.save() + BlogPost.drop_collection() + + def test_geo_indexes_recursion(self): + + class User(Document): + channel = ReferenceField('Channel') + location = GeoPointField() + + class Channel(Document): + user = ReferenceField('User') + location = GeoPointField() + + self.assertEquals(len(User._geo_indices()), 2) + + def test_hint(self): + + class BlogPost(Document): + tags = ListField(StringField()) + meta = { + 'indexes': [ + 'tags', + ], + } + + BlogPost.drop_collection() + + for i in xrange(0, 10): + tags = [("tag %i" % n) for n in xrange(0, i % 2)] + BlogPost(tags=tags).save() + + self.assertEquals(BlogPost.objects.count(), 10) + self.assertEquals(BlogPost.objects.hint().count(), 10) + self.assertEquals(BlogPost.objects.hint([('tags', 1)]).count(), 10) + + self.assertEquals(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) + + def invalid_index(): + BlogPost.objects.hint('tags') + self.assertRaises(TypeError, invalid_index) + + def invalid_index_2(): + return BlogPost.objects.hint(('tags', 1)) + self.assertRaises(TypeError, invalid_index_2) + def test_unique(self): """Ensure that uniqueness constraints are applied to fields. """ @@ -334,6 +798,9 @@ class DocumentTest(unittest.TestCase): post2 = BlogPost(title='test2', slug='test') self.assertRaises(OperationError, post2.save) + def test_unique_with(self): + """Ensure that unique_with constraints are applied to fields. + """ class Date(EmbeddedDocument): year = IntField(db_field='yr') @@ -357,6 +824,108 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + def test_unique_embedded_document(self): + """Ensure that uniqueness constraints are applied to fields on embedded documents. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_with_embedded_document_and_embedded_unique(self): + """Ensure that uniqueness constraints are applied to fields on + embedded documents. And work with unique_with as well. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField(unique_with='sub.year') + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + # Now there will be two docs with the same title and year + post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_and_indexes(self): + """Ensure that 'unique' constraints aren't overridden by + meta.indexes. + """ + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'allow_inheritance': False, + } + + Customer.drop_collection() + cust = Customer(cust_id=1) + cust.save() + + cust_dupe = Customer(cust_id=1) + try: + cust_dupe.save() + raise AssertionError, "We saved a dupe!" + except OperationError: + pass + Customer.drop_collection() + + def test_unique_and_primary(self): + """If you set a field as primary, then unexpected behaviour can occur. + You won't create a duplicate but you will update an existing document. + """ + + class User(Document): + name = StringField(primary_key=True, unique=True) + password = StringField() + + User.drop_collection() + + user = User(name='huangz', password='secret') + user.save() + + user = User(name='huangz', password='secret2') + user.save() + + self.assertEqual(User.objects.count(), 1) + self.assertEqual(User.objects.get().password, 'secret2') + + User.drop_collection() + def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys. """ @@ -380,7 +949,7 @@ class DocumentTest(unittest.TestCase): class EmailUser(User): email = StringField() - + user = User(username='test', name='test user') user.save() @@ -391,22 +960,48 @@ class DocumentTest(unittest.TestCase): user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() - + user = User(pk='mongo', name='mongo user') user.save() - + user_obj = User.objects.first() self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.pk, 'mongo') - + user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'mongo') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() + + def test_document_not_registered(self): + + class Place(Document): + name = StringField() + + class NicePlace(Place): + pass + + Place.drop_collection() + + Place(name="London").save() + NicePlace(name="Buckingham Palace").save() + + # Mimic Place and NicePlace definitions being in a different file + # and the NicePlace model not being imported in at query time. + @classmethod + def _get_subclasses(cls): + return {} + Place._get_subclasses = _get_subclasses + + def query_without_importing_nice_place(): + print Place.objects.all() + self.assertRaises(NotRegistered, query_without_importing_nice_place) + + def test_creation(self): """Ensure that document may be created using keyword arguments. """ @@ -414,6 +1009,14 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 30) + def test_to_dbref(self): + """Ensure that you can get a dbref of a document""" + person = self.Person(name="Test User", age=30) + self.assertRaises(OperationError, person.to_dbref) + person.save() + + person.to_dbref() + def test_reload(self): """Ensure that attributes may be reloaded. """ @@ -432,6 +1035,47 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person.name, "Mr Test User") self.assertEqual(person.age, 21) + def test_reload_referencing(self): + """Ensures reloading updates weakrefs correctly + """ + class Embedded(EmbeddedDocument): + dict_field = DictField() + list_field = ListField() + + class Doc(Document): + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + doc = Doc() + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + + embedded_1 = Embedded() + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + doc.save() + + doc.reload() + doc.list_field.append(1) + doc.dict_field['woot'] = "woot" + doc.embedded_field.list_field.append(1) + doc.embedded_field.dict_field['woot'] = "woot" + + self.assertEquals(doc._get_changed_fields(), [ + 'list_field', 'dict_field', 'embedded_field.list_field', + 'embedded_field.dict_field']) + doc.save() + + doc.reload() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(len(doc.list_field), 4) + self.assertEquals(len(doc.dict_field), 2) + self.assertEquals(len(doc.embedded_field.list_field), 4) + self.assertEquals(len(doc.embedded_field.dict_field), 2) + def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -457,18 +1101,18 @@ class DocumentTest(unittest.TestCase): """ class Comment(EmbeddedDocument): content = StringField() - + self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) self.assertFalse('collection' in Comment._meta) - + def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ class Comment(EmbeddedDocument): date = DateTimeField() content = StringField(required=True) - + comment = Comment() self.assertRaises(ValidationError, comment.validate) @@ -488,7 +1132,7 @@ class DocumentTest(unittest.TestCase): person = self.Person(name='Test User', age=30) person.save() # Ensure that the object is in the database - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) @@ -496,13 +1140,774 @@ class DocumentTest(unittest.TestCase): # Test skipping validation on save class Recipient(Document): email = EmailField(required=True) - + recipient = Recipient(email='root@localhost') self.assertRaises(ValidationError, recipient.save) try: recipient.save(validate=False) except ValidationError: - fail() + self.fail() + + def test_save_to_a_value_that_equates_to_false(self): + + class Thing(EmbeddedDocument): + count = IntField() + + class User(Document): + thing = EmbeddedDocumentField(Thing) + + User.drop_collection() + + user = User(thing=Thing(count=1)) + user.save() + user.reload() + + user.thing.count = 0 + user.save() + + user.reload() + self.assertEquals(user.thing.count, 0) + + def test_save_max_recursion_not_hit(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friend = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p1.friend = p2 + p1.save() + + # Confirm can save and it resets the changed fields without hitting + # max recursion error + p0 = Person.objects.first() + p0.name = 'wpjunior' + p0.save() + + def test_save_cascades(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + + def test_save_cascades_generically(self): + + class Person(Document): + name = StringField() + parent = GenericReferenceField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEquals(p1.name, p.parent.name) + + def test_update(self): + """Ensure that an existing document is updated instead of be overwritten. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + + # Create same person object, with same id, without age + same_person = self.Person(name='Test') + same_person.id = person.id + same_person.save() + + # Confirm only one object + self.assertEquals(self.Person.objects.count(), 1) + + # reload + person.reload() + same_person.reload() + + # Confirm the same + self.assertEqual(person, same_person) + self.assertEqual(person.name, same_person.name) + self.assertEqual(person.age, same_person.age) + + # Confirm the saved values + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # Test only / exclude only updates included fields + person = self.Person.objects.only('name').get() + person.name = 'User' + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Test only / exclude can set non excluded / included fields + person = self.Person.objects.only('name').get() + person.name = 'Test' + person.age = 30 + person.save() + + person.reload() + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.name = 'User' + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Confirm does remove unrequired fields + person = self.Person.objects.exclude('name').get() + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, None) + + person = self.Person.objects.get() + person.name = None + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, None) + self.assertEqual(person.age, None) + + def test_document_update(self): + + def update_not_saved_raises(): + person = self.Person(name='dcrosta') + person.update(set__name='Dan Crosta') + + self.assertRaises(OperationError, update_not_saved_raises) + + author = self.Person(name='dcrosta') + author.save() + + author.update(set__name='Dan Crosta') + author.reload() + + p1 = self.Person.objects.first() + self.assertEquals(p1.name, author.name) + + def update_no_value_raises(): + person = self.Person.objects.first() + person.update() + + self.assertRaises(OperationError, update_no_value_raises) + + def test_embedded_update(self): + """ + Test update on `EmbeddedDocumentField` fields + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + + Site.drop_collection() + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_embedded_update_db_field(self): + """ + Test update on `EmbeddedDocumentField` fields when db_field is other + than default. + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + db_field="page_log_message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + + Site.drop_collection() + + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_delta(self): + + class Doc(Document): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEquals(doc._get_changed_fields(), ['string_field']) + self.assertEquals(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEquals(doc._get_changed_fields(), ['int_field']) + self.assertEquals(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEquals(doc._get_changed_fields(), ['dict_field']) + self.assertEquals(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEquals(doc._get_changed_fields(), ['list_field']) + self.assertEquals(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['dict_field']) + self.assertEquals(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['list_field']) + self.assertEquals(doc._delta(), ({}, {'list_field': 1})) + + def test_delta_recursive(self): + + class Embedded(EmbeddedDocument): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + class Doc(Document): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEquals(doc._get_changed_fields(), ['embedded_field']) + + embedded_delta = { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'string_field': 'hello', + 'int_field': 1, + 'dict_field': {'hello': 'world'}, + 'list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {})) + + doc.save() + doc.reload() + + doc.embedded_field.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['embedded_field.dict_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1})) + self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['embedded_field.list_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1})) + self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEquals(doc._get_changed_fields(), ['embedded_field.list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEquals(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc.reload() + + self.assertEquals(doc.embedded_field.list_field[0], '1') + self.assertEquals(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEquals(doc._get_changed_fields(), ['embedded_field.list_field.2.string_field']) + self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEquals(doc._get_changed_fields(), ['embedded_field.list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}}]}, {})) + self.assertEquals(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) + doc.save() + doc.reload() + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort() + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + doc.save() + doc.reload() + + del(doc.embedded_field.list_field[2].list_field) + self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) + + doc.save() + doc.reload() + + doc.dict_field['Embedded'] = embedded_1 + doc.save() + doc.reload() + + doc.dict_field['Embedded'].string_field = 'Hello World' + self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field']) + self.assertEquals(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) + + + def test_delta_db_field(self): + + class Doc(Document): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEquals(doc._get_changed_fields(), ['db_string_field']) + self.assertEquals(doc._delta(), ({'db_string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEquals(doc._get_changed_fields(), ['db_int_field']) + self.assertEquals(doc._delta(), ({'db_int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEquals(doc._get_changed_fields(), ['db_dict_field']) + self.assertEquals(doc._delta(), ({'db_dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEquals(doc._get_changed_fields(), ['db_list_field']) + self.assertEquals(doc._delta(), ({'db_list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['db_dict_field']) + self.assertEquals(doc._delta(), ({}, {'db_dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['db_list_field']) + self.assertEquals(doc._delta(), ({}, {'db_list_field': 1})) + + # Test it saves that data + doc = Doc() + doc.save() + + doc.string_field = 'hello' + doc.int_field = 1 + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + doc.save() + doc.reload() + + self.assertEquals(doc.string_field, 'hello') + self.assertEquals(doc.int_field, 1) + self.assertEquals(doc.dict_field, {'hello': 'world'}) + self.assertEquals(doc.list_field, ['1', 2, {'hello': 'world'}]) + + def test_delta_recursive_db_field(self): + + class Embedded(EmbeddedDocument): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + class Doc(Document): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + embedded_field = EmbeddedDocumentField(Embedded, db_field='db_embedded_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field']) + + embedded_delta = { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_int_field': 1, + 'db_dict_field': {'hello': 'world'}, + 'db_list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) + + doc.save() + doc.reload() + + doc.embedded_field.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEquals(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc.reload() + + self.assertEquals(doc.embedded_field.list_field[0], '1') + self.assertEquals(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field.2.db_string_field']) + self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}}]}, {})) + self.assertEquals(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) + doc.save() + doc.reload() + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort() + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) + doc.save() + doc.reload() + + del(doc.embedded_field.list_field[2].list_field) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) + + def test_save_only_changed_fields(self): + """Ensure save only sets / unsets changed fields + """ + + class User(self.Person): + active = BooleanField(default=True) + + + User.drop_collection() + + # Create person object and save it to the database + user = User(name='Test User', age=30, active=True) + user.save() + user.reload() + + # Simulated Race condition + same_person = self.Person.objects.get() + same_person.active = False + + user.age = 21 + user.save() + + same_person.name = 'User' + same_person.save() + + person = self.Person.objects.get() + self.assertEquals(person.name, 'User') + self.assertEquals(person.age, 21) + self.assertEquals(person.active, False) + + def test_save_only_changed_fields_recursive(self): + """Ensure save only sets / unsets changed fields + """ + + class Comment(EmbeddedDocument): + published = BooleanField(default=True) + + class User(self.Person): + comments_dict = DictField() + comments = ListField(EmbeddedDocumentField(Comment)) + active = BooleanField(default=True) + + User.drop_collection() + + # Create person object and save it to the database + person = User(name='Test User', age=30, active=True) + person.comments.append(Comment()) + person.save() + person.reload() + + person = self.Person.objects.get() + self.assertTrue(person.comments[0].published) + + person.comments[0].published = False + person.save() + + person = self.Person.objects.get() + self.assertFalse(person.comments[0].published) + + # Simple dict w + person.comments_dict['first_post'] = Comment() + person.save() + + person = self.Person.objects.get() + self.assertTrue(person.comments_dict['first_post'].published) + + person.comments_dict['first_post'].published = False + person.save() + + person = self.Person.objects.get() + self.assertFalse(person.comments_dict['first_post'].published) def test_delete(self): """Ensure that document may be deleted using the delete method. @@ -517,23 +1922,23 @@ class DocumentTest(unittest.TestCase): """Ensure that a document may be saved with a custom _id. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, id='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') - + def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, pk='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') @@ -556,7 +1961,7 @@ class DocumentTest(unittest.TestCase): post.comments = comments post.save() - collection = self.db[BlogPost._meta['collection']] + collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() self.assertEqual(post_obj['tags'], tags) for comment_obj, comment in zip(post_obj['comments'], comments): @@ -564,8 +1969,60 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + def test_list_search_by_embedded(self): + class User(Document): + username = StringField(required=True) + + meta = {'allow_inheritance': False} + + class Comment(EmbeddedDocument): + comment = StringField() + user = ReferenceField(User, + required=True) + + meta = {'allow_inheritance': False} + + class Page(Document): + comments = ListField(EmbeddedDocumentField(Comment)) + meta = {'allow_inheritance': False, + 'indexes': [ + {'fields': ['comments.user']} + ]} + + User.drop_collection() + Page.drop_collection() + + u1 = User(username="wilson") + u1.save() + + u2 = User(username="rozza") + u2.save() + + u3 = User(username="hmarr") + u3.save() + + p1 = Page(comments = [Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world"), + Comment(user=u3, comment="Ping Pong"), + Comment(user=u1, comment="I like a beer")]) + p1.save() + + p2 = Page(comments = [Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world")]) + p2.save() + + p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) + p3.save() + + p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) + p4.save() + + self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) + self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) + self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) + def test_save_embedded_document(self): - """Ensure that a document with an embedded document field may be + """Ensure that a document with an embedded document field may be saved in the database. """ class EmployeeDetails(EmbeddedDocument): @@ -581,17 +2038,83 @@ class DocumentTest(unittest.TestCase): employee.save() # Ensure that the object is in the database - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] employee_obj = collection.find_one({'name': 'Test Employee'}) self.assertEqual(employee_obj['name'], 'Test Employee') self.assertEqual(employee_obj['age'], 50) # Ensure that the 'details' embedded object saved correctly self.assertEqual(employee_obj['details']['position'], 'Developer') + def test_updating_an_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Test updating an embedded document + promoted_employee = Employee.objects.get(name='Test Employee') + promoted_employee.details.position = 'Senior Developer' + promoted_employee.save() + + promoted_employee.reload() + self.assertEqual(promoted_employee.name, 'Test Employee') + self.assertEqual(promoted_employee.age, 50) + + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(promoted_employee.details.position, 'Senior Developer') + + # Test removal + promoted_employee.details = None + promoted_employee.save() + + promoted_employee.reload() + self.assertEqual(promoted_employee.details, None) + + def test_mixins_dont_add_to_types(self): + + class Bob(Document): name = StringField() + + Bob.drop_collection() + + p = Bob(name="Rozza") + p.save() + Bob.drop_collection() + + class Person(Document, Mixin): + pass + + Person.drop_collection() + + p = Person(name="Rozza") + p.save() + self.assertEquals(p._fields.keys(), ['name', 'id']) + + collection = self.db[Person._get_collection_name()] + obj = collection.find_one() + self.assertEquals(obj['_cls'], 'Person') + self.assertEquals(obj['_types'], ['Person']) + + + + self.assertEquals(Person.objects.count(), 1) + rozza = Person.objects.get(name="Rozza") + + Person.drop_collection() + def test_save_reference(self): """Ensure that a document reference field may be saved in the database. """ - + class BlogPost(Document): meta = {'collection': 'blogpost_1'} content = StringField() @@ -610,7 +2133,7 @@ class DocumentTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], + self.assertTrue(isinstance(post_obj._data['author'], pymongo.dbref.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') @@ -624,8 +2147,129 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() - def tearDown(self): + + def test_reverse_delete_rule_cascade_and_nullify(self): + """Ensure that a referenced document is also deleted upon deletion. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + reviewer = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) + self.Person.drop_collection() + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + reviewer = self.Person(name='Re Viewer') + reviewer.save() + + post = BlogPost(content = 'Watched some TV') + post.author = author + post.reviewer = reviewer + post.save() + + reviewer.delete() + self.assertEqual(len(BlogPost.objects), 1) # No effect on the BlogPost + self.assertEqual(BlogPost.objects.get().reviewer, None) + + # Delete the Person, which should lead to deletion of the BlogPost, too + author.delete() + self.assertEqual(len(BlogPost.objects), 0) + + def test_reverse_delete_rule_cascade_recurs(self): + """Ensure that a chain of documents is also deleted upon cascaded + deletion. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + + class Comment(Document): + text = StringField() + post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE) + + self.Person.drop_collection() + BlogPost.drop_collection() + Comment.drop_collection() + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content = 'Watched some TV') + post.author = author + post.save() + + comment = Comment(text = 'Kudos.') + comment.post = post + comment.save() + + # Delete the Person, which should lead to deletion of the BlogPost, and, + # recursively to the Comment, too + author.delete() + self.assertEqual(len(Comment.objects), 0) + + self.Person.drop_collection() + BlogPost.drop_collection() + Comment.drop_collection() + + def test_reverse_delete_rule_deny(self): + """Ensure that a document cannot be referenced if there are still + documents referring to it. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=DENY) + + self.Person.drop_collection() + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content = 'Watched some TV') + post.author = author + post.save() + + # Delete the Person should be denied + self.assertRaises(OperationError, author.delete) # Should raise denied error + self.assertEqual(len(BlogPost.objects), 1) # No objects may have been deleted + self.assertEqual(len(self.Person.objects), 1) + + # Other users, that don't have BlogPosts must be removable, like normal + author = self.Person(name='Another User') + author.save() + + self.assertEqual(len(self.Person.objects), 2) + author.delete() + self.assertEqual(len(self.Person.objects), 1) + + self.Person.drop_collection() + BlogPost.drop_collection() + + def subclasses_and_unique_keys_works(self): + + class A(Document): + pass + + class B(A): + foo = BooleanField(unique=True) + + A.drop_collection() + B.drop_collection() + + A().save() + A().save() + B(foo=True).save() + + self.assertEquals(A.objects.count(), 2) + self.assertEquals(B.objects.count(), 1) + A.drop_collection() + B.drop_collection() def test_document_hash(self): """Test document in list, dict, set @@ -635,7 +2279,7 @@ class DocumentTest(unittest.TestCase): class BlogPost(Document): pass - + # Clear old datas User.drop_collection() BlogPost.drop_collection() @@ -672,9 +2316,35 @@ class DocumentTest(unittest.TestCase): # in Set all_user_set = set(User.objects.all()) - + self.assertTrue(u1 in all_user_set ) - + + def test_picklable(self): + + pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) + pickle_doc.embedded = PickleEmbedded() + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEquals(resurrected, pickle_doc) + + resurrected.string = "Two" + resurrected.save() + + pickle_doc.reload() + self.assertEquals(resurrected, pickle_doc) + + def throw_invalid_document_error(self): + + # test handles people trying to upsert + def throw_invalid_document_error(): + class Blog(Document): + validate = DictField() + + self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + if __name__ == '__main__': unittest.main() diff --git a/tests/fields.py b/tests/fields.py index 5602cdec..dc53eae3 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -7,6 +7,7 @@ import gridfs from mongoengine import * from mongoengine.connection import _get_db +from mongoengine.base import _document_registry, NotRegistered class FieldTest(unittest.TestCase): @@ -20,12 +21,15 @@ class FieldTest(unittest.TestCase): """ class Person(Document): name = StringField() - age = IntField(default=30) - userid = StringField(default=lambda: 'test') + age = IntField(default=30, help_text="Your real age") + userid = StringField(default=lambda: 'test', verbose_name="User Identity") person = Person(name='Test Person') self.assertEqual(person._data['age'], 30) self.assertEqual(person._data['userid'], 'test') + self.assertEqual(person._fields['name'].help_text, None) + self.assertEqual(person._fields['age'].help_text, "Your real age") + self.assertEqual(person._fields['userid'].verbose_name, "User Identity") def test_required_values(self): """Ensure that required field constraints are enforced. @@ -45,7 +49,7 @@ class FieldTest(unittest.TestCase): """ class Person(Document): name = StringField() - + person = Person(name='Test User') self.assertEqual(person.id, None) @@ -95,7 +99,7 @@ class FieldTest(unittest.TestCase): link.url = 'http://www.google.com:8080' link.validate() - + def test_int_validation(self): """Ensure that invalid values cannot be assigned to int fields. """ @@ -129,12 +133,12 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.height = 4.0 self.assertRaises(ValidationError, person.validate) - + def test_decimal_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), + height = DecimalField(min_value=Decimal('0.1'), max_value=Decimal('3.5')) Person.drop_collection() @@ -181,11 +185,184 @@ class FieldTest(unittest.TestCase): log.time = datetime.datetime.now() log.validate() + log.time = datetime.date.today() + log.validate() + log.time = -1 self.assertRaises(ValidationError, log.validate) log.time = '1pm' self.assertRaises(ValidationError, log.validate) + def test_datetime(self): + """Tests showing pymongo datetime fields handling of microseconds. + Microseconds are rounded to the nearest millisecond and pre UTC + handling is wonky. + + See: http://api.mongodb.org/python/current/api/bson/son.html#dt + """ + class LogEntry(Document): + date = DateTimeField() + + LogEntry.drop_collection() + + # Test can save dates + log = LogEntry() + log.date = datetime.date.today() + log.save() + log.reload() + self.assertEquals(log.date.date(), datetime.date.today()) + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and dropped + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Post UTC - microseconds are rounded (down) nearest millisecond + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC dates microseconds below 1000 are dropped + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + self.assertEquals(log.date, d2) + + # Pre UTC microseconds above 1000 is wonky. + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + # + # However, the timedelta is predicable with pre UTC timestamps + # It always adds 16 seconds and [777216-776217] microseconds + for i in xrange(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertNotEquals(log.date, d1) + + delta = log.date - d1 + self.assertEquals(delta.seconds, 16) + microseconds = 777216 - (i % 1000) + self.assertEquals(delta.microseconds, microseconds) + + LogEntry.drop_collection() + + def test_complexdatetime_storage(self): + """Tests for complex datetime fields - which can handle microseconds + without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Pre UTC dates microseconds below 1000 are dropped - with default datetimefields + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Pre UTC microseconds above 1000 is wonky - with default datetimefields + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + for i in xrange(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + LogEntry.drop_collection() + + def test_complexdatetime_usage(self): + """Tests for complex datetime fields - which can handle microseconds + without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + log = LogEntry() + log.date = d1 + log.save() + + log1 = LogEntry.objects.get(date=d1) + self.assertEquals(log, log1) + + LogEntry.drop_collection() + + # create 60 log entries + for i in xrange(1950, 2010): + d = datetime.datetime(i, 01, 01, 00, 00, 01, 999) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + count = logs.count() + i = 0 + while i == count-1: + self.assertTrue(logs[i].date <= logs[i+1].date) + i +=1 + + logs = LogEntry.objects.order_by("-date") + count = logs.count() + i = 0 + while i == count-1: + self.assertTrue(logs[i].date >= logs[i+1].date) + i +=1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011,1,1), + date__gte=datetime.datetime(2000,1,1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ @@ -200,6 +377,7 @@ class FieldTest(unittest.TestCase): comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) authors = ListField(ReferenceField(User)) + generic = ListField(GenericReferenceField()) post = BlogPost(content='Went for a walk today...') post.validate() @@ -227,8 +405,28 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, post.validate) post.authors = [User()] + self.assertRaises(ValidationError, post.validate) + + user = User() + user.save() + post.authors = [user] post.validate() + post.generic = [1, 2] + self.assertRaises(ValidationError, post.validate) + + post.generic = [User(), Comment()] + self.assertRaises(ValidationError, post.validate) + + post.generic = [Comment()] + self.assertRaises(ValidationError, post.validate) + + post.generic = [user] + post.validate() + + User.drop_collection() + BlogPost.drop_collection() + def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values. """ @@ -249,7 +447,7 @@ class FieldTest(unittest.TestCase): post.save() post.reload() self.assertEqual(post.tags, ['fun', 'leisure']) - + comment1 = Comment(content='Good for you', order=1) comment2 = Comment(content='Yay.', order=0) comments = [comment1, comment2] @@ -261,12 +459,116 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() - def test_dict_validation(self): + def test_list_field(self): + """Ensure that list types work as expected. + """ + class BlogPost(Document): + info = ListField() + + BlogPost.drop_collection() + + post = BlogPost() + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) + + post.info = {'title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = ['test'] + post.save() + + post = BlogPost() + post.info = [{'test': 'test'}] + post.save() + + post = BlogPost() + post.info = [{'test': 3}] + post.save() + + + self.assertEquals(BlogPost.objects.count(), 3) + self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1) + self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1) + + # Confirm handles non strings or non existing keys + self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) + self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + BlogPost.drop_collection() + + def test_list_field_strict(self): + """Ensure that list field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = ListField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping = [1] + e.save() + + def create_invalid_mapping(): + e.mapping = ["abc"] + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_list_field_complex(self): + """Ensure that the list fields can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = ListField() + + Simple.drop_collection() + e = Simple() + e.mapping.append(StringSetting(value='foo')) + e.mapping.append(IntegerSetting(value=42)) + e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]}) + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping[0], StringSetting)) + self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) + self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1) + + Simple.objects().update( + set__mapping__2__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + + def test_dict_field(self): """Ensure that dict types work as expected. """ class BlogPost(Document): info = DictField() + BlogPost.drop_collection() + post = BlogPost() post.info = 'my post' self.assertRaises(ValidationError, post.validate) @@ -281,7 +583,149 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, post.validate) post.info = {'title': 'test'} - post.validate() + post.save() + + post = BlogPost() + post.info = {'details': {'test': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 3}} + post.save() + + self.assertEquals(BlogPost.objects.count(), 3) + self.assertEquals(BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEquals(BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + + # Confirm handles non strings or non existing keys + self.assertEquals(BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + BlogPost.drop_collection() + + def test_dictfield_strict(self): + """Ensure that dict field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = DictField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_dictfield_complex(self): + """Ensure that the dict field can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + e = Simple() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]} + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update( + set__mapping={"someint": IntegerSetting(value=10)}) + Simple.objects().update( + set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + def create_invalid_class(): + class NoDeclaredType(Document): + mapping = MapField() + + self.assertRaises(ValidationError, create_invalid_class) + + Simple.drop_collection() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + def create_invalid_mapping(): + e.mapping['someint'] = 123 + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Extensible.drop_collection() def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to @@ -315,7 +759,7 @@ class FieldTest(unittest.TestCase): person.validate() def test_embedded_document_inheritance(self): - """Ensure that subclasses of embedded documents may be provided to + """Ensure that subclasses of embedded documents may be provided to EmbeddedDocumentFields of the superclass' type. """ class User(EmbeddedDocument): @@ -327,7 +771,7 @@ class FieldTest(unittest.TestCase): class BlogPost(Document): content = StringField() author = EmbeddedDocumentField(User) - + post = BlogPost(content='What I did today...') post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) @@ -370,7 +814,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() - + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -434,7 +878,8 @@ class FieldTest(unittest.TestCase): class TreeNode(EmbeddedDocument): name = StringField() children = ListField(EmbeddedDocumentField('self')) - + + Tree.drop_collection() tree = Tree(name="Tree") first_child = TreeNode(name="Child 1") @@ -442,18 +887,51 @@ class FieldTest(unittest.TestCase): second_child = TreeNode(name="Child 2") first_child.children.append(second_child) - - third_child = TreeNode(name="Child 3") - first_child.children.append(third_child) - tree.save() - tree_obj = Tree.objects.first() + tree = Tree.objects.first() + self.assertEqual(len(tree.children), 1) + + self.assertEqual(len(tree.children[0].children), 1) + + third_child = TreeNode(name="Child 3") + tree.children[0].children.append(third_child) + tree.save() + self.assertEqual(len(tree.children), 1) self.assertEqual(tree.children[0].name, first_child.name) self.assertEqual(tree.children[0].children[0].name, second_child.name) self.assertEqual(tree.children[0].children[1].name, third_child.name) + # Test updating + tree.children[0].name = 'I am Child 1' + tree.children[0].children[0].name = 'I am Child 2' + tree.children[0].children[1].name = 'I am Child 3' + tree.save() + + self.assertEqual(tree.children[0].name, 'I am Child 1') + self.assertEqual(tree.children[0].children[0].name, 'I am Child 2') + self.assertEqual(tree.children[0].children[1].name, 'I am Child 3') + + # Test removal + self.assertEqual(len(tree.children[0].children), 2) + del(tree.children[0].children[1]) + + tree.save() + self.assertEqual(len(tree.children[0].children), 1) + + tree.children[0].children.pop(0) + tree.save() + self.assertEqual(len(tree.children[0].children), 0) + self.assertEqual(tree.children[0].children, []) + + tree.children[0].children.insert(0, third_child) + tree.children[0].children.insert(0, second_child) + tree.save() + self.assertEqual(len(tree.children[0].children), 2) + self.assertEqual(tree.children[0].children[0].name, second_child.name) + self.assertEqual(tree.children[0].children[1].name, third_child.name) + def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. """ @@ -506,46 +984,46 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() - + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """ class Link(Document): title = StringField() meta = {'allow_inheritance': False} - + class Post(Document): title = StringField() - + class Bookmark(Document): bookmark_object = GenericReferenceField() - + Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() - + link_1 = Link(title="Pitchfork") link_1.save() - + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() - + bm = Bookmark(bookmark_object=post_1) bm.save() - + bm = Bookmark.objects(bookmark_object=post_1).first() - + self.assertEqual(bm.bookmark_object, post_1) self.assertTrue(isinstance(bm.bookmark_object, Post)) - + bm.bookmark_object = link_1 bm.save() - + bm = Bookmark.objects(bookmark_object=link_1).first() - + self.assertEqual(bm.bookmark_object, link_1) self.assertTrue(isinstance(bm.bookmark_object, Link)) - + Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() @@ -555,35 +1033,81 @@ class FieldTest(unittest.TestCase): """ class Link(Document): title = StringField() - + class Post(Document): title = StringField() - + class User(Document): bookmarks = ListField(GenericReferenceField()) - + Link.drop_collection() Post.drop_collection() User.drop_collection() - + link_1 = Link(title="Pitchfork") link_1.save() - + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") post_1.save() - + user = User(bookmarks=[post_1, link_1]) user.save() - + user = User.objects(bookmarks__all=[post_1, link_1]).first() - + self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[1], link_1) - + Link.drop_collection() Post.drop_collection() User.drop_collection() + + def test_generic_reference_document_not_registered(self): + """Ensure dereferencing out of the document registry throws a + `NotRegistered` error. + """ + class Link(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField()) + + Link.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + user = User(bookmarks=[link_1]) + user.save() + + # Mimic User and Link definitions being in a different file + # and the Link model not being imported in the User file. + del(_document_registry["Link"]) + + user = User.objects.first() + try: + user.bookmarks + raise AssertionError, "Link was removed from the registry" + except NotRegistered: + pass + + Link.drop_collection() + User.drop_collection() + + def test_generic_reference_is_none(self): + + class Person(Document): + name = StringField() + city = GenericReferenceField() + + Person.drop_collection() + Person(name="Wilson Jr").save() + + self.assertEquals(repr(Person.objects(city=None)), + "[]") + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """ @@ -644,7 +1168,8 @@ class FieldTest(unittest.TestCase): """Ensure that value is in a container of allowed values. """ class Shirt(Document): - size = StringField(max_length=3, choices=('S','M','L','XL','XXL')) + size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) Shirt.drop_collection() @@ -659,6 +1184,35 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_choices_get_field_display(self): + """Test dynamic helper for returning the display value of a choices field. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') + + Shirt.drop_collection() + + shirt = Shirt() + + self.assertEqual(shirt.get_size_display(), None) + self.assertEqual(shirt.get_style_display(), 'Small') + + shirt.size = "XXL" + shirt.style = "B" + self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') + self.assertEqual(shirt.get_style_display(), 'Baggy') + + # Set as Z - an invalid choice + shirt.size = "Z" + shirt.style = "Z" + self.assertEqual(shirt.get_size_display(), 'Z') + self.assertEqual(shirt.get_style_display(), 'Z') + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + def test_file_fields(self): """Ensure that file fields can be written to and their data retrieved """ @@ -700,6 +1254,12 @@ class FieldTest(unittest.TestCase): self.assertTrue(streamfile == result) self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.content_type, content_type) + result.file.seek(0) + self.assertEquals(result.file.tell(), 0) + self.assertEquals(result.file.read(len(text)), text) + self.assertEquals(result.file.tell(), len(text)) + self.assertEquals(result.file.read(len(more_text)), more_text) + self.assertEquals(result.file.tell(), len(text + more_text)) result.file.delete() # Ensure deleted file returns None @@ -720,7 +1280,7 @@ class FieldTest(unittest.TestCase): result = SetFile.objects.first() self.assertTrue(setfile == result) self.assertEquals(result.file.read(), more_text) - result.file.delete() + result.file.delete() PutFile.drop_collection() StreamFile.drop_collection() @@ -753,6 +1313,21 @@ class FieldTest(unittest.TestCase): TestFile.drop_collection() + def test_file_boolean(self): + """Ensure that a boolean test of a FileField indicates its presence + """ + class TestFile(Document): + file = FileField() + + testfile = TestFile() + self.assertFalse(bool(testfile.file)) + testfile.file = 'Hello, World!' + testfile.file.content_type = 'text/plain' + testfile.save() + self.assertTrue(bool(testfile.file)) + + TestFile.drop_collection() + def test_geo_indexes(self): """Ensure that indexes are created automatically for GeoPointFields. """ @@ -771,6 +1346,27 @@ class FieldTest(unittest.TestCase): Event.drop_collection() + def test_geo_embedded_indexes(self): + """Ensure that indexes are created automatically for GeoPointFields on + embedded documents. + """ + class Venue(EmbeddedDocument): + location = GeoPointField() + name = StringField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + Event.drop_collection() + venue = Venue(name="Double Door", location=[41.909889, -87.677137]) + event = Event(title="Coltrane Motion", venue=venue) + event.save() + + info = Event.objects._collection.index_information() + self.assertTrue(u'location_2d' in info) + self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) + def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" class D(Document): @@ -784,5 +1380,139 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) + def test_sequence_field(self): + class Person(Document): + id = SequenceField(primary_key=True) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + def test_multiple_sequence_fields(self): + class Person(Document): + id = SequenceField(primary_key=True) + counter = SequenceField() + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + counters = [i.counter for i in Person.objects] + self.assertEqual(counters, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + def test_sequence_fields_reload(self): + class Animal(Document): + counter = SequenceField() + type = StringField() + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + + a = Animal(type="Boi") + a.save() + + self.assertEqual(a.counter, 1) + a.reload() + self.assertEqual(a.counter, 1) + + a.counter = None + self.assertEqual(a.counter, 2) + a.save() + + self.assertEqual(a.counter, 2) + + a = Animal.objects.first() + self.assertEqual(a.counter, 2) + a.reload() + self.assertEqual(a.counter, 2) + + def test_multiple_sequence_fields_on_docs(self): + + class Animal(Document): + id = SequenceField(primary_key=True) + + class Person(Document): + id = SequenceField(primary_key=True) + + self.db['mongoengine.counters'].drop() + Animal.drop_collection() + Person.drop_collection() + + for x in xrange(10): + a = Animal(name="Animal %s" % x) + a.save() + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, range(1, 11)) + + id = [i.id for i in Animal.objects] + self.assertEqual(id, range(1, 11)) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) + self.assertEqual(c['next'], 10) + + + def test_generic_embedded_document(self): + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + person.save() + + person = Person.objects.first() + self.assertTrue(isinstance(person.like, Car)) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertTrue(isinstance(person.like, Dish)) + if __name__ == '__main__': unittest.main() diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..5aaba556 --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,25 @@ +from datetime import datetime +import pymongo + +from mongoengine import * +from mongoengine.base import BaseField +from mongoengine.connection import _get_db + + +class PickleEmbedded(EmbeddedDocument): + date = DateTimeField(default=datetime.now) + + +class PickleTest(Document): + number = IntField() + string = StringField(choices=(('One', '1'), ('Two', '2'))) + embedded = EmbeddedDocumentField(PickleEmbedded) + lists = ListField(StringField()) + + +class Mixin(object): + name = StringField() + + +class Base(Document): + pass diff --git a/tests/queryset.py b/tests/queryset.py index 72623b89..f093f6ab 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- - - import unittest import pymongo from datetime import datetime, timedelta -from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, - DoesNotExist) +from mongoengine.queryset import (QuerySet, QuerySetManager, + MultipleObjectsReturned, DoesNotExist, + QueryFieldList) from mongoengine import * +from mongoengine.connection import _get_connection +from mongoengine.tests import query_counter class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - + class Person(Document): name = StringField() age = IntField() @@ -25,7 +26,7 @@ class QuerySetTest(unittest.TestCase): """ self.assertTrue(isinstance(self.Person.objects, QuerySet)) self.assertEqual(self.Person.objects._collection.name, - self.Person._meta['collection']) + self.Person._get_collection_name()) self.assertTrue(isinstance(self.Person.objects._collection, pymongo.collection.Collection)) @@ -105,6 +106,10 @@ class QuerySetTest(unittest.TestCase): people = list(self.Person.objects[1:1]) self.assertEqual(len(people), 0) + # Test slice out of range + people = list(self.Person.objects[80000:80001]) + self.assertEqual(len(people), 0) + def test_find_one(self): """Ensure that a query using find_one returns a valid result. """ @@ -162,7 +167,7 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") - + def test_find_array_position(self): """Ensure that query by array position works. """ @@ -177,7 +182,7 @@ class QuerySetTest(unittest.TestCase): posts = ListField(EmbeddedDocumentField(Post)) Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='b')), 0) @@ -207,6 +212,200 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() + def test_update_write_options(self): + """Test that passing write_options works""" + + self.Person.drop_collection() + + write_options = {"fsync": True} + + author, created = self.Person.objects.get_or_create( + name='Test User', write_options=write_options) + author.save(write_options=write_options) + + self.Person.objects.update(set__name='Ross', write_options=write_options) + + author = self.Person.objects.first() + self.assertEquals(author.name, 'Ross') + + self.Person.objects.update_one(set__name='Test User', write_options=write_options) + author = self.Person.objects.first() + self.assertEquals(author.name, 'Test User') + + def test_update_update_has_a_value(self): + """Test to ensure that update is passed a value to update to""" + self.Person.drop_collection() + + author = self.Person(name='Test User') + author.save() + + def update_raises(): + self.Person.objects(pk=author.pk).update({}) + + def update_one_raises(): + self.Person.objects(pk=author.pk).update_one({}) + + self.assertRaises(OperationError, update_raises) + self.assertRaises(OperationError, update_one_raises) + + def test_update_array_position(self): + """Ensure that updating by array position works. + + Check update() and update_one() can take syntax like: + set__posts__1__comments__1__name="testc" + Check that it only works for ListFields. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update all of the first comments of second posts of all blogs + blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + self.assertEqual(len(testc_blogs), 2) + + Blog.drop_collection() + + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update only the first blog returned by the query + blog = Blog.objects().update_one( + set__posts__1__comments__1__name="testc") + testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + self.assertEqual(len(testc_blogs), 1) + + # Check that using this indexing syntax on a non-list fails + def non_list_indexing(): + Blog.objects().update(set__posts__1__comments__0__name__1="asdf") + self.assertRaises(InvalidQueryError, non_list_indexing) + + Blog.drop_collection() + + def test_update_using_positional_operator(self): + """Ensure that the list fields can be updated using the positional + operator.""" + + class Comment(EmbeddedDocument): + by = StringField() + votes = IntField() + + class BlogPost(Document): + title = StringField() + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + c1 = Comment(by="joe", votes=3) + c2 = Comment(by="jane", votes=7) + + BlogPost(title="ABC", comments=[c1, c2]).save() + + BlogPost.objects(comments__by="joe").update(inc__comments__S__votes=1) + + post = BlogPost.objects.first() + self.assertEquals(post.comments[0].by, 'joe') + self.assertEquals(post.comments[0].votes, 4) + + # Currently the $ operator only applies to the first matched item in + # the query + + class Simple(Document): + x = ListField() + + Simple.drop_collection() + Simple(x=[1, 2, 3, 2]).save() + Simple.objects(x=2).update(inc__x__S=1) + + simple = Simple.objects.first() + self.assertEquals(simple.x, [1, 3, 3, 2]) + Simple.drop_collection() + + # You can set multiples + Simple.drop_collection() + Simple(x=[1, 2, 3, 4]).save() + Simple(x=[2, 3, 4, 5]).save() + Simple(x=[3, 4, 5, 6]).save() + Simple(x=[4, 5, 6, 7]).save() + Simple.objects(x=3).update(set__x__S=0) + + s = Simple.objects() + self.assertEquals(s[0].x, [1, 2, 0, 4]) + self.assertEquals(s[1].x, [2, 0, 4, 5]) + self.assertEquals(s[2].x, [0, 4, 5, 6]) + self.assertEquals(s[3].x, [4, 5, 6, 7]) + + # Using "$unset" with an expression like this "array.$" will result in + # the array item becoming None, not being removed. + Simple.drop_collection() + Simple(x=[1, 2, 3, 4, 3, 2, 3, 4]).save() + Simple.objects(x=3).update(unset__x__S=1) + simple = Simple.objects.first() + self.assertEquals(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) + + # Nested updates arent supported yet.. + def update_nested(): + Simple.drop_collection() + Simple(x=[{'test': [1, 2, 3, 4]}]).save() + Simple.objects(x__test=2).update(set__x__S__test__S=3) + self.assertEquals(simple.x, [1, 2, 3, 4]) + + self.assertRaises(OperationError, update_nested) + Simple.drop_collection() + + def test_mapfield_update(self): + """Ensure that the MapField can be updated.""" + class Member(EmbeddedDocument): + gender = StringField() + age = IntField() + + class Club(Document): + members = MapField(EmbeddedDocumentField(Member)) + + Club.drop_collection() + + club = Club() + club.members['John'] = Member(gender="M", age=13) + club.save() + + Club.objects().update( + set__members={"John": Member(gender="F", age=14)}) + + club = Club.objects().first() + self.assertEqual(club.members['John'].gender, "F") + self.assertEqual(club.members['John'].age, 14) + + def test_dictfield_update(self): + """Ensure that the DictField can be updated.""" + class Club(Document): + members = DictField() + + club = Club() + club.members['John'] = dict(gender="M", age=13) + club.save() + + Club.objects().update( + set__members={"John": dict(gender="F", age=14)}) + + club = Club.objects().first() + self.assertEqual(club.members['John']['gender'], "F") + self.assertEqual(club.members['John']['age'], 14) + def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new document. @@ -226,19 +425,138 @@ class QuerySetTest(unittest.TestCase): person, created = self.Person.objects.get_or_create(age=30) self.assertEqual(person.name, "User B") self.assertEqual(created, False) - + person, created = self.Person.objects.get_or_create(age__lt=30) self.assertEqual(person.name, "User A") self.assertEqual(created, False) - + # Try retrieving when no objects exists - new doc should be created kwargs = dict(age=50, defaults={'name': 'User C'}) person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) - + person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") + def test_bulk_insert(self): + """Ensure that query by array position works. + """ + + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + title = StringField() + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + with query_counter() as q: + self.assertEqual(q, 0) + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + + blogs = [] + for i in xrange(1, 100): + blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) + + Blog.objects.insert(blogs, load_bulk=False) + self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert + + Blog.objects.insert(blogs) + self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog(title="code", posts=[post1, post2]) + blog2 = Blog(title="mongodb", posts=[post2, post1]) + blog1, blog2 = Blog.objects.insert([blog1, blog2]) + self.assertEqual(blog1.title, "code") + self.assertEqual(blog2.title, "mongodb") + + self.assertEqual(Blog.objects.count(), 2) + + # test handles people trying to upsert + def throw_operation_error(): + blogs = Blog.objects + Blog.objects.insert(blogs) + + self.assertRaises(OperationError, throw_operation_error) + + # test handles other classes being inserted + def throw_operation_error_wrong_doc(): + class Author(Document): + pass + Blog.objects.insert(Author()) + + self.assertRaises(OperationError, throw_operation_error_wrong_doc) + + def throw_operation_error_not_a_document(): + Blog.objects.insert("HELLO WORLD") + + self.assertRaises(OperationError, throw_operation_error_not_a_document) + + Blog.drop_collection() + + blog1 = Blog(title="code", posts=[post1, post2]) + blog1 = Blog.objects.insert(blog1) + self.assertEqual(blog1.title, "code") + self.assertEqual(Blog.objects.count(), 1) + + Blog.drop_collection() + blog1 = Blog(title="code", posts=[post1, post2]) + obj_id = Blog.objects.insert(blog1, load_bulk=False) + self.assertEquals(obj_id.__class__.__name__, 'ObjectId') + + def test_slave_okay(self): + """Ensures that a query can take slave_okay syntax + """ + person1 = self.Person(name="User A", age=20) + person1.save() + person2 = self.Person(name="User B", age=30) + person2.save() + + # Retrieve the first person from the database + person = self.Person.objects.slave_okay(True).first() + self.assertTrue(isinstance(person, self.Person)) + self.assertEqual(person.name, "User A") + self.assertEqual(person.age, 20) + + def test_cursor_args(self): + """Ensures the cursor args can be set as expected + """ + p = self.Person.objects + # Check default + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': True}) + + p.snapshot(False).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(True) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': True}) + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ @@ -251,6 +569,18 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(people1, people2) + def test_repr_iteration(self): + """Ensure that QuerySet __repr__ can handle loops + """ + self.Person(name='Person 1').save() + self.Person(name='Person 2').save() + + queryset = self.Person.objects + self.assertEquals('[, ]', repr(queryset)) + for person in queryset: + self.assertEquals('.. queryset mid-iteration ..', repr(queryset)) + + def test_regex_query_shortcuts(self): """Ensure that contains, startswith, endswith, etc work. """ @@ -328,7 +658,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) - + # Test unsafe expressions person = self.Person(name='Guido van Rossum [.\'Geek\']') person.save() @@ -351,8 +681,6 @@ class QuerySetTest(unittest.TestCase): def test_filter_chaining(self): """Ensure filters can be chained together. """ - from datetime import datetime - class BlogPost(Document): title = StringField() is_published = BooleanField() @@ -452,6 +780,224 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj.salary, employee.salary) self.assertEqual(obj.name, None) + def test_only_with_subfields(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.only('author.name',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author.email, None) + self.assertEqual(obj.author.name, 'Test User') + self.assertEqual(obj.comments, []) + + obj = BlogPost.objects.only('content', 'comments.title',).get() + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, None) + self.assertEqual(obj.comments[1].text, None) + + obj = BlogPost.objects.only('comments',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, 'Great post!') + self.assertEqual(obj.comments[1].text, 'I hate coffee') + + BlogPost.drop_collection() + + def test_exclude(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.exclude('author', 'comments.text').get() + self.assertEqual(obj.author, None) + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[0].text, None) + + BlogPost.drop_collection() + + def test_exclude_only_combining(self): + class Attachment(EmbeddedDocument): + name = StringField() + content = StringField() + + class Email(Document): + sender = StringField() + to = StringField() + subject = StringField() + body = StringField() + content_type = StringField() + attachments = ListField(EmbeddedDocumentField(Attachment)) + + Email.drop_collection() + email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email.attachments = [ + Attachment(name='file1.doc', content='ABC'), + Attachment(name='file2.doc', content='XYZ'), + ] + email.save() + + obj = Email.objects.exclude('content_type').exclude('body').get() + self.assertEqual(obj.sender, 'me') + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, 'From Russia with Love') + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() + self.assertEqual(obj.attachments[0].name, 'file1.doc') + self.assertEqual(obj.attachments[0].content, None) + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + Email.drop_collection() + + def test_all_fields(self): + + class Email(Document): + sender = StringField() + to = StringField() + subject = StringField() + body = StringField() + content_type = StringField() + + Email.drop_collection() + + email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email.save() + + obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get() + self.assertEqual(obj.sender, 'me') + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, 'From Russia with Love') + self.assertEqual(obj.body, 'Hello!') + self.assertEqual(obj.content_type, 'text/plain') + + Email.drop_collection() + + def test_slicing_fields(self): + """Ensure that query slicing an array works. + """ + class Numbers(Document): + n = ListField(IntField()) + + Numbers.drop_collection() + + numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__n=3).get() + self.assertEquals(numbers.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__n=-3).get() + self.assertEquals(numbers.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__n=[2, 3]).get() + self.assertEquals(numbers.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + def test_slicing_nested_fields(self): + """Ensure that query slicing an embedded array works. + """ + + class EmbeddedNumber(EmbeddedDocument): + n = ListField(IntField()) + + class Numbers(Document): + embedded = EmbeddedDocumentField(EmbeddedNumber) + + Numbers.drop_collection() + + numbers = Numbers() + numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__embedded__n=3).get() + self.assertEquals(numbers.embedded.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__embedded__n=-3).get() + self.assertEquals(numbers.embedded.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() + self.assertEquals(numbers.embedded.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -494,33 +1040,47 @@ class QuerySetTest(unittest.TestCase): """Ensure that Q objects may be used to query for documents. """ class BlogPost(Document): + title = StringField() publish_date = DateTimeField() published = BooleanField() BlogPost.drop_collection() - post1 = BlogPost(publish_date=datetime(2010, 1, 8), published=False) + post1 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 8), published=False) post1.save() - post2 = BlogPost(publish_date=datetime(2010, 1, 15), published=True) + post2 = BlogPost(title='Test 2', publish_date=datetime(2010, 1, 15), published=True) post2.save() - post3 = BlogPost(published=True) + post3 = BlogPost(title='Test 3', published=True) post3.save() - post4 = BlogPost(publish_date=datetime(2010, 1, 8)) + post4 = BlogPost(title='Test 4', publish_date=datetime(2010, 1, 8)) post4.save() - post5 = BlogPost(publish_date=datetime(2010, 1, 15)) + post5 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 15)) post5.save() - post6 = BlogPost(published=False) + post6 = BlogPost(title='Test 1', published=False) post6.save() # Check ObjectId lookup works obj = BlogPost.objects(id=post1.id).first() self.assertEqual(obj, post1) + # Check Q object combination with one does not exist + q = BlogPost.objects(Q(title='Test 5') | Q(published=True)) + posts = [post.id for post in q] + + published_posts = (post2, post3) + self.assertTrue(all(obj.id in posts for obj in published_posts)) + + q = BlogPost.objects(Q(title='Test 1') | Q(published=True)) + posts = [post.id for post in q] + published_posts = (post1, post2, post3, post5, post6) + self.assertTrue(all(obj.id in posts for obj in published_posts)) + + # Check Q object combination date = datetime(2010, 1, 10) q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) @@ -559,7 +1119,7 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) @@ -631,7 +1191,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): name = StringField(db_field='doc-name') - comments = ListField(EmbeddedDocumentField(Comment), + comments = ListField(EmbeddedDocumentField(Comment), db_field='cmnts') BlogPost.drop_collection() @@ -674,6 +1234,11 @@ class QuerySetTest(unittest.TestCase): ] self.assertEqual(results, expected_results) + # Test template style + code = "{{~comments.content}}" + sub_code = BlogPost.objects._sub_js_fields(code) + self.assertEquals("cmnts.body", sub_code) + BlogPost.drop_collection() def test_delete(self): @@ -691,6 +1256,71 @@ class QuerySetTest(unittest.TestCase): self.Person.objects.delete() self.assertEqual(len(self.Person.objects), 0) + def test_reverse_delete_rule_cascade(self): + """Ensure cascading deletion of referring documents from the database. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + BlogPost.drop_collection() + + me = self.Person(name='Test User') + me.save() + someoneelse = self.Person(name='Some-one Else') + someoneelse.save() + + BlogPost(content='Watching TV', author=me).save() + BlogPost(content='Chilling out', author=me).save() + BlogPost(content='Pro Testing', author=someoneelse).save() + + self.assertEqual(3, BlogPost.objects.count()) + self.Person.objects(name='Test User').delete() + self.assertEqual(1, BlogPost.objects.count()) + + def test_reverse_delete_rule_nullify(self): + """Ensure nullification of references to deleted documents. + """ + class Category(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + category = ReferenceField(Category, reverse_delete_rule=NULLIFY) + + BlogPost.drop_collection() + Category.drop_collection() + + lameness = Category(name='Lameness') + lameness.save() + + post = BlogPost(content='Watching TV', category=lameness) + post.save() + + self.assertEqual(1, BlogPost.objects.count()) + self.assertEqual('Lameness', BlogPost.objects.first().category.name) + Category.objects.delete() + self.assertEqual(1, BlogPost.objects.count()) + self.assertEqual(None, BlogPost.objects.first().category) + + def test_reverse_delete_rule_deny(self): + """Ensure deletion gets denied on documents that still have references + to them. + """ + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=DENY) + + BlogPost.drop_collection() + self.Person.drop_collection() + + me = self.Person(name='Test User') + me.save() + + post = BlogPost(content='Watching TV', author=me) + post.save() + + self.assertRaises(OperationError, self.Person.objects.delete) + def test_update(self): """Ensure that atomic updates work properly. """ @@ -733,7 +1363,12 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.update_one(add_to_set__tags='unique') post.reload() self.assertEqual(post.tags.count('unique'), 1) - + + self.assertNotEqual(post.hits, None) + BlogPost.objects.update_one(unset__hits=1) + post.reload() + self.assertEqual(post.hits, None) + BlogPost.drop_collection() def test_update_pull(self): @@ -751,6 +1386,69 @@ class QuerySetTest(unittest.TestCase): self.assertTrue('code' not in post.tags) self.assertEqual(len(post.tags), 1) + def test_update_one_pop_generic_reference(self): + + class BlogTag(Document): + name = StringField(required=True) + + class BlogPost(Document): + slug = StringField() + tags = ListField(ReferenceField(BlogTag), required=True) + + BlogPost.drop_collection() + BlogTag.drop_collection() + + tag_1 = BlogTag(name='code') + tag_1.save() + tag_2 = BlogTag(name='mongodb') + tag_2.save() + + post = BlogPost(slug="test", tags=[tag_1]) + post.save() + + post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) + post.save() + self.assertEqual(len(post.tags), 2) + + BlogPost.objects(slug="test-2").update_one(pop__tags=-1) + + post.reload() + self.assertEqual(len(post.tags), 1) + + BlogPost.drop_collection() + BlogTag.drop_collection() + + def test_editting_embedded_objects(self): + + class BlogTag(EmbeddedDocument): + name = StringField(required=True) + + class BlogPost(Document): + slug = StringField() + tags = ListField(EmbeddedDocumentField(BlogTag), required=True) + + BlogPost.drop_collection() + + tag_1 = BlogTag(name='code') + tag_2 = BlogTag(name='mongodb') + + post = BlogPost(slug="test", tags=[tag_1]) + post.save() + + post = BlogPost(slug="test-2", tags=[tag_1, tag_2]) + post.save() + self.assertEqual(len(post.tags), 2) + + BlogPost.objects(slug="test-2").update_one(set__tags__0__name="python") + post.reload() + self.assertEquals(post.tags[0].name, 'python') + + BlogPost.objects(slug="test-2").update_one(pop__tags=-1) + post.reload() + self.assertEqual(len(post.tags), 1) + + BlogPost.drop_collection() + def test_order_by(self): """Ensure that QuerySets may be ordered. """ @@ -770,6 +1468,29 @@ class QuerySetTest(unittest.TestCase): ages = [p.age for p in self.Person.objects.order_by('-name')] self.assertEqual(ages, [30, 40, 20]) + def test_confirm_order_by_reference_wont_work(self): + """Ordering by reference is not possible. Use map / reduce.. or + denormalise""" + + class Author(Document): + author = ReferenceField(self.Person) + + Author.drop_collection() + + person_a = self.Person(name="User A", age=20) + person_a.save() + person_b = self.Person(name="User B", age=40) + person_b.save() + person_c = self.Person(name="User C", age=30) + person_c.save() + + Author(author=person_a).save() + Author(author=person_b).save() + Author(author=person_c).save() + + names = [a.author.name for a in Author.objects.order_by('-author__age')] + self.assertEqual(names, ['User A', 'User B', 'User C']) + def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. """ @@ -802,7 +1523,7 @@ class QuerySetTest(unittest.TestCase): """ # run a map/reduce operation spanning all posts - results = BlogPost.objects.map_reduce(map_f, reduce_f) + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) self.assertEqual(len(results), 4) @@ -813,7 +1534,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(film.value, 3) BlogPost.drop_collection() - + def test_map_reduce_with_custom_object_ids(self): """Ensure that QuerySet.map_reduce works properly with custom primary keys. @@ -822,24 +1543,24 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): title = StringField(primary_key=True) tags = ListField(StringField()) - + post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) - + post1.save() post2.save() post3.save() - + self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._meta['id_field'], 'title') - + map_f = """ function() { emit(this._id, 1); } """ - + # reduce to a list of tag ids and counts reduce_f = """ function(key, values) { @@ -850,10 +1571,10 @@ class QuerySetTest(unittest.TestCase): return total; } """ - - results = BlogPost.objects.map_reduce(map_f, reduce_f) + + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - + self.assertEqual(results[0].object, post1) self.assertEqual(results[1].object, post2) self.assertEqual(results[2].object, post3) @@ -943,7 +1664,7 @@ class QuerySetTest(unittest.TestCase): finalize_f = """ function(key, value) { - // f(sec_since_epoch,y,z) = + // f(sec_since_epoch,y,z) = // log10(z) + ((y*sec_since_epoch) / 45000) z_10 = Math.log(value.z) / Math.log(10); weight = z_10 + ((value.y * value.t_s) / 45000); @@ -962,6 +1683,7 @@ class QuerySetTest(unittest.TestCase): results = Link.objects.order_by("-value") results = results.map_reduce(map_f, reduce_f, + "myresults", finalize_f=finalize_f, scope=scope) results = list(results) @@ -983,39 +1705,141 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - BlogPost(hits=1, tags=['music', 'film', 'actors']).save() - BlogPost(hits=2, tags=['music']).save() + BlogPost(hits=1, tags=['music', 'film', 'actors', 'watch']).save() + BlogPost(hits=2, tags=['music', 'watch']).save() BlogPost(hits=2, tags=['music', 'actors']).save() - f = BlogPost.objects.item_frequencies('tags') - f = dict((key, int(val)) for key, val in f.items()) - self.assertEqual(set(['music', 'film', 'actors']), set(f.keys())) - self.assertEqual(f['music'], 3) - self.assertEqual(f['actors'], 2) - self.assertEqual(f['film'], 1) + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['music', 'film', 'actors', 'watch']), set(f.keys())) + self.assertEqual(f['music'], 3) + self.assertEqual(f['actors'], 2) + self.assertEqual(f['watch'], 2) + self.assertEqual(f['film'], 1) + + exec_js = BlogPost.objects.item_frequencies('tags') + map_reduce = BlogPost.objects.item_frequencies('tags', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) # Ensure query is taken into account - f = BlogPost.objects(hits__gt=1).item_frequencies('tags') - f = dict((key, int(val)) for key, val in f.items()) - self.assertEqual(set(['music', 'actors']), set(f.keys())) - self.assertEqual(f['music'], 2) - self.assertEqual(f['actors'], 1) + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) + self.assertEqual(f['music'], 2) + self.assertEqual(f['actors'], 1) + self.assertEqual(f['watch'], 1) + + exec_js = BlogPost.objects(hits__gt=1).item_frequencies('tags') + map_reduce = BlogPost.objects(hits__gt=1).item_frequencies('tags', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) # Check that normalization works - f = BlogPost.objects.item_frequencies('tags', normalize=True) - self.assertAlmostEqual(f['music'], 3.0/6.0) - self.assertAlmostEqual(f['actors'], 2.0/6.0) - self.assertAlmostEqual(f['film'], 1.0/6.0) + def test_assertions(f): + self.assertAlmostEqual(f['music'], 3.0/8.0) + self.assertAlmostEqual(f['actors'], 2.0/8.0) + self.assertAlmostEqual(f['watch'], 2.0/8.0) + self.assertAlmostEqual(f['film'], 1.0/8.0) + + exec_js = BlogPost.objects.item_frequencies('tags', normalize=True) + map_reduce = BlogPost.objects.item_frequencies('tags', normalize=True, map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) # Check item_frequencies works for non-list fields - f = BlogPost.objects.item_frequencies('hits') - f = dict((key, int(val)) for key, val in f.items()) - self.assertEqual(set(['1', '2']), set(f.keys())) - self.assertEqual(f['1'], 1) - self.assertEqual(f['2'], 2) + def test_assertions(f): + self.assertEqual(set(['1', '2']), set(f.keys())) + self.assertEqual(f['1'], 1) + self.assertEqual(f['2'], 2) + + exec_js = BlogPost.objects.item_frequencies('hits') + map_reduce = BlogPost.objects.item_frequencies('hits', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) BlogPost.drop_collection() + def test_item_frequencies_on_embedded(self): + """Ensure that item frequencies are properly generated from lists. + """ + + class Phone(EmbeddedDocument): + number = StringField() + + class Person(Document): + name = StringField() + phone = EmbeddedDocumentField(Phone) + + Person.drop_collection() + + doc = Person(name="Guido") + doc.phone = Phone(number='62-3331-1656') + doc.save() + + doc = Person(name="Marr") + doc.phone = Phone(number='62-3331-1656') + doc.save() + + doc = Person(name="WP Junior") + doc.phone = Phone(number='62-3332-1656') + doc.save() + + + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['62-3331-1656', '62-3332-1656']), set(f.keys())) + self.assertEqual(f['62-3331-1656'], 2) + self.assertEqual(f['62-3332-1656'], 1) + + exec_js = Person.objects.item_frequencies('phone.number') + map_reduce = Person.objects.item_frequencies('phone.number', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + + # Ensure query is taken into account + def test_assertions(f): + f = dict((key, int(val)) for key, val in f.items()) + self.assertEqual(set(['62-3331-1656']), set(f.keys())) + self.assertEqual(f['62-3331-1656'], 2) + + exec_js = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number') + map_reduce = Person.objects(phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + + # Check that normalization works + def test_assertions(f): + self.assertEqual(f['62-3331-1656'], 2.0/3.0) + self.assertEqual(f['62-3332-1656'], 1.0/3.0) + + exec_js = Person.objects.item_frequencies('phone.number', normalize=True) + map_reduce = Person.objects.item_frequencies('phone.number', normalize=True, map_reduce=True) + test_assertions(exec_js) + test_assertions(map_reduce) + + def test_item_frequencies_null_values(self): + + class Person(Document): + name = StringField() + city = StringField() + + Person.drop_collection() + + Person(name="Wilson Snr", city="CRB").save() + Person(name="Wilson Jr").save() + + freq = Person.objects.item_frequencies('city') + self.assertEquals(freq, {'CRB': 1.0, None: 1.0}) + freq = Person.objects.item_frequencies('city', normalize=True) + self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) + + + freq = Person.objects.item_frequencies('city', map_reduce=True) + self.assertEquals(freq, {'CRB': 1.0, None: 1.0}) + freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) + self.assertEquals(freq, {'CRB': 0.5, None: 0.5}) + def test_average(self): """Ensure that field can be averaged correctly. """ @@ -1064,6 +1888,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): tags = ListField(StringField()) deleted = BooleanField(default=False) + date = DateTimeField(default=datetime.now) @queryset_manager def objects(doc_cls, queryset): @@ -1071,7 +1896,7 @@ class QuerySetTest(unittest.TestCase): @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music', deleted=False) + return queryset(tags='music', deleted=False).order_by('-date') BlogPost.drop_collection() @@ -1087,7 +1912,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) self.assertEqual([p.id for p in BlogPost.music_posts], - [post1.id, post2.id]) + [post2.id, post1.id]) BlogPost.drop_collection() @@ -1208,6 +2033,22 @@ class QuerySetTest(unittest.TestCase): self.assertTrue([('_types', 1)] in info) self.assertTrue([('_types', 1), ('date', -1)] in info) + def test_dont_index_types(self): + """Ensure that index_types will, when disabled, prevent _types + being added to all indices. + """ + class BlogPost(Document): + date = DateTimeField() + meta = {'index_types': False, + 'indexes': ['-date']} + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('_types', 1)] not in info) + self.assertTrue([('date', -1)] in info) + BlogPost.drop_collection() class BlogPost(Document): @@ -1227,10 +2068,12 @@ class QuerySetTest(unittest.TestCase): class Test(Document): testdict = DictField() + Test.drop_collection() + t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0) + self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) self.assertEqual(len(Test.objects(testdict__f='Value')), 1) Test.drop_collection() @@ -1289,12 +2132,12 @@ class QuerySetTest(unittest.TestCase): title = StringField() date = DateTimeField() location = GeoPointField() - + def __unicode__(self): return self.title - + Event.drop_collection() - + event1 = Event(title="Coltrane Motion @ Double Door", date=datetime.now() - timedelta(days=1), location=[41.909889, -87.677137]) @@ -1304,7 +2147,7 @@ class QuerySetTest(unittest.TestCase): event3 = Event(title="Coltrane Motion @ Empty Bottle", date=datetime.now(), location=[41.900474, -87.686638]) - + event1.save() event2.save() event3.save() @@ -1324,24 +2167,24 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(event2 not in events) self.assertTrue(event1 in events) self.assertTrue(event3 in events) - + # ensure ordering is respected by "near" events = Event.objects(location__near=[41.9120459, -87.67892]) events = events.order_by("-date") self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event3, event1, event2]) - + # find events within 10 degrees of san francisco point_and_distance = [[37.7566023, -122.415579], 10] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - + # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[40.7237134, -73.9509714], 1] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 0) - + # ensure ordering is respected by "within_distance" point_and_distance = [[41.9120459, -87.67892], 10] events = Event.objects(location__within_distance=point_and_distance) @@ -1355,6 +2198,31 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event2.id) + # check that polygon works for users who have a server >= 1.9 + server_version = tuple( + _get_connection().server_info()['version'].split('.') + ) + required_version = tuple("1.9.0".split(".")) + if server_version >= required_version: + polygon = [ + (41.912114,-87.694445), + (41.919395,-87.69084), + (41.927186,-87.681742), + (41.911731,-87.654276), + (41.898061,-87.656164), + ] + events = Event.objects(location__within_polygon=polygon) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event1.id) + + polygon2 = [ + (54.033586,-1.742249), + (52.792797,-1.225891), + (53.389881,-4.40094) + ] + events = Event.objects(location__within_polygon=polygon2) + self.assertEqual(events.count(), 0) + Event.drop_collection() def test_spherical_geospatial_operators(self): @@ -1429,6 +2297,103 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_custom_querysets_set_manager_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Post(Document): + objects = CustomQuerySetManager() + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_managers_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySetManager(QuerySetManager): + + @staticmethod + def get_queryset(doc_cls, queryset): + return queryset(is_published=True) + + class Post(Document): + is_published = BooleanField(default=False) + published = CustomQuerySetManager() + + Post.drop_collection() + + Post().save() + Post(is_published=True).save() + self.assertEquals(Post.objects.count(), 2) + self.assertEquals(Post.published.count(), 1) + + Post.drop_collection() + + def test_custom_querysets_inherited(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class Base(Document): + meta = {'abstract': True, 'queryset_class': CustomQuerySet} + + class Post(Base): + pass + + Post.drop_collection() + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_inherited_direct(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Base(Document): + meta = {'abstract': True} + objects = CustomQuerySetManager() + + class Post(Base): + pass + + Post.drop_collection() + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ @@ -1447,8 +2412,169 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_order_then_filter(self): + """Ensure that ordering still works after filtering. + """ + class Number(Document): + n = IntField() + + Number.drop_collection() + + n2 = Number.objects.create(n=2) + n1 = Number.objects.create(n=1) + + self.assertEqual(list(Number.objects), [n2, n1]) + self.assertEqual(list(Number.objects.order_by('n')), [n1, n2]) + self.assertEqual(list(Number.objects.order_by('n').filter()), [n1, n2]) + + Number.drop_collection() + + def test_clone(self): + """Ensure that cloning clones complex querysets + """ + class Number(Document): + n = IntField() + + Number.drop_collection() + + for i in xrange(1, 101): + t = Number(n=i) + t.save() + + test = Number.objects + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.filter(n__gt=11) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.limit(10) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + Number.drop_collection() + + def test_unset_reference(self): + class Comment(Document): + text = StringField() + + class Post(Document): + comment = ReferenceField(Comment) + + Comment.drop_collection() + Post.drop_collection() + + comment = Comment.objects.create(text='test') + post = Post.objects.create(comment=comment) + + self.assertEqual(post.comment, comment) + Post.objects.update(unset__comment=1) + post.reload() + self.assertEqual(post.comment, None) + + Comment.drop_collection() + Post.drop_collection() + + def test_order_works_with_custom_db_field_names(self): + class Number(Document): + n = IntField(db_field='number') + + Number.drop_collection() + + n2 = Number.objects.create(n=2) + n1 = Number.objects.create(n=1) + + self.assertEqual(list(Number.objects), [n2,n1]) + self.assertEqual(list(Number.objects.order_by('n')), [n1,n2]) + + Number.drop_collection() + + def test_order_works_with_primary(self): + """Ensure that order_by and primary work. + """ + class Number(Document): + n = IntField(primary_key=True) + + Number.drop_collection() + + Number(n=1).save() + Number(n=2).save() + Number(n=3).save() + + numbers = [n.n for n in Number.objects.order_by('-n')] + self.assertEquals([3, 2, 1], numbers) + + numbers = [n.n for n in Number.objects.order_by('+n')] + self.assertEquals([1, 2, 3], numbers) + Number.drop_collection() + + + def test_ensure_index(self): + """Ensure that manual creation of indexes works. + """ + class Comment(Document): + message = StringField() + + Comment.objects.ensure_index('message') + + info = Comment.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) + + def test_where(self): + """Ensure that where clauses work. + """ + + class IntPair(Document): + fielda = IntField() + fieldb = IntField() + + IntPair.objects._collection.remove() + + a = IntPair(fielda=1, fieldb=1) + b = IntPair(fielda=1, fieldb=2) + c = IntPair(fielda=2, fieldb=1) + a.save() + b.save() + c.save() + + query = IntPair.objects.where('this[~fielda] >= this[~fieldb]') + self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause) + results = list(query) + self.assertEqual(2, len(results)) + self.assertTrue(a in results) + self.assertTrue(c in results) + + query = IntPair.objects.where('this[~fielda] == this[~fieldb]') + results = list(query) + self.assertEqual(1, len(results)) + self.assertTrue(a in results) + + query = IntPair.objects.where('function() { return this[~fielda] >= this[~fieldb] }') + self.assertEqual('function() { return this["fielda"] >= this["fieldb"] }', query._where_clause) + results = list(query) + self.assertEqual(2, len(results)) + self.assertTrue(a in results) + self.assertTrue(c in results) + + def invalid_where(): + list(IntPair.objects.where(fielda__gte=3)) + + self.assertRaises(TypeError, invalid_where) + + class QTest(unittest.TestCase): + def setUp(self): + connect(db='mongoenginetest') + def test_empty_q(self): """Ensure that empty Q objects won't hurt. """ @@ -1467,7 +2593,7 @@ class QTest(unittest.TestCase): query = {'age': {'$gte': 18}, 'name': 'test'} self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) - + def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" connect(db='mongoenginetest') @@ -1509,7 +2635,7 @@ class QTest(unittest.TestCase): query = Q(x__lt=100) & Q(y__ne='NotMyString') query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, + 'x': {'$lt': 100, '$gt': -100}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, } self.assertEqual(query.to_query(TestDoc), mongo_query) @@ -1584,5 +2710,86 @@ class QTest(unittest.TestCase): self.assertTrue(condition in query['$or']) + def test_q_clone(self): + + class TestDoc(Document): + x = IntField() + + TestDoc.drop_collection() + for i in xrange(1, 101): + t = TestDoc(x=i) + t.save() + + # Check normal cases work without an error + test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) + + self.assertEqual(test.count(), 3) + + test2 = test.clone() + self.assertEqual(test2.count(), 3) + self.assertFalse(test2 == test) + + test2.filter(x=6) + self.assertEqual(test2.count(), 1) + self.assertEqual(test.count(), 3) + +class QueryFieldListTest(unittest.TestCase): + def test_empty(self): + q = QueryFieldList() + self.assertFalse(q) + + q = QueryFieldList(always_include=['_cls']) + self.assertFalse(q) + + def test_include_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'b': True}) + + def test_include_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': True, 'b': True}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': True}) + + def test_exclude_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) + + def test_exclude_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': False, 'b': False}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'c': True}) + + def test_always_include(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + + def test_reset(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) + q.reset() + self.assertFalse(q) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) + + def test_using_a_slice(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a'], value={"$slice": 5}) + self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) + + if __name__ == '__main__': unittest.main() diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 00000000..9c413379 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine import signals + +signal_output = [] + + +class SignalTests(unittest.TestCase): + """ + Testing signals before/after saving and deleting. + """ + + def get_signal_output(self, fn, *args, **kwargs): + # Flush any existing signal output + global signal_output + signal_output = [] + fn(*args, **kwargs) + return signal_output + + def setUp(self): + connect(db='mongoenginetest') + class Author(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, sender, document, *args, **kwargs): + signal_output.append('pre_init signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, sender, document, **kwargs): + signal_output.append('post_init signal, %s' % document) + + @classmethod + def pre_save(cls, sender, document, **kwargs): + signal_output.append('pre_save signal, %s' % document) + + @classmethod + def post_save(cls, sender, document, **kwargs): + signal_output.append('post_save signal, %s' % document) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, sender, document, **kwargs): + signal_output.append('pre_delete signal, %s' % document) + + @classmethod + def post_delete(cls, sender, document, **kwargs): + signal_output.append('post_delete signal, %s' % document) + self.Author = Author + + + class Another(Document): + name = StringField() + + def __unicode__(self): + return self.name + + @classmethod + def pre_init(cls, sender, document, **kwargs): + signal_output.append('pre_init Another signal, %s' % cls.__name__) + signal_output.append(str(kwargs['values'])) + + @classmethod + def post_init(cls, sender, document, **kwargs): + signal_output.append('post_init Another signal, %s' % document) + + @classmethod + def pre_save(cls, sender, document, **kwargs): + signal_output.append('pre_save Another signal, %s' % document) + + @classmethod + def post_save(cls, sender, document, **kwargs): + signal_output.append('post_save Another signal, %s' % document) + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + @classmethod + def pre_delete(cls, sender, document, **kwargs): + signal_output.append('pre_delete Another signal, %s' % document) + + @classmethod + def post_delete(cls, sender, document, **kwargs): + signal_output.append('post_delete Another signal, %s' % document) + + self.Another = Another + # Save up the number of connected signals so that we can check at the end + # that all the signals we register get properly unregistered + self.pre_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + signals.pre_init.connect(Author.pre_init, sender=Author) + signals.post_init.connect(Author.post_init, sender=Author) + signals.pre_save.connect(Author.pre_save, sender=Author) + signals.post_save.connect(Author.post_save, sender=Author) + signals.pre_delete.connect(Author.pre_delete, sender=Author) + signals.post_delete.connect(Author.post_delete, sender=Author) + + signals.pre_init.connect(Another.pre_init, sender=Another) + signals.post_init.connect(Another.post_init, sender=Another) + signals.pre_save.connect(Another.pre_save, sender=Another) + signals.post_save.connect(Another.post_save, sender=Another) + signals.pre_delete.connect(Another.pre_delete, sender=Another) + signals.post_delete.connect(Another.post_delete, sender=Another) + + def tearDown(self): + signals.pre_init.disconnect(self.Author.pre_init) + signals.post_init.disconnect(self.Author.post_init) + signals.post_delete.disconnect(self.Author.post_delete) + signals.pre_delete.disconnect(self.Author.pre_delete) + signals.post_save.disconnect(self.Author.post_save) + signals.pre_save.disconnect(self.Author.pre_save) + + signals.pre_init.disconnect(self.Another.pre_init) + signals.post_init.disconnect(self.Another.post_init) + signals.post_delete.disconnect(self.Another.post_delete) + signals.pre_delete.disconnect(self.Another.pre_delete) + signals.post_save.disconnect(self.Another.post_save) + signals.pre_save.disconnect(self.Another.pre_save) + + # Check that all our signals got disconnected properly. + post_signals = ( + len(signals.pre_init.receivers), + len(signals.post_init.receivers), + len(signals.pre_save.receivers), + len(signals.post_save.receivers), + len(signals.pre_delete.receivers), + len(signals.post_delete.receivers) + ) + + self.assertEqual(self.pre_signals, post_signals) + + def test_model_signals(self): + """ Model saves should throw some signals. """ + + def create_author(): + a1 = self.Author(name='Bill Shakespeare') + + self.assertEqual(self.get_signal_output(create_author), [ + "pre_init signal, Author", + "{'name': 'Bill Shakespeare'}", + "post_init signal, Bill Shakespeare", + ]) + + a1 = self.Author(name='Bill Shakespeare') + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, Bill Shakespeare", + "post_save signal, Bill Shakespeare", + "Is created" + ]) + + a1.reload() + a1.name='William Shakespeare' + self.assertEqual(self.get_signal_output(a1.save), [ + "pre_save signal, William Shakespeare", + "post_save signal, William Shakespeare", + "Is updated" + ]) + + self.assertEqual(self.get_signal_output(a1.delete), [ + 'pre_delete signal, William Shakespeare', + 'post_delete signal, William Shakespeare', + ]) \ No newline at end of file