diff --git a/.travis.yml b/.travis.yml index 7fb55e36..e78bda5a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,18 +2,20 @@ language: python services: mongodb python: - - "2.5" - "2.6" - "2.7" - "3.2" - "3.3" env: - - PYMONGO=dev - - PYMONGO=2.5 - - PYMONGO=2.4.2 + - PYMONGO=dev DJANGO=1.5.1 + - PYMONGO=dev DJANGO=1.4.2 + - PYMONGO=2.5 DJANGO=1.5.1 + - PYMONGO=2.5 DJANGO=1.4.2 + - PYMONGO=2.4.2 DJANGO=1.4.2 install: - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi + - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install django==$DJANGO --use-mirrors ; true; fi - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi - python setup.py install @@ -24,4 +26,3 @@ notifications: branches: only: - master - - "0.8" diff --git a/AUTHORS b/AUTHORS index aeb672cb..44e19bf6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -124,6 +124,7 @@ that much better: * Stefan Wójcik * dimonb * Garry Polley + * James Slagle * Adrian Scott * Peter Teichman * Jakub Kot @@ -131,4 +132,28 @@ that much better: * Aleksandr Sorokoumov * Yohan Graterol * bool-dev - * Russ Weeks \ No newline at end of file + * Russ Weeks + * Paul Swartz + * Sundar Raman + * Benoit Louy + * lraucy + * hellysmile + * Jaepil Jeong + * Daniil Sharou + * Stefan Wójcik + * Pete Campton + * Martyn Smith + * Marcelo Anton + * Aleksey Porfirov + * Nicolas Trippar + * Manuel Hermann + * Gustavo Gawryszewski + * Max Countryman + * caitifbrito + * lcya86 刘春洋 + * Martin Alderete (https://github.com/malderete) + * Nick Joyce + * Jared Forsyth + * Kenneth Falck + * Lukasz Balcerzak + * Nicolas Cortot diff --git a/benchmark.py b/benchmark.py index 0197e1d7..16b2fd47 100644 --- a/benchmark.py +++ b/benchmark.py @@ -86,17 +86,43 @@ def main(): ---------------------------------------------------------------------------------------------------- Creating 10000 dictionaries - MongoEngine, force=True 8.36906409264 + 0.8.X + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - Pymongo + 3.69964408875 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - Pymongo write_concern={"w": 0} + 3.5526599884 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine + 7.00959801674 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries without continual assign - MongoEngine + 5.60943293571 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True + 6.715102911 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True + 5.50644683838 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False + 4.69851183891 + ---------------------------------------------------------------------------------------------------- + Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False + 4.68946313858 + ---------------------------------------------------------------------------------------------------- """ setup = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() connection.drop_database('timeit_test') """ stmt = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() db = connection.timeit_test noddy = db.noddy @@ -106,7 +132,7 @@ for i in xrange(10000): for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) - noddy.insert(example) + noddy.save(example) myNoddys = noddy.find() [n for n in myNoddys] # iterate @@ -117,9 +143,32 @@ myNoddys = noddy.find() t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) + stmt = """ +from pymongo import MongoClient +connection = MongoClient() + +db = connection.timeit_test +noddy = db.noddy + +for i in xrange(10000): + example = {'fields': {}} + for j in range(20): + example['fields']["key"+str(j)] = "value "+str(j) + + noddy.save(example, write_concern={"w": 0}) + +myNoddys = noddy.find() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + setup = """ -from pymongo import Connection -connection = Connection() +from pymongo import MongoClient +connection = MongoClient() connection.drop_database('timeit_test') connection.disconnect() @@ -149,33 +198,18 @@ myNoddys = Noddy.objects() stmt = """ for i in xrange(10000): noddy = Noddy() + fields = {} for j in range(20): - noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(safe=False, validate=False) + fields["key"+str(j)] = "value "+str(j) + noddy.fields = fields + noddy.save() myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, safe=False, validate=False""" - t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) - - - stmt = """ -for i in xrange(10000): - noddy = Noddy() - for j in range(20): - noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(safe=False, validate=False, cascade=False) - -myNoddys = Noddy.objects() -[n for n in myNoddys] # iterate -""" - - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False""" + print """Creating 10000 dictionaries without continual assign - MongoEngine""" t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) @@ -184,16 +218,65 @@ for i in xrange(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) - noddy.save(force_insert=True, safe=False, validate=False, cascade=False) + noddy.save(write_concern={"w": 0}, cascade=True) myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, force=True""" + print """Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""" t = timeit.Timer(stmt=stmt, setup=setup) print t.timeit(1) + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(write_concern={"w": 0}, validate=False, cascade=True) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(validate=False, write_concern={"w": 0}) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + stmt = """ +for i in xrange(10000): + noddy = Noddy() + for j in range(20): + noddy.fields["key"+str(j)] = "value "+str(j) + noddy.save(force_insert=True, write_concern={"w": 0}, validate=False) + +myNoddys = Noddy.objects() +[n for n in myNoddys] # iterate +""" + + print "-" * 100 + print """Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""" + t = timeit.Timer(stmt=stmt, setup=setup) + print t.timeit(1) + + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/docs/apireference.rst b/docs/apireference.rst index 0f8901a1..3a156299 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -34,6 +34,13 @@ Documents .. autoclass:: mongoengine.ValidationError :members: +Context Managers +================ + +.. autoclass:: mongoengine.context_managers.switch_db +.. autoclass:: mongoengine.context_managers.no_dereference +.. autoclass:: mongoengine.context_managers.query_counter + Querying ======== @@ -47,28 +54,33 @@ Querying Fields ====== -.. autoclass:: mongoengine.BinaryField -.. autoclass:: mongoengine.BooleanField -.. autoclass:: mongoengine.ComplexDateTimeField -.. autoclass:: mongoengine.DateTimeField -.. autoclass:: mongoengine.DecimalField -.. autoclass:: mongoengine.DictField -.. autoclass:: mongoengine.DynamicField -.. autoclass:: mongoengine.EmailField -.. autoclass:: mongoengine.EmbeddedDocumentField -.. autoclass:: mongoengine.FileField -.. autoclass:: mongoengine.FloatField -.. autoclass:: mongoengine.GenericEmbeddedDocumentField -.. autoclass:: mongoengine.GenericReferenceField -.. autoclass:: mongoengine.GeoPointField -.. autoclass:: mongoengine.ImageField -.. autoclass:: mongoengine.IntField -.. autoclass:: mongoengine.ListField -.. autoclass:: mongoengine.MapField -.. autoclass:: mongoengine.ObjectIdField -.. autoclass:: mongoengine.ReferenceField -.. autoclass:: mongoengine.SequenceField -.. autoclass:: mongoengine.SortedListField -.. autoclass:: mongoengine.StringField -.. autoclass:: mongoengine.URLField -.. autoclass:: mongoengine.UUIDField +.. autoclass:: mongoengine.fields.StringField +.. autoclass:: mongoengine.fields.URLField +.. autoclass:: mongoengine.fields.EmailField +.. autoclass:: mongoengine.fields.IntField +.. autoclass:: mongoengine.fields.LongField +.. autoclass:: mongoengine.fields.FloatField +.. autoclass:: mongoengine.fields.DecimalField +.. autoclass:: mongoengine.fields.BooleanField +.. autoclass:: mongoengine.fields.DateTimeField +.. autoclass:: mongoengine.fields.ComplexDateTimeField +.. autoclass:: mongoengine.fields.EmbeddedDocumentField +.. autoclass:: mongoengine.fields.GenericEmbeddedDocumentField +.. autoclass:: mongoengine.fields.DynamicField +.. autoclass:: mongoengine.fields.ListField +.. autoclass:: mongoengine.fields.SortedListField +.. autoclass:: mongoengine.fields.DictField +.. autoclass:: mongoengine.fields.MapField +.. autoclass:: mongoengine.fields.ReferenceField +.. autoclass:: mongoengine.fields.GenericReferenceField +.. autoclass:: mongoengine.fields.BinaryField +.. autoclass:: mongoengine.fields.FileField +.. autoclass:: mongoengine.fields.ImageField +.. autoclass:: mongoengine.fields.GeoPointField +.. autoclass:: mongoengine.fields.SequenceField +.. autoclass:: mongoengine.fields.ObjectIdField +.. autoclass:: mongoengine.fields.UUIDField +.. autoclass:: mongoengine.fields.GridFSError +.. autoclass:: mongoengine.fields.GridFSProxy +.. autoclass:: mongoengine.fields.ImageGridFsProxy +.. autoclass:: mongoengine.fields.ImproperlyConfigured diff --git a/docs/changelog.rst b/docs/changelog.rst index 65a5aaf1..f786c1d6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,8 +2,78 @@ Changelog ========= +Changes in 0.8.X +================ +- Document serialization uses field order to ensure a strict order is set (#296) +- DecimalField now stores as float not string (#289) +- UUIDField now stores as a binary by default (#292) +- Added Custom User Model for Django 1.5 (#285) +- Cascading saves now default to off (#291) +- ReferenceField now store ObjectId's by default rather than DBRef (#290) +- Added ImageField support for inline replacements (#86) +- Added SequenceField.set_next_value(value) helper (#159) +- Updated .only() behaviour - now like exclude it is chainable (#202) +- Added with_limit_and_skip support to count() (#235) +- Removed __len__ from queryset (#247) +- Objects queryset manager now inherited (#256) +- Updated connection to use MongoClient (#262, #274) +- Fixed db_alias and inherited Documents (#143) +- Documentation update for document errors (#124) +- Deprecated `get_or_create` (#35) +- Updated inheritable objects created by upsert now contain _cls (#118) +- Added support for creating documents with embedded documents in a single operation (#6) +- Added to_json and from_json to Document (#1) +- Added to_json and from_json to QuerySet (#131) +- Updated index creation now tied to Document class (#102) +- Added none() to queryset (#127) +- Updated SequenceFields to allow post processing of the calculated counter value (#141) +- Added clean method to documents for pre validation data cleaning (#60) +- Added support setting for read prefrence at a query level (#157) +- Added _instance to EmbeddedDocuments pointing to the parent (#139) +- Inheritance is off by default (#122) +- Remove _types and just use _cls for inheritance (#148) +- Only allow QNode instances to be passed as query objects (#199) +- Dynamic fields are now validated on save (#153) (#154) +- Added support for multiple slices and made slicing chainable. (#170) (#190) (#191) +- Fixed GridFSProxy __getattr__ behaviour (#196) +- Fix Django timezone support (#151) +- Simplified Q objects, removed QueryTreeTransformerVisitor (#98) (#171) +- FileFields now copyable (#198) +- Querysets now return clones and are no longer edit in place (#56) +- Added support for $maxDistance (#179) +- Uses getlasterror to test created on updated saves (#163) +- Fixed inheritance and unique index creation (#140) +- Fixed reverse delete rule with inheritance (#197) +- Fixed validation for GenericReferences which havent been dereferenced +- Added switch_db context manager (#106) +- Added switch_db method to document instances (#106) +- Added no_dereference context manager (#82) (#61) +- Added switch_collection context manager (#220) +- Added switch_collection method to document instances (#220) +- Added support for compound primary keys (#149) (#121) +- Fixed overriding objects with custom manager (#58) +- Added no_dereference method for querysets (#82) (#61) +- Undefined data should not override instance methods (#49) +- Added Django Group and Permission (#142) +- Added Doc class and pk to Validation messages (#69) +- Fixed Documents deleted via a queryset don't call any signals (#105) +- Added the "get_decoded" method to the MongoSession class (#216) +- Fixed invalid choices error bubbling (#214) +- Updated Save so it calls $set and $unset in a single operation (#211) +- Fixed inner queryset looping (#204) + Changes in 0.7.10 ================= +- Fix UnicodeEncodeError for dbref (#278) +- Allow construction using positional parameters (#268) +- Updated EmailField length to support long domains (#243) +- Added 64-bit integer support (#251) +- Added Django sessions TTL support (#224) +- Fixed issue with numerical keys in MapField(EmbeddedDocumentField()) (#240) +- Fixed clearing _changed_fields for complex nested embedded documents (#237, #239, #242) +- Added "id" back to _data dictionary (#255) +- Only mark a field as changed if the value has changed (#258) +- Explicitly check for Document instances when dereferencing (#261) - Fixed order_by chaining issue (#265) - Added dereference support for tuples (#250) - Resolve field name to db field name when using distinct(#260, #264, #269) @@ -19,12 +89,12 @@ Changes in 0.7.9 Changes in 0.7.8 ================ -- Fix sequence fields in embedded documents (MongoEngine/mongoengine#166) -- Fix query chaining with .order_by() (MongoEngine/mongoengine#176) -- Added optional encoding and collection config for Django sessions (MongoEngine/mongoengine#180, MongoEngine/mongoengine#181, MongoEngine/mongoengine#183) -- Fixed EmailField so can add extra validation (MongoEngine/mongoengine#173, MongoEngine/mongoengine#174, MongoEngine/mongoengine#187) -- Fixed bulk inserts can now handle custom pk's (MongoEngine/mongoengine#192) -- Added as_pymongo method to return raw or cast results from pymongo (MongoEngine/mongoengine#193) +- Fix sequence fields in embedded documents (#166) +- Fix query chaining with .order_by() (#176) +- Added optional encoding and collection config for Django sessions (#180, #181, #183) +- Fixed EmailField so can add extra validation (#173, #174, #187) +- Fixed bulk inserts can now handle custom pk's (#192) +- Added as_pymongo method to return raw or cast results from pymongo (#193) Changes in 0.7.7 ================ @@ -32,70 +102,70 @@ Changes in 0.7.7 Changes in 0.7.6 ================ -- Unicode fix for repr (MongoEngine/mongoengine#133) -- Allow updates with match operators (MongoEngine/mongoengine#144) -- Updated URLField - now can have a override the regex (MongoEngine/mongoengine#136) +- Unicode fix for repr (#133) +- Allow updates with match operators (#144) +- Updated URLField - now can have a override the regex (#136) - Allow Django AuthenticationBackends to work with Django user (hmarr/mongoengine#573) -- Fixed reload issue with ReferenceField where dbref=False (MongoEngine/mongoengine#138) +- Fixed reload issue with ReferenceField where dbref=False (#138) Changes in 0.7.5 ================ -- ReferenceFields with dbref=False use ObjectId instead of strings (MongoEngine/mongoengine#134) - See ticket for upgrade notes (https://github.com/MongoEngine/mongoengine/issues/134) +- ReferenceFields with dbref=False use ObjectId instead of strings (#134) + See ticket for upgrade notes (#134) Changes in 0.7.4 ================ -- Fixed index inheritance issues - firmed up testcases (MongoEngine/mongoengine#123) (MongoEngine/mongoengine#125) +- Fixed index inheritance issues - firmed up testcases (#123) (#125) Changes in 0.7.3 ================ -- Reverted EmbeddedDocuments meta handling - now can turn off inheritance (MongoEngine/mongoengine#119) +- Reverted EmbeddedDocuments meta handling - now can turn off inheritance (#119) Changes in 0.7.2 ================ -- Update index spec generation so its not destructive (MongoEngine/mongoengine#113) +- Update index spec generation so its not destructive (#113) Changes in 0.7.1 ================= -- Fixed index spec inheritance (MongoEngine/mongoengine#111) +- Fixed index spec inheritance (#111) Changes in 0.7.0 ================= -- Updated queryset.delete so you can use with skip / limit (MongoEngine/mongoengine#107) -- Updated index creation allows kwargs to be passed through refs (MongoEngine/mongoengine#104) -- Fixed Q object merge edge case (MongoEngine/mongoengine#109) +- Updated queryset.delete so you can use with skip / limit (#107) +- Updated index creation allows kwargs to be passed through refs (#104) +- Fixed Q object merge edge case (#109) - Fixed reloading on sharded documents (hmarr/mongoengine#569) -- Added NotUniqueError for duplicate keys (MongoEngine/mongoengine#62) -- Added custom collection / sequence naming for SequenceFields (MongoEngine/mongoengine#92) -- Fixed UnboundLocalError in composite index with pk field (MongoEngine/mongoengine#88) +- Added NotUniqueError for duplicate keys (#62) +- Added custom collection / sequence naming for SequenceFields (#92) +- Fixed UnboundLocalError in composite index with pk field (#88) - Updated ReferenceField's to optionally store ObjectId strings - this will become the default in 0.8 (MongoEngine/mongoengine#89) + this will become the default in 0.8 (#89) - Added FutureWarning - save will default to `cascade=False` in 0.8 -- Added example of indexing embedded document fields (MongoEngine/mongoengine#75) -- Fixed ImageField resizing when forcing size (MongoEngine/mongoengine#80) -- Add flexibility for fields handling bad data (MongoEngine/mongoengine#78) +- Added example of indexing embedded document fields (#75) +- Fixed ImageField resizing when forcing size (#80) +- Add flexibility for fields handling bad data (#78) - Embedded Documents no longer handle meta definitions -- Use weakref proxies in base lists / dicts (MongoEngine/mongoengine#74) +- Use weakref proxies in base lists / dicts (#74) - Improved queryset filtering (hmarr/mongoengine#554) - Fixed Dynamic Documents and Embedded Documents (hmarr/mongoengine#561) -- Fixed abstract classes and shard keys (MongoEngine/mongoengine#64) +- Fixed abstract classes and shard keys (#64) - Fixed Python 2.5 support - Added Python 3 support (thanks to Laine Heron) Changes in 0.6.20 ================= -- Added support for distinct and db_alias (MongoEngine/mongoengine#59) +- Added support for distinct and db_alias (#59) - Improved support for chained querysets when constraining the same fields (hmarr/mongoengine#554) -- Fixed BinaryField lookup re (MongoEngine/mongoengine#48) +- Fixed BinaryField lookup re (#48) Changes in 0.6.19 ================= -- Added Binary support to UUID (MongoEngine/mongoengine#47) -- Fixed MapField lookup for fields without declared lookups (MongoEngine/mongoengine#46) -- Fixed BinaryField python value issue (MongoEngine/mongoengine#48) -- Fixed SequenceField non numeric value lookup (MongoEngine/mongoengine#41) -- Fixed queryset manager issue (MongoEngine/mongoengine#52) +- Added Binary support to UUID (#47) +- Fixed MapField lookup for fields without declared lookups (#46) +- Fixed BinaryField python value issue (#48) +- Fixed SequenceField non numeric value lookup (#41) +- Fixed queryset manager issue (#52) - Fixed FileField comparision (hmarr/mongoengine#547) Changes in 0.6.18 diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py index 6ba1eee2..0e40e899 100644 --- a/docs/code/tumblelog.py +++ b/docs/code/tumblelog.py @@ -45,7 +45,7 @@ print 'ALL POSTS' print for post in Post.objects: print post.title - print '=' * len(post.title) + print '=' * post.title.count() if isinstance(post, TextPost): print post.content diff --git a/docs/conf.py b/docs/conf.py index 62fa1505..8bcb9ec9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ import sys, os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.append(os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('..')) # -- General configuration ----------------------------------------------------- @@ -38,7 +38,7 @@ master_doc = 'index' # General information about the project. project = u'MongoEngine' -copyright = u'2009-2012, MongoEngine Authors' +copyright = u'2009, MongoEngine Authors' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -173,8 +173,8 @@ latex_paper_size = 'a4' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'MongoEngine.tex', u'MongoEngine Documentation', - u'Harry Marr', 'manual'), + ('index', 'MongoEngine.tex', 'MongoEngine Documentation', + 'Ross Lawley', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -193,3 +193,6 @@ latex_documents = [ # If false, no module index is generated. #latex_use_modindex = True + +autoclass_content = 'both' + diff --git a/docs/django.rst b/docs/django.rst index 144baab5..d60e55d9 100644 --- a/docs/django.rst +++ b/docs/django.rst @@ -2,7 +2,7 @@ Using MongoEngine with Django ============================= -.. note :: Updated to support Django 1.4 +.. note:: Updated to support Django 1.4 Connecting ========== @@ -10,6 +10,16 @@ In your **settings.py** file, ignore the standard database settings (unless you also plan to use the ORM in your project), and instead call :func:`~mongoengine.connect` somewhere in the settings module. +.. note:: + If you are not using another Database backend you may need to add a dummy + database backend to ``settings.py`` eg:: + + DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.dummy' + } + } + Authentication ============== MongoEngine includes a Django authentication backend, which uses MongoDB. The @@ -32,6 +42,42 @@ The :mod:`~mongoengine.django.auth` module also contains a .. versionadded:: 0.1.3 +Custom User model +================= +Django 1.5 introduced `Custom user Models +` +which can be used as an alternative the Mongoengine authentication backend. + +The main advantage of this option is that other components relying on +:mod:`django.contrib.auth` and supporting the new swappable user model are more +likely to work. For example, you can use the ``createsuperuser`` management +command as usual. + +To enable the custom User model in Django, add ``'mongoengine.django.mongo_auth'`` +in your ``INSTALLED_APPS`` and set ``'mongo_auth.MongoUser'`` as the custom user +user model to use. In your **settings.py** file you will have:: + + INSTALLED_APPS = ( + ... + 'django.contrib.auth', + 'mongoengine.django.mongo_auth', + ... + ) + + AUTH_USER_MODEL = 'mongo_auth.MongoUser' + +An additional ``MONGOENGINE_USER_DOCUMENT`` setting enables you to replace the +:class:`~mongoengine.django.auth.User` class with another class of your choice:: + + MONGOENGINE_USER_DOCUMENT = 'mongoengine.django.auth.User' + +The custom :class:`User` must be a :class:`~mongoengine.Document` class, but +otherwise has the same requirements as a standard custom user model, +as specified in the `Django Documentation +`. +In particular, the custom class must define :attr:`USERNAME_FIELD` and +:attr:`REQUIRED_FIELDS` attributes. + Sessions ======== Django allows the use of different backend stores for its sessions. MongoEngine @@ -45,11 +91,14 @@ into you settings module:: SESSION_ENGINE = 'mongoengine.django.sessions' +Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL +`_. + .. versionadded:: 0.2.1 Storage ======= -With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, +With MongoEngine's support for GridFS via the :class:`~mongoengine.fields.FileField`, it is useful to have a Django file storage backend that wraps this. The new storage module is called :class:`~mongoengine.django.storage.GridFSStorage`. Using it is very similar to using the default FileSystemStorage.:: diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index bc45dbfe..8674b5eb 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -6,20 +6,23 @@ Connecting to MongoDB To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` function. The first argument is the name of the -database to connect to. If the database does not exist, it will be created. If -the database requires authentication, :attr:`username` and :attr:`password` -arguments may be provided:: +database to connect to:: from mongoengine import connect - connect('project1', username='webapp', password='pwd123') + connect('project1') By default, MongoEngine assumes that the :program:`mongod` instance is running -on **localhost** on port **27017**. If MongoDB is running elsewhere, you may -provide :attr:`host` and :attr:`port` arguments to +on **localhost** on port **27017**. If MongoDB is running elsewhere, you should +provide the :attr:`host` and :attr:`port` arguments to :func:`~mongoengine.connect`:: connect('project1', host='192.168.1.35', port=12345) +If the database requires authentication, :attr:`username` and :attr:`password` +arguments should be provided:: + + connect('project1', username='webapp', password='pwd123') + Uri style connections are also supported as long as you include the database name - just supply the uri as the :attr:`host` to :func:`~mongoengine.connect`:: @@ -29,10 +32,16 @@ name - just supply the uri as the :attr:`host` to ReplicaSets =========== -MongoEngine now supports :func:`~pymongo.replica_set_connection.ReplicaSetConnection` +MongoEngine supports :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` to use them please use a URI style connection and provide the `replicaSet` name in the connection kwargs. +Read preferences are supported throught the connection or via individual +queries by passing the read_preference :: + + Bar.objects().read_preference(ReadPreference.PRIMARY) + Bar.objects(read_preference=ReadPreference.PRIMARY) + Multiple Databases ================== @@ -63,3 +72,28 @@ to point across databases and collections. Below is an example schema, using book = ReferenceField(Book) meta = {"db_alias": "users-books-db"} + + +Switch Database Context Manager +=============================== + +Sometimes you may want to switch the database to query against for a class +for example, archiving older data into a separate database for performance +reasons. + +The :class:`~mongoengine.context_managers.switch_db` context manager allows +you to change the database alias for a given class allowing quick and easy +access to the same User document across databases.eg :: + + from mongoengine.context_managers import switch_db + + class User(Document): + name = StringField() + + meta = {"db_alias": "user-db"} + + with switch_db(User, 'archive-user-db') as User: + User(name="Ross").save() # Saves the 'archive-user-db' + +.. note:: Make sure any aliases have been registered with + :func:`~mongoengine.register_connection` before using the context manager. diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 3ee77961..36e0efea 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -24,6 +24,9 @@ objects** as class attributes to the document class:: title = StringField(max_length=200, required=True) date_modified = DateTimeField(default=datetime.datetime.now) +As BSON (the binary format for storing data in mongodb) is order dependent, +documents are serialized based on their field order. + Dynamic document schemas ======================== One of the benefits of MongoDb is dynamic schemas for a collection, whilst data @@ -47,10 +50,11 @@ be saved :: >>> Page.objects(tags='mongoengine').count() >>> 1 -..note:: +.. note:: There is one caveat on Dynamic Documents: fields cannot start with `_` +Dynamic fields are stored in alphabetical order *after* any declared fields. Fields ====== @@ -62,31 +66,31 @@ not provided. Default values may optionally be a callable, which will be called to retrieve the value (such as in the above example). The field types available are as follows: -* :class:`~mongoengine.BinaryField` -* :class:`~mongoengine.BooleanField` -* :class:`~mongoengine.ComplexDateTimeField` -* :class:`~mongoengine.DateTimeField` -* :class:`~mongoengine.DecimalField` -* :class:`~mongoengine.DictField` -* :class:`~mongoengine.DynamicField` -* :class:`~mongoengine.EmailField` -* :class:`~mongoengine.EmbeddedDocumentField` -* :class:`~mongoengine.FileField` -* :class:`~mongoengine.FloatField` -* :class:`~mongoengine.GenericEmbeddedDocumentField` -* :class:`~mongoengine.GenericReferenceField` -* :class:`~mongoengine.GeoPointField` -* :class:`~mongoengine.ImageField` -* :class:`~mongoengine.IntField` -* :class:`~mongoengine.ListField` -* :class:`~mongoengine.MapField` -* :class:`~mongoengine.ObjectIdField` -* :class:`~mongoengine.ReferenceField` -* :class:`~mongoengine.SequenceField` -* :class:`~mongoengine.SortedListField` -* :class:`~mongoengine.StringField` -* :class:`~mongoengine.URLField` -* :class:`~mongoengine.UUIDField` +* :class:`~mongoengine.fields.BinaryField` +* :class:`~mongoengine.fields.BooleanField` +* :class:`~mongoengine.fields.ComplexDateTimeField` +* :class:`~mongoengine.fields.DateTimeField` +* :class:`~mongoengine.fields.DecimalField` +* :class:`~mongoengine.fields.DictField` +* :class:`~mongoengine.fields.DynamicField` +* :class:`~mongoengine.fields.EmailField` +* :class:`~mongoengine.fields.EmbeddedDocumentField` +* :class:`~mongoengine.fields.FileField` +* :class:`~mongoengine.fields.FloatField` +* :class:`~mongoengine.fields.GenericEmbeddedDocumentField` +* :class:`~mongoengine.fields.GenericReferenceField` +* :class:`~mongoengine.fields.GeoPointField` +* :class:`~mongoengine.fields.ImageField` +* :class:`~mongoengine.fields.IntField` +* :class:`~mongoengine.fields.ListField` +* :class:`~mongoengine.fields.MapField` +* :class:`~mongoengine.fields.ObjectIdField` +* :class:`~mongoengine.fields.ReferenceField` +* :class:`~mongoengine.fields.SequenceField` +* :class:`~mongoengine.fields.SortedListField` +* :class:`~mongoengine.fields.StringField` +* :class:`~mongoengine.fields.URLField` +* :class:`~mongoengine.fields.UUIDField` Field arguments --------------- @@ -110,7 +114,7 @@ arguments can be set on all fields: The definion of default parameters follow `the general rules on Python `__, which means that some care should be taken when dealing with default mutable objects - (like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`):: + (like in :class:`~mongoengine.fields.ListField` or :class:`~mongoengine.fields.DictField`):: class ExampleFirst(Document): # Default an empty list @@ -135,7 +139,8 @@ arguments can be set on all fields: field, will not have two documents in the collection with the same value. :attr:`primary_key` (Default: False) - When True, use this field as a primary key for the collection. + When True, use this field as a primary key for the collection. `DictField` + and `EmbeddedDocuments` both support being the primary key for a document. :attr:`choices` (Default: None) An iterable (e.g. a list or tuple) of choices to which the value of this @@ -171,8 +176,8 @@ arguments can be set on all fields: List fields ----------- MongoDB allows the storage of lists of items. To add a list of items to a -:class:`~mongoengine.Document`, use the :class:`~mongoengine.ListField` field -type. :class:`~mongoengine.ListField` takes another field object as its first +:class:`~mongoengine.Document`, use the :class:`~mongoengine.fields.ListField` field +type. :class:`~mongoengine.fields.ListField` takes another field object as its first argument, which specifies which type elements may be stored within the list:: class Page(Document): @@ -190,7 +195,7 @@ inherit from :class:`~mongoengine.EmbeddedDocument` rather than content = StringField() To embed the document within another document, use the -:class:`~mongoengine.EmbeddedDocumentField` field type, providing the embedded +:class:`~mongoengine.fields.EmbeddedDocumentField` field type, providing the embedded document class as the first argument:: class Page(Document): @@ -205,7 +210,7 @@ Dictionary Fields Often, an embedded document may be used instead of a dictionary -- generally this is recommended as dictionaries don't support validation or custom field types. However, sometimes you will not know the structure of what you want to -store; in this situation a :class:`~mongoengine.DictField` is appropriate:: +store; in this situation a :class:`~mongoengine.fields.DictField` is appropriate:: class SurveyResponse(Document): date = DateTimeField() @@ -223,7 +228,7 @@ other objects, so are the most flexible field type available. Reference fields ---------------- References may be stored to other documents in the database using the -:class:`~mongoengine.ReferenceField`. Pass in another document class as the +:class:`~mongoengine.fields.ReferenceField`. Pass in another document class as the first argument to the constructor, then simply assign document objects to the field:: @@ -244,9 +249,9 @@ field:: The :class:`User` object is automatically turned into a reference behind the scenes, and dereferenced when the :class:`Page` object is retrieved. -To add a :class:`~mongoengine.ReferenceField` that references the document +To add a :class:`~mongoengine.fields.ReferenceField` that references the document being defined, use the string ``'self'`` in place of the document class as the -argument to :class:`~mongoengine.ReferenceField`'s constructor. To reference a +argument to :class:`~mongoengine.fields.ReferenceField`'s constructor. To reference a document that has not yet been defined, use the name of the undefined document as the constructor's argument:: @@ -324,7 +329,7 @@ Its value can take any of the following constants: :const:`mongoengine.PULL` Removes the reference to the object (using MongoDB's "pull" operation) from any object's fields of - :class:`~mongoengine.ListField` (:class:`~mongoengine.ReferenceField`). + :class:`~mongoengine.fields.ListField` (:class:`~mongoengine.fields.ReferenceField`). .. warning:: @@ -351,7 +356,7 @@ Its value can take any of the following constants: Generic reference fields '''''''''''''''''''''''' A second kind of reference field also exists, -:class:`~mongoengine.GenericReferenceField`. This allows you to reference any +:class:`~mongoengine.fields.GenericReferenceField`. This allows you to reference any kind of :class:`~mongoengine.Document`, and hence doesn't take a :class:`~mongoengine.Document` subclass as a constructor argument:: @@ -375,15 +380,15 @@ kind of :class:`~mongoengine.Document`, and hence doesn't take a .. note:: - Using :class:`~mongoengine.GenericReferenceField`\ s is slightly less - efficient than the standard :class:`~mongoengine.ReferenceField`\ s, so if + Using :class:`~mongoengine.fields.GenericReferenceField`\ s is slightly less + efficient than the standard :class:`~mongoengine.fields.ReferenceField`\ s, so if you will only be referencing one document type, prefer the standard - :class:`~mongoengine.ReferenceField`. + :class:`~mongoengine.fields.ReferenceField`. Uniqueness constraints ---------------------- MongoEngine allows you to specify that a field should be unique across a -collection by providing ``unique=True`` to a :class:`~mongoengine.Field`\ 's +collection by providing ``unique=True`` to a :class:`~mongoengine.fields.Field`\ 's constructor. If you try to save a document that has the same value for a unique field as a document that is already in the database, a :class:`~mongoengine.OperationError` will be raised. You may also specify @@ -441,6 +446,7 @@ The following example shows a :class:`Log` document that will be limited to Indexes ======= + You can specify indexes on collections to make querying faster. This is done by creating a list of index specifications called :attr:`indexes` in the :attr:`~mongoengine.Document.meta` dictionary, where an index specification may @@ -461,9 +467,11 @@ If a dictionary is passed then the following options are available: :attr:`fields` (Default: None) The fields to index. Specified in the same format as described above. -:attr:`types` (Default: True) - Whether the index should have the :attr:`_types` field added automatically - to the start of the index. +:attr:`cls` (Default: True) + If you have polymorphic models that inherit and have + :attr:`allow_inheritance` turned on, you can configure whether the index + should have the :attr:`_cls` field added automatically to the start of the + index. :attr:`sparse` (Default: False) Whether the index should be sparse. @@ -471,26 +479,28 @@ If a dictionary is passed then the following options are available: :attr:`unique` (Default: False) Whether the index should be unique. -.. note :: +.. note:: - To index embedded files / dictionary fields use 'dot' notation eg: - `rank.title` + Inheritance adds extra fields indices see: :ref:`document-inheritance`. -.. warning:: +Compound Indexes and Indexing sub documents +------------------------------------------- - Inheritance adds extra indices. - If don't need inheritance for a document turn inheritance off - - see :ref:`document-inheritance`. +Compound indexes can be created by adding the Embedded field or dictionary +field name to the index definition. +Sometimes its more efficient to index parts of Embeedded / dictionary fields, +in this case use 'dot' notation to identify the value to index eg: `rank.title` Geospatial indexes ---------------------------- +------------------ + Geospatial indexes will be automatically created for all -:class:`~mongoengine.GeoPointField`\ s +:class:`~mongoengine.fields.GeoPointField`\ s It is also possible to explicitly define geospatial indexes. This is useful if you need to define a geospatial index on a subfield of a -:class:`~mongoengine.DictField` or a custom field that contains a +:class:`~mongoengine.fields.DictField` or a custom field that contains a point. To create a geospatial index you must prefix the field with the ***** sign. :: @@ -572,7 +582,9 @@ defined, you may subclass it and add any extra fields or methods you may need. As this is new class is not a direct subclass of :class:`~mongoengine.Document`, it will not be stored in its own collection; it will use the same collection as its superclass uses. This allows for more -convenient and efficient retrieval of related documents:: +convenient and efficient retrieval of related documents - all you need do is +set :attr:`allow_inheritance` to True in the :attr:`meta` data for a +document.:: # Stored in a collection named 'page' class Page(Document): @@ -584,25 +596,26 @@ convenient and efficient retrieval of related documents:: class DatedPage(Page): date = DateTimeField() -.. note:: From 0.7 onwards you must declare `allow_inheritance` in the document meta. +.. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults + to False, meaning you must set it to True to use inheritance. Working with existing data -------------------------- -To enable correct retrieval of documents involved in this kind of heirarchy, -two extra attributes are stored on each document in the database: :attr:`_cls` -and :attr:`_types`. These are hidden from the user through the MongoEngine -interface, but may not be present if you are trying to use MongoEngine with -an existing database. For this reason, you may disable this inheritance -mechansim, removing the dependency of :attr:`_cls` and :attr:`_types`, enabling -you to work with existing databases. To disable inheritance on a document -class, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` -dictionary:: +As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and +easily get working with existing data. Just define the document to match +the expected schema in your database :: # Will work with data in an existing collection named 'cmsPage' class Page(Document): title = StringField(max_length=200, required=True) meta = { - 'collection': 'cmsPage', - 'allow_inheritance': False, + 'collection': 'cmsPage' } + +If you have wildly varying schemas then using a +:class:`~mongoengine.DynamicDocument` might be more appropriate, instead of +defining all possible field types. + +If you use :class:`~mongoengine.Document` and the database contains data that +isn't defined then that data will be stored in the `document._data` dictionary. diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index 54fa804b..f9a6610f 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -30,21 +30,53 @@ already exist, then any changes will be updated atomically. For example:: .. note:: - Changes to documents are tracked and on the whole perform `set` operations. + Changes to documents are tracked and on the whole perform ``set`` operations. - * ``list_field.pop(0)`` - *sets* the resulting list + * ``list_field.push(0)`` - *sets* the resulting list * ``del(list_field)`` - *unsets* whole list + With lists its preferable to use ``Doc.update(push__list_field=0)`` as + this stops the whole list being updated - stopping any race conditions. + .. seealso:: :ref:`guide-atomic-updates` +Pre save data validation and cleaning +------------------------------------- +MongoEngine allows you to create custom cleaning rules for your documents when +calling :meth:`~mongoengine.Document.save`. By providing a custom +:meth:`~mongoengine.Document.clean` method you can do any pre validation / data +cleaning. + +This might be useful if you want to ensure a default value based on other +document values for example:: + + class Essay(Document): + status = StringField(choices=('Published', 'Draft'), required=True) + pub_date = DateTimeField() + + def clean(self): + """Ensures that only published essays have a `pub_date` and + automatically sets the pub_date if published and not set""" + if self.status == 'Draft' and self.pub_date is not None: + msg = 'Draft entries should not have a publication date.' + raise ValidationError(msg) + # Set the pub_date for published items if not set. + if self.status == 'Published' and self.pub_date is None: + self.pub_date = datetime.now() + +.. note:: + Cleaning is only called if validation is turned on and when calling + :meth:`~mongoengine.Document.save`. + Cascading Saves --------------- -If your document contains :class:`~mongoengine.ReferenceField` or -:class:`~mongoengine.GenericReferenceField` objects, then by default the -:meth:`~mongoengine.Document.save` method will automatically save any changes to -those objects as well. If this is not desired passing :attr:`cascade` as False -to the save method turns this feature off. +If your document contains :class:`~mongoengine.fields.ReferenceField` or +:class:`~mongoengine.fields.GenericReferenceField` objects, then by default the +:meth:`~mongoengine.Document.save` method will not save any changes to +those objects. If you want all references to also be saved also, noting each +save is a separate query, then passing :attr:`cascade` as True +to the save method will cascade any saves. Deleting documents ------------------ diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst index 9c80a99e..d81bb922 100644 --- a/docs/guide/gridfs.rst +++ b/docs/guide/gridfs.rst @@ -7,7 +7,7 @@ GridFS Writing ------- -GridFS support comes in the form of the :class:`~mongoengine.FileField` field +GridFS support comes in the form of the :class:`~mongoengine.fields.FileField` field object. This field acts as a file-like object and provides a couple of different ways of inserting and retrieving data. Arbitrary metadata such as content type can also be stored alongside the files. In the following example, @@ -18,26 +18,16 @@ a document is created to store details about animals, including a photo:: family = StringField() photo = FileField() - marmot = Animal('Marmota', 'Sciuridae') - - marmot_photo = open('marmot.jpg', 'r') # Retrieve a photo from disk - marmot.photo = marmot_photo # Store photo in the document - marmot.photo.content_type = 'image/jpeg' # Store metadata - - marmot.save() - -Another way of writing to a :class:`~mongoengine.FileField` is to use the -:func:`put` method. This allows for metadata to be stored in the same call as -the file:: - - marmot.photo.put(marmot_photo, content_type='image/jpeg') + marmot = Animal(genus='Marmota', family='Sciuridae') + marmot_photo = open('marmot.jpg', 'r') + marmot.photo.put(marmot_photo, content_type = 'image/jpeg') marmot.save() Retrieval --------- -So using the :class:`~mongoengine.FileField` is just like using any other +So using the :class:`~mongoengine.fields.FileField` is just like using any other field. The file can also be retrieved just as easily:: marmot = Animal.objects(genus='Marmota').first() @@ -47,7 +37,7 @@ field. The file can also be retrieved just as easily:: Streaming --------- -Streaming data into a :class:`~mongoengine.FileField` is achieved in a +Streaming data into a :class:`~mongoengine.fields.FileField` is achieved in a slightly different manner. First, a new file must be created by calling the :func:`new_file` method. Data can then be written using :func:`write`:: diff --git a/docs/guide/installing.rst b/docs/guide/installing.rst index f15d3dbb..e93f0485 100644 --- a/docs/guide/installing.rst +++ b/docs/guide/installing.rst @@ -22,10 +22,10 @@ Alternatively, if you don't have setuptools installed, `download it from PyPi $ python setup.py install To use the bleeding-edge version of MongoEngine, you can get the source from -`GitHub `_ and install it as above: +`GitHub `_ and install it as above: .. code-block:: console - $ git clone git://github.com/hmarr/mongoengine + $ git clone git://github.com/mongoengine/mongoengine $ cd mongoengine $ python setup.py install diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 14498017..3a25c286 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -79,7 +79,7 @@ expressions: * ``match`` -- performs an $elemMatch so you can match an entire document within an array There are a few special operators for performing geographical queries, that -may used with :class:`~mongoengine.GeoPointField`\ s: +may used with :class:`~mongoengine.fields.GeoPointField`\ s: * ``within_distance`` -- provide a list containing a point and a maximum distance (e.g. [(41.342, -87.653), 5]) @@ -92,13 +92,15 @@ may used with :class:`~mongoengine.GeoPointField`\ s: * ``within_polygon`` -- filter documents to those within a given polygon (e.g. [(41.91,-87.69), (41.92,-87.68), (41.91,-87.65), (41.89,-87.65)]). .. note:: Requires Mongo Server 2.0 +* ``max_distance`` -- can be added to your location queries to set a maximum + distance. Querying lists -------------- On most fields, this syntax will look up documents where the field specified matches the given value exactly, but when the field refers to a -:class:`~mongoengine.ListField`, a single item may be provided, in which case +:class:`~mongoengine.fields.ListField`, a single item may be provided, in which case lists that contain that item will be matched:: class Page(Document): @@ -179,9 +181,11 @@ Retrieving unique results ------------------------- To retrieve a result that should be unique in the collection, use :meth:`~mongoengine.queryset.QuerySet.get`. This will raise -:class:`~mongoengine.queryset.DoesNotExist` if no document matches the query, -and :class:`~mongoengine.queryset.MultipleObjectsReturned` if more than one -document matched the query. +:class:`~mongoengine.queryset.DoesNotExist` if +no document matches the query, and +:class:`~mongoengine.queryset.MultipleObjectsReturned` +if more than one document matched the query. These exceptions are merged into +your document defintions eg: `MyDoc.DoesNotExist` A variation of this method exists, :meth:`~mongoengine.queryset.Queryset.get_or_create`, that will create a new @@ -315,7 +319,7 @@ Retrieving a subset of fields Sometimes a subset of fields on a :class:`~mongoengine.Document` is required, and for efficiency only these should be retrieved from the database. This issue is especially important for MongoDB, as fields may often be extremely large -(e.g. a :class:`~mongoengine.ListField` of +(e.g. a :class:`~mongoengine.fields.ListField` of :class:`~mongoengine.EmbeddedDocument`\ s, which represent the comments on a blog post. To select only a subset of fields, use :meth:`~mongoengine.queryset.QuerySet.only`, specifying the fields you want to @@ -347,14 +351,14 @@ If you later need the missing fields, just call Getting related data -------------------- -When iterating the results of :class:`~mongoengine.ListField` or -:class:`~mongoengine.DictField` we automatically dereference any +When iterating the results of :class:`~mongoengine.fields.ListField` or +:class:`~mongoengine.fields.DictField` we automatically dereference any :class:`~pymongo.dbref.DBRef` objects as efficiently as possible, reducing the number the queries to mongo. There are times when that efficiency is not enough, documents that have -:class:`~mongoengine.ReferenceField` objects or -:class:`~mongoengine.GenericReferenceField` objects at the top level are +:class:`~mongoengine.fields.ReferenceField` objects or +:class:`~mongoengine.fields.GenericReferenceField` objects at the top level are expensive as the number of queries to MongoDB can quickly rise. To limit the number of queries use @@ -365,8 +369,30 @@ references to the depth of 1 level. If you have more complicated documents and want to dereference more of the object at once then increasing the :attr:`max_depth` will dereference more levels of the document. +Turning off dereferencing +------------------------- + +Sometimes for performance reasons you don't want to automatically dereference +data. To turn off dereferencing of the results of a query use +:func:`~mongoengine.queryset.QuerySet.no_dereference` on the queryset like so:: + + post = Post.objects.no_dereference().first() + assert(isinstance(post.author, ObjectId)) + +You can also turn off all dereferencing for a fixed period by using the +:class:`~mongoengine.context_managers.no_dereference` context manager:: + + with no_dereference(Post) as Post: + post = Post.objects.first() + assert(isinstance(post.author, ObjectId)) + + # Outside the context manager dereferencing occurs. + assert(isinstance(post.author, User)) + + Advanced queries ================ + Sometimes calling a :class:`~mongoengine.queryset.QuerySet` object with keyword arguments can't fully express the query you want to use -- for example if you need to combine a number of constraints using *and* and *or*. This is made @@ -385,6 +411,11 @@ calling it with keyword arguments:: # Get top posts Post.objects((Q(featured=True) & Q(hits__gte=1000)) | Q(hits__gte=5000)) +.. warning:: You have to use bitwise operators. You cannot use ``or``, ``and`` + to combine queries as ``Q(a=a) or Q(b=b)`` is not the same as + ``Q(a=a) | Q(b=b)``. As ``Q(a=a)`` equates to true ``Q(a=a) or Q(b=b)`` is + the same as ``Q(a=a)``. + .. _guide-atomic-updates: Atomic updates @@ -425,7 +456,7 @@ modifier comes before the field, not after it:: >>> post.tags ['database', 'nosql'] -.. note :: +.. note:: In version 0.5 the :meth:`~mongoengine.Document.save` runs atomic updates on changed documents by tracking changes to that document. @@ -441,7 +472,7 @@ cannot use the `$` syntax in keyword arguments it has been mapped to `S`:: >>> post.tags ['database', 'mongodb'] -.. note :: +.. note:: Currently only top level lists are handled, future versions of mongodb / pymongo plan to support nested positional operators. See `The $ positional operator `_. @@ -510,7 +541,7 @@ Javascript code. When accessing a field on a collection object, use square-bracket notation, and prefix the MongoEngine field name with a tilde. The field name that follows the tilde will be translated to the name used in the database. Note that when referring to fields on embedded documents, -the name of the :class:`~mongoengine.EmbeddedDocumentField`, followed by a dot, +the name of the :class:`~mongoengine.fields.EmbeddedDocumentField`, followed by a dot, should be used before the name of the field on the embedded document. The following example shows how the substitutions are made:: diff --git a/docs/index.rst b/docs/index.rst index f6d44b51..4aca82da 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,16 +7,18 @@ MongoDB. To install it, simply run .. code-block:: console - # pip install -U mongoengine + $ pip install -U mongoengine :doc:`tutorial` - Start here for a quick overview. + A quick tutorial building a tumblelog to get you up and running with + MongoEngine. :doc:`guide/index` - The Full guide to MongoEngine + The Full guide to MongoEngine - from modeling documents to storing files, + from querying for data to firing signals and *everything* between. :doc:`apireference` - The complete API documentation. + The complete API documentation --- the innards of documents, querysets and fields. :doc:`upgrade` How to upgrade MongoEngine. @@ -28,35 +30,40 @@ Community --------- To get help with using MongoEngine, use the `MongoEngine Users mailing list -`_ or come chat on the -`#mongoengine IRC channel `_. +`_ or the ever popular +`stackoverflow `_. Contributing ------------ -The source is available on `GitHub `_ and -contributions are always encouraged. Contributions can be as simple as -minor tweaks to this documentation. To contribute, fork the project on +**Yes please!** We are always looking for contributions, additions and improvements. + +The source is available on `GitHub `_ +and contributions are always encouraged. Contributions can be as simple as +minor tweaks to this documentation, the website or the core. + +To contribute, fork the project on `GitHub `_ and send a pull request. -Also, you can join the developers' `mailing list -`_. - Changes ------- + See the :doc:`changelog` for a full list of changes to MongoEngine and :doc:`upgrade` for upgrade information. -.. toctree:: - :hidden: +.. note:: Always read and test the `upgrade `_ documentation before + putting updates live in production **;)** - tutorial - guide/index - apireference - django - changelog - upgrade +.. toctree:: + :hidden: + + tutorial + guide/index + apireference + django + changelog + upgrade Indices and tables ------------------ diff --git a/docs/tutorial.rst b/docs/tutorial.rst index a5284c8f..c2f481b9 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -1,6 +1,7 @@ ======== Tutorial ======== + This tutorial introduces **MongoEngine** by means of example --- we will walk through how to create a simple **Tumblelog** application. A Tumblelog is a type of blog where posts are not constrained to being conventional text-based posts. @@ -12,23 +13,29 @@ interface. Getting started =============== + Before we start, make sure that a copy of MongoDB is running in an accessible location --- running it locally will be easier, but if that is not an option -then it may be run on a remote server. +then it may be run on a remote server. If you haven't installed mongoengine, +simply use pip to install it like so:: + + $ pip install mongoengine Before we can start using MongoEngine, we need to tell it how to connect to our instance of :program:`mongod`. For this we use the :func:`~mongoengine.connect` -function. The only argument we need to provide is the name of the MongoDB -database to use:: +function. If running locally the only argument we need to provide is the name +of the MongoDB database to use:: from mongoengine import * connect('tumblelog') -For more information about connecting to MongoDB see :ref:`guide-connecting`. +There are lots of options for connecting to MongoDB, for more information about +them see the :ref:`guide-connecting` guide. Defining our documents ====================== + MongoDB is *schemaless*, which means that no schema is enforced by the database --- we may add and remove fields however we want and MongoDB won't complain. This makes life a lot easier in many regards, especially when there is a change @@ -39,17 +46,19 @@ define utility methods on our documents in the same way that traditional In our Tumblelog application we need to store several different types of information. We will need to have a collection of **users**, so that we may -link posts to an individual. We also need to store our different types -**posts** (text, image and link) in the database. To aid navigation of our +link posts to an individual. We also need to store our different types of +**posts** (eg: text, image and link) in the database. To aid navigation of our Tumblelog, posts may have **tags** associated with them, so that the list of posts shown to the user may be limited to posts that have been assigned a -specified tag. Finally, it would be nice if **comments** could be added to -posts. We'll start with **users**, as the others are slightly more involved. +specific tag. Finally, it would be nice if **comments** could be added to +posts. We'll start with **users**, as the other document models are slightly +more involved. Users ----- + Just as if we were using a relational database with an ORM, we need to define -which fields a :class:`User` may have, and what their types will be:: +which fields a :class:`User` may have, and what types of data they might store:: class User(Document): email = StringField(required=True) @@ -58,11 +67,13 @@ which fields a :class:`User` may have, and what their types will be:: This looks similar to how a the structure of a table would be defined in a regular ORM. The key difference is that this schema will never be passed on to -MongoDB --- this will only be enforced at the application level. Also, the User -documents will be stored in a MongoDB *collection* rather than a table. +MongoDB --- this will only be enforced at the application level, making future +changes easy to manage. Also, the User documents will be stored in a +MongoDB *collection* rather than a table. Posts, Comments and Tags ------------------------ + Now we'll think about how to store the rest of the information. If we were using a relational database, we would most likely have a table of **posts**, a table of **comments** and a table of **tags**. To associate the comments with @@ -75,21 +86,25 @@ of them stand out as particularly intuitive solutions. Posts ^^^^^ -But MongoDB *isn't* a relational database, so we're not going to do it that + +Happily mongoDB *isn't* a relational database, so we're not going to do it that way. As it turns out, we can use MongoDB's schemaless nature to provide us with -a much nicer solution. We will store all of the posts in *one collection* --- -each post type will just have the fields it needs. If we later want to add +a much nicer solution. We will store all of the posts in *one collection* and +each post type will only store the fields it needs. If we later want to add video posts, we don't have to modify the collection at all, we just *start using* the new fields we need to support video posts. This fits with the Object-Oriented principle of *inheritance* nicely. We can think of :class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and :class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports -this kind of modelling out of the box:: +this kind of modelling out of the box --- all you need do is turn on inheritance +by setting :attr:`allow_inheritance` to True in the :attr:`meta`:: class Post(Document): title = StringField(max_length=120, required=True) author = ReferenceField(User) + meta = {'allow_inheritance': True} + class TextPost(Post): content = StringField() @@ -100,12 +115,13 @@ this kind of modelling out of the box:: link_url = StringField() We are storing a reference to the author of the posts using a -:class:`~mongoengine.ReferenceField` object. These are similar to foreign key +:class:`~mongoengine.fields.ReferenceField` object. These are similar to foreign key fields in traditional ORMs, and are automatically translated into references when they are saved, and dereferenced when they are loaded. Tags ^^^^ + Now that we have our Post models figured out, how will we attach tags to them? MongoDB allows us to store lists of items natively, so rather than having a link table, we can just store a list of tags in each post. So, for both @@ -121,13 +137,16 @@ size of our database. So let's take a look that the code our modified author = ReferenceField(User) tags = ListField(StringField(max_length=30)) -The :class:`~mongoengine.ListField` object that is used to define a Post's tags +The :class:`~mongoengine.fields.ListField` object that is used to define a Post's tags takes a field object as its first argument --- this means that you can have -lists of any type of field (including lists). Note that we don't need to -modify the specialised post types as they all inherit from :class:`Post`. +lists of any type of field (including lists). + +.. note:: We don't need to modify the specialised post types as they all + inherit from :class:`Post`. Comments ^^^^^^^^ + A comment is typically associated with *one* post. In a relational database, to display a post with its comments, we would have to retrieve the post from the database, then query the database again for the comments associated with the @@ -155,7 +174,7 @@ We can then store a list of comment documents in our post document:: Handling deletions of references ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The :class:`~mongoengine.ReferenceField` object takes a keyword +The :class:`~mongoengine.fields.ReferenceField` object takes a keyword `reverse_delete_rule` for handling deletion rules if the reference is deleted. To delete all the posts if a user is deleted set the rule:: @@ -165,9 +184,9 @@ To delete all the posts if a user is deleted set the rule:: tags = ListField(StringField(max_length=30)) comments = ListField(EmbeddedDocumentField(Comment)) -See :class:`~mongoengine.ReferenceField` for more information. +See :class:`~mongoengine.fields.ReferenceField` for more information. -..note:: +.. note:: MapFields and DictFields currently don't support automatic handling of deleted references @@ -178,15 +197,15 @@ Now that we've defined how our documents will be structured, let's start adding some documents to the database. Firstly, we'll need to create a :class:`User` object:: - john = User(email='jdoe@example.com', first_name='John', last_name='Doe') - john.save() + ross = User(email='ross@example.com', first_name='Ross', last_name='Lawley').save() -Note that we could have also defined our user using attribute syntax:: +.. note:: + We could have also defined our user using attribute syntax:: - john = User(email='jdoe@example.com') - john.first_name = 'John' - john.last_name = 'Doe' - john.save() + ross = User(email='ross@example.com') + ross.first_name = 'Ross' + ross.last_name = 'Lawley' + ross.save() Now that we've got our user in the database, let's add a couple of posts:: @@ -195,16 +214,17 @@ Now that we've got our user in the database, let's add a couple of posts:: post1.tags = ['mongodb', 'mongoengine'] post1.save() - post2 = LinkPost(title='MongoEngine Documentation', author=john) - post2.link_url = 'http://tractiondigital.com/labs/mongoengine/docs' + post2 = LinkPost(title='MongoEngine Documentation', author=ross) + post2.link_url = 'http://docs.mongoengine.com/' post2.tags = ['mongoengine'] post2.save() -Note that if you change a field on a object that has already been saved, then -call :meth:`save` again, the document will be updated. +.. note:: If you change a field on a object that has already been saved, then + call :meth:`save` again, the document will be updated. Accessing our data ================== + So now we've got a couple of posts in our database, how do we display them? Each document class (i.e. any class that inherits either directly or indirectly from :class:`~mongoengine.Document`) has an :attr:`objects` attribute, which is @@ -216,6 +236,7 @@ class. So let's see how we can get our posts' titles:: Retrieving type-specific information ------------------------------------ + This will print the titles of our posts, one on each line. But What if we want to access the type-specific data (link_url, content, etc.)? One way is simply to use the :attr:`objects` attribute of a subclass of :class:`Post`:: @@ -254,6 +275,7 @@ text post, and "Link: " if it was a link post. Searching our posts by tag -------------------------- + The :attr:`objects` attribute of a :class:`~mongoengine.Document` is actually a :class:`~mongoengine.queryset.QuerySet` object. This lazily queries the database only when you need the data. It may also be filtered to narrow down @@ -272,3 +294,9 @@ used on :class:`~mongoengine.queryset.QuerySet` objects:: num_posts = Post.objects(tags='mongodb').count() print 'Found %d posts with tag "mongodb"' % num_posts +Learning more about mongoengine +------------------------------- + +If you got this far you've made a great start, so well done! The next step on +your mongoengine journey is the `full user guide `_, where you +can learn indepth about how to use mongoengine and mongodb. \ No newline at end of file diff --git a/docs/upgrade.rst b/docs/upgrade.rst index 901c251d..bb5705ca 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -1,12 +1,331 @@ -========= +######### Upgrading -========= +######### -0.6 to 0.7 +0.7 to 0.8 +********** + +There have been numerous backwards breaking changes in 0.8. The reasons for +these are ensure that MongoEngine has sane defaults going forward and +performs the best it can out the box. Where possible there have been +FutureWarnings to help get you ready for the change, but that hasn't been +possible for the whole of the release. + +.. warning:: Breaking changes - test upgrading on a test system before putting + live. There maybe multiple manual steps in migrating and these are best honed + on a staging / test system. + +Python +======= + +Support for python 2.5 has been dropped. + +Data Model ========== +Inheritance +----------- + +The inheritance model has changed, we no longer need to store an array of +:attr:`types` with the model we can just use the classname in :attr:`_cls`. +This means that you will have to update your indexes for each of your +inherited classes like so: :: + + # 1. Declaration of the class + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Remove _types + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({'_types': {"$exists": True}}).count() + assert count == 0 + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.ensure_indexes() + + +Document Definition +------------------- + +The default for inheritance has changed - its now off by default and +:attr:`_cls` will not be stored automatically with the class. So if you extend +your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` +you will need to declare :attr:`allow_inheritance` in the meta data like so: :: + + class Animal(Document): + name = StringField() + + meta = {'allow_inheritance': True} + +Previously, if you had data the database that wasn't defined in the Document +definition, it would set it as an attribute on the document. This is no longer +the case and the data is set only in the ``document._data`` dictionary: :: + + >>> from mongoengine import * + >>> class Animal(Document): + ... name = StringField() + ... + >>> cat = Animal(name="kit", size="small") + + # 0.7 + >>> cat.size + u'small' + + # 0.8 + >>> cat.size + Traceback (most recent call last): + File "", line 1, in + AttributeError: 'Animal' object has no attribute 'size' + +ReferenceField +-------------- + +ReferenceFields now store ObjectId's by default - this is more efficient than +DBRefs as we already know what Document types they reference:: + + # Old code + class Animal(Document): + name = ReferenceField('self') + + # New code to keep dbrefs + class Animal(Document): + name = ReferenceField('self', dbref=True) + +To migrate all the references you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + # Mark all ReferenceFields as dirty and save + for p in Person.objects: + p._mark_as_dirty('parent') + p._mark_as_dirty('friends') + p.save() + +`An example test migration for ReferenceFields is available on github +`_. + +UUIDField +--------- + +UUIDFields now default to storing binary values:: + + # Old code + class Animal(Document): + uuid = UUIDField() + + # New code + class Animal(Document): + uuid = UUIDField(binary=False) + +To migrate all the uuid's you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Animal(Document): + uuid = UUIDField() + + # Mark all ReferenceFields as dirty and save + for a in Animal.objects: + a._mark_as_dirty('uuid') + a.save() + +`An example test migration for UUIDFields is available on github +`_. + +DecimalField +------------ + +DecimalField now store floats - previous it was storing strings and that +made it impossible to do comparisons when querying correctly.:: + + # Old code + class Person(Document): + balance = DecimalField() + + # New code + class Person(Document): + balance = DecimalField(force_string=True) + +To migrate all the uuid's you need to touch each object and mark it as dirty +eg:: + + # Doc definition + class Person(Document): + balance = DecimalField() + + # Mark all ReferenceFields as dirty and save + for p in Person.objects: + p._mark_as_dirty('balance') + p.save() + +.. note:: DecimalField's have also been improved with the addition of precision + and rounding. See :class:`~mongoengine.fields.DecimalField` for more information. + +`An example test migration for DecimalFields is available on github +`_. + +Cascading Saves +--------------- +To improve performance document saves will no longer automatically cascade. +Any changes to a Documents references will either have to be saved manually or +you will have to explicitly tell it to cascade on save:: + + # At the class level: + class Person(Document): + meta = {'cascade': True} + + # Or on save: + my_document.save(cascade=True) + +Storage +------- + +Document and Embedded Documents are now serialized based on declared field order. +Previously, the data was passed to mongodb as a dictionary and which meant that +order wasn't guaranteed - so things like ``$addToSet`` operations on +:class:`~mongoengine.EmbeddedDocument` could potentially fail in unexpected +ways. + +If this impacts you, you may want to rewrite the objects using the +``doc.mark_as_dirty('field')`` pattern described above. If you are using a +compound primary key then you will need to ensure the order is fixed and match +your EmbeddedDocument to that order. + +Querysets +========= + +Attack of the clones +-------------------- + +Querysets now return clones and should no longer be considered editable in +place. This brings us in line with how Django's querysets work and removes a +long running gotcha. If you edit your querysets inplace you will have to +update your code like so: :: + + # Old code: + mammals = Animal.objects(type="mammal") + mammals.filter(order="Carnivora") # Returns a cloned queryset that isn't assigned to anything - so this will break in 0.8 + [m for m in mammals] # This will return all mammals in 0.8 as the 2nd filter returned a new queryset + + # Update example a) assign queryset after a change: + mammals = Animal.objects(type="mammal") + carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied + [m for m in carnivores] # This will return all carnivores + + # Update example b) chain the queryset: + mammals = Animal.objects(type="mammal").filter(order="Carnivora") # The final queryset is assgined to mammals + [m for m in mammals] # This will return all carnivores + +No more len +----------- + +If you ever did len(queryset) it previously did a count() under the covers, this +caused some unusual issues - so now it has been removed in favour of the +explicit `queryset.count()` to update:: + + # Old code + len(Animal.objects(type="mammal")) + + # New code + Animal.objects(type="mammal").count()) + + +.only() now inline with .exclude() +---------------------------------- + +The behaviour of `.only()` was highly ambious, now it works in the mirror fashion +to `.exclude()`. Chaining `.only()` calls will increase the fields required:: + + # Old code + Animal.objects().only(['type', 'name']).only('name', 'order') # Would have returned just `name` + + # New code + Animal.objects().only('name') + + # Note: + Animal.objects().only(['name']).only('order') # Now returns `name` *and* `order` + + +Client +====== +PyMongo 2.4 came with a new connection client; MongoClient_ and started the +depreciation of the old :class:`~pymongo.connection.Connection`. MongoEngine +now uses the latest `MongoClient` for connections. By default operations were +`safe` but if you turned them off or used the connection directly this will +impact your queries. + +Querysets +--------- + +Safe +^^^^ + +`safe` has been depreciated in the new MongoClient connection. Please use +`write_concern` instead. As `safe` always defaulted as `True` normally no code +change is required. To disable confirmation of the write just pass `{"w": 0}` +eg: :: + + # Old + Animal(name="Dinasour").save(safe=False) + + # new code: + Animal(name="Dinasour").save(write_concern={"w": 0}) + +Write Concern +^^^^^^^^^^^^^ + +`write_options` has been replaced with `write_concern` to bring it inline with +pymongo. To upgrade simply rename any instances where you used the `write_option` +keyword to `write_concern` like so:: + + # Old code: + Animal(name="Dinasour").save(write_options={"w": 2}) + + # new code: + Animal(name="Dinasour").save(write_concern={"w": 2}) + + +Indexes +======= + +Index methods are no longer tied to querysets but rather to the document class. +Although `QuerySet._ensure_indexes` and `QuerySet.ensure_index` still exist. +They should be replaced with :func:`~mongoengine.Document.ensure_indexes` / +:func:`~mongoengine.Document.ensure_index`. + +SequenceFields +============== + +:class:`~mongoengine.fields.SequenceField` now inherits from `BaseField` to +allow flexible storage of the calculated value. As such MIN and MAX settings +are no longer handled. + +.. _MongoClient: http://blog.mongodb.org/post/36666163412/introducing-mongoclient + +0.6 to 0.7 +********** + Cascade saves -------------- +============= Saves will raise a `FutureWarning` if they cascade and cascade hasn't been set to True. This is because in 0.8 it will default to False. If you require @@ -20,11 +339,11 @@ via `save` eg :: # Or in code: my_document.save(cascade=True) -.. note :: +.. note:: Remember: cascading saves **do not** cascade through lists. ReferenceFields ---------------- +=============== ReferenceFields now can store references as ObjectId strings instead of DBRefs. This will become the default in 0.8 and if `dbref` is not set a `FutureWarning` @@ -53,7 +372,7 @@ migrate :: item_frequencies ----------------- +================ In the 0.6 series we added support for null / zero / false values in item_frequencies. A side effect was to return keys in the value they are @@ -62,14 +381,14 @@ updated to handle native types rather than strings keys for the results of item frequency queries. BinaryFields ------------- +============ Binary fields have been updated so that they are native binary types. If you previously were doing `str` comparisons with binary field values you will have to update and wrap the value in a `str`. 0.5 to 0.6 -========== +********** Embedded Documents - if you had a `pk` field you will have to rename it from `_id` to `pk` as pk is no longer a property of Embedded Documents. @@ -84,18 +403,18 @@ Document.objects.with_id - now raises an InvalidQueryError if used with a filter. FutureWarning - A future warning has been added to all inherited classes that -don't define `allow_inheritance` in their meta. +don't define :attr:`allow_inheritance` in their meta. You may need to update pyMongo to 2.0 for use with Sharding. 0.4 to 0.5 -=========== +********** There have been the following backwards incompatibilities from 0.4 to 0.5. The main areas of changed are: choices in fields, map_reduce and collection names. Choice options: ---------------- +=============== Are now expected to be an iterable of tuples, with the first element in each tuple being the actual value to be stored. The second element is the @@ -103,7 +422,7 @@ human-readable name for the option. PyMongo / MongoDB ------------------ +================= map reduce now requires pymongo 1.11+- The pymongo `merge_output` and `reduce_output` parameters, have been depreciated. @@ -117,7 +436,7 @@ such the following have been changed: Default collection naming -------------------------- +========================= Previously it was just lowercase, its now much more pythonic and readable as its lowercase and underscores, previously :: @@ -187,5 +506,5 @@ Alternatively, you can rename your collections eg :: mongodb 1.8 > 2.0 + =================== -Its been reported that indexes may need to be recreated to the newer version of indexes. +Its been reported that indexes may need to be recreated to the newer version of indexes. To do this drop indexes and call ``ensure_indexes`` on each model. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index b67512d7..6fe6d088 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -8,11 +8,14 @@ import queryset from queryset import * import signals from signals import * +from errors import * +import errors +import django -__all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__ + signals.__all__) +__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 7, 9) +VERSION = (0, 8, 0, '+') def get_version(): diff --git a/mongoengine/base.py b/mongoengine/base.py deleted file mode 100644 index fa6f825f..00000000 --- a/mongoengine/base.py +++ /dev/null @@ -1,1524 +0,0 @@ -import operator -import sys -import warnings -import weakref - -from collections import defaultdict -from functools import partial - -from queryset import QuerySet, QuerySetManager -from queryset import DoesNotExist, MultipleObjectsReturned -from queryset import DO_NOTHING - -from mongoengine import signals -from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, - to_str_keys_recursive) - -import pymongo -from bson import ObjectId -from bson.dbref import DBRef - -ALLOW_INHERITANCE = True - -_document_registry = {} -_class_registry = {} - - -class NotRegistered(Exception): - pass - - -class InvalidDocumentError(Exception): - pass - - -class ValidationError(AssertionError): - """Validation exception. - - May represent an error validating a field or a - document containing fields with validation errors. - - :ivar errors: A dictionary of errors for fields within this - document or list, or None if the error is for an - individual field. - """ - - errors = {} - field_name = None - _message = None - - def __init__(self, message="", **kwargs): - self.errors = kwargs.get('errors', {}) - self.field_name = kwargs.get('field_name') - self.message = message - - def __str__(self): - return txt_type(self.message) - - def __repr__(self): - return '%s(%s,)' % (self.__class__.__name__, self.message) - - def __getattribute__(self, name): - message = super(ValidationError, self).__getattribute__(name) - if name == 'message': - if self.field_name: - message = '%s' % message - if self.errors: - message = '%s(%s)' % (message, self._format_errors()) - return message - - def _get_message(self): - return self._message - - def _set_message(self, message): - self._message = message - - message = property(_get_message, _set_message) - - def to_dict(self): - """Returns a dictionary of all errors within a document - - Keys are field names or list indices and values are the - validation error messages, or a nested dictionary of - errors for an embedded document or list. - """ - - def build_dict(source): - errors_dict = {} - if not source: - return errors_dict - if isinstance(source, dict): - for field_name, error in source.iteritems(): - errors_dict[field_name] = build_dict(error) - elif isinstance(source, ValidationError) and source.errors: - return build_dict(source.errors) - else: - return unicode(source) - return errors_dict - if not self.errors: - return {} - return build_dict(self.errors) - - def _format_errors(self): - """Returns a string listing all errors within a document""" - - def generate_key(value, prefix=''): - if isinstance(value, list): - value = ' '.join([generate_key(k) for k in value]) - if isinstance(value, dict): - value = ' '.join( - [generate_key(v, k) for k, v in value.iteritems()]) - - results = "%s.%s" % (prefix, value) if prefix else value - return results - - error_dict = defaultdict(list) - for k, v in self.to_dict().iteritems(): - error_dict[generate_key(v)].append(k) - return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) - - -def get_document(name): - doc = _document_registry.get(name, None) - if not doc: - # Possible old style name - single_end = name.split('.')[-1] - compound_end = '.%s' % single_end - possible_match = [k for k in _document_registry.keys() - if k.endswith(compound_end) or k == single_end] - if len(possible_match) == 1: - doc = _document_registry.get(possible_match.pop(), None) - if not doc: - raise NotRegistered(""" - `%s` has not been registered in the document registry. - Importing the document class automatically registers it, has it - been imported? - """.strip() % name) - return doc - - -class BaseField(object): - """A base class for fields in a MongoDB document. Instances of this class - may be added to subclasses of `Document` to define a document's schema. - - .. versionchanged:: 0.5 - added verbose and help text - """ - - name = None - - # Fields may have _types inserted into indexes by default - _index_with_types = True - _geo_index = False - - # These track each time a Field instance is created. Used to retain order. - # The auto_creation_counter is used for fields that MongoEngine implicitly - # creates, creation_counter is used for all user-specified fields. - creation_counter = 0 - auto_creation_counter = -1 - - def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False, - validation=None, choices=None, verbose_name=None, - help_text=None): - self.db_field = (db_field or name) if not primary_key else '_id' - if name: - msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" - warnings.warn(msg, DeprecationWarning) - self.name = None - self.required = required or primary_key - self.default = default - self.unique = bool(unique or unique_with) - self.unique_with = unique_with - self.primary_key = primary_key - self.validation = validation - self.choices = choices - self.verbose_name = verbose_name - self.help_text = help_text - - # Adjust the appropriate creation counter, and save our local copy. - if self.db_field == '_id': - self.creation_counter = BaseField.auto_creation_counter - BaseField.auto_creation_counter -= 1 - else: - self.creation_counter = BaseField.creation_counter - BaseField.creation_counter += 1 - - def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do - any necessary conversion between Python and MongoDB types. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available, if not use default - value = instance._data.get(self.name) - - if value is None: - value = self.default - # Allow callable default values - if callable(value): - value = value() - - return value - - def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. - """ - instance._data[self.name] = value - if instance._initialised: - instance._mark_as_changed(self.name) - - def error(self, message="", errors=None, field_name=None): - """Raises a ValidationError. - """ - field_name = field_name if field_name else self.name - raise ValidationError(message, errors=errors, field_name=field_name) - - def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - return value - - def to_mongo(self, value): - """Convert a Python type to a MongoDB-compatible type. - """ - return self.to_python(value) - - def prepare_query_value(self, op, value): - """Prepare a value that is being used in a query for PyMongo. - """ - return value - - def validate(self, value): - """Perform validation on a value. - """ - pass - - def _validate(self, value): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - # check choices - if self.choices: - is_cls = isinstance(value, (Document, EmbeddedDocument)) - value_to_check = value.__class__ if is_cls else value - err_msg = 'an instance' if is_cls else 'one' - if isinstance(self.choices[0], (list, tuple)): - option_keys = [k for k, v in self.choices] - if value_to_check not in option_keys: - msg = ('Value must be %s of %s' % - (err_msg, unicode(option_keys))) - self.error(msg) - elif value_to_check not in self.choices: - msg = ('Value must be %s of %s' % - (err_msg, unicode(self.choices))) - self.error(msg) - - # check validation argument - if self.validation is not None: - if callable(self.validation): - if not self.validation(value): - self.error('Value does not match custom validation method') - else: - raise ValueError('validation argument for "%s" must be a ' - 'callable.' % self.name) - - self.validate(value) - - -class ComplexBaseField(BaseField): - """Handles complex fields, such as lists / dictionaries. - - Allows for nesting of embedded documents inside complex types. - Handles the lazy dereferencing of a queryset by lazily dereferencing all - items in a list / dict rather than one at a time. - - .. versionadded:: 0.5 - """ - - field = None - __dereference = False - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - dereference = self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)) - if not self._dereference and instance._initialised and dereference: - instance._data[self.name] = self._dereference( - instance._data.get(self.name), max_depth=1, instance=instance, - name=self.name - ) - - value = super(ComplexBaseField, self).__get__(instance, owner) - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, instance, self.name) - instance._data[self.name] = value - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance, self.name) - instance._data[self.name] = value - - if (instance._initialised and isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): - value = self._dereference( - value, max_depth=1, instance=instance, name=self.name - ) - value._dereferenced = True - instance._data[self.name] = value - - return value - - def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. - """ - instance._data[self.name] = value - instance._mark_as_changed(self.name) - - def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - Document = _import_class('Document') - - if isinstance(value, basestring): - return value - - if hasattr(value, 'to_python'): - return value.to_python() - - is_list = False - if not hasattr(value, 'items'): - try: - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - except TypeError: # Not iterable return the value - return value - - if self.field: - value_dict = dict([(key, self.field.to_python(item)) - for key, item in value.items()]) - else: - value_dict = {} - for k, v in value.items(): - if isinstance(v, Document): - # We need the id from the saved object to create the DBRef - if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') - collection = v._get_collection_name() - value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_python'): - value_dict[k] = v.to_python() - else: - value_dict[k] = self.to_python(v) - - if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] - return value_dict - - def to_mongo(self, value): - """Convert a Python type to a MongoDB-compatible type. - """ - Document = _import_class("Document") - - if isinstance(value, basestring): - return value - - if hasattr(value, 'to_mongo'): - return value.to_mongo() - - is_list = False - if not hasattr(value, 'items'): - try: - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - except TypeError: # Not iterable return the value - return value - - if self.field: - value_dict = dict([(key, self.field.to_mongo(item)) - for key, item in value.items()]) - else: - value_dict = {} - for k, v in value.items(): - if isinstance(v, Document): - # We need the id from the saved object to create the DBRef - if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') - - # If its a document that is not inheritable it won't have - # _types / _cls data so make it a generic reference allows - # us to dereference - meta = getattr(v, '_meta', {}) - allow_inheritance = ( - meta.get('allow_inheritance', ALLOW_INHERITANCE) - == False) - if allow_inheritance and not self.field: - GenericReferenceField = _import_class("GenericReferenceField") - value_dict[k] = GenericReferenceField().to_mongo(v) - else: - collection = v._get_collection_name() - value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_mongo'): - value_dict[k] = v.to_mongo() - else: - value_dict[k] = self.to_mongo(v) - - if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] - return value_dict - - def validate(self, value): - """If field is provided ensure the value is valid. - """ - errors = {} - if self.field: - if hasattr(value, 'iteritems') or hasattr(value, 'items'): - sequence = value.iteritems() - else: - sequence = enumerate(value) - for k, v in sequence: - try: - self.field._validate(v) - except ValidationError, error: - errors[k] = error.errors or error - except (ValueError, AssertionError), error: - errors[k] = error - - if errors: - field_class = self.field.__class__.__name__ - self.error('Invalid %s item (%s)' % (field_class, value), - errors=errors) - # Don't allow empty values if required - if self.required and not value: - self.error('Field is required and cannot be empty') - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - - def lookup_member(self, member_name): - if self.field: - return self.field.lookup_member(member_name) - return None - - def _set_owner_document(self, owner_document): - if self.field: - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) - - @property - def _dereference(self,): - if not self.__dereference: - DeReference = _import_class("DeReference") - self.__dereference = DeReference() # Cached - return self.__dereference - - -class ObjectIdField(BaseField): - """An field wrapper around MongoDB's ObjectIds. - """ - - def to_python(self, value): - if not isinstance(value, ObjectId): - value = ObjectId(value) - return value - - def to_mongo(self, value): - if not isinstance(value, ObjectId): - try: - return ObjectId(unicode(value)) - except Exception, e: - # e.message attribute has been deprecated since Python 2.6 - self.error(unicode(e)) - return value - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - - def validate(self, value): - try: - ObjectId(unicode(value)) - except: - self.error('Invalid Object ID') - - -class DocumentMetaclass(type): - """Metaclass for all documents. - """ - - def __new__(cls, name, bases, attrs): - flattened_bases = cls._get_bases(bases) - super_new = super(DocumentMetaclass, cls).__new__ - - # If a base class just call super - metaclass = attrs.get('my_metaclass') - if metaclass and issubclass(metaclass, DocumentMetaclass): - return super_new(cls, name, bases, attrs) - - attrs['_is_document'] = attrs.get('_is_document', False) - - # EmbeddedDocuments could have meta data for inheritance - if 'meta' in attrs: - attrs['_meta'] = attrs.pop('meta') - - # Handle document Fields - - # Merge all fields from subclasses - doc_fields = {} - for base in flattened_bases[::-1]: - if hasattr(base, '_fields'): - doc_fields.update(base._fields) - - # Standard object mixin - merge in any Fields - if not hasattr(base, '_meta'): - base_fields = {} - for attr_name, attr_value in base.__dict__.iteritems(): - if not isinstance(attr_value, BaseField): - continue - attr_value.name = attr_name - if not attr_value.db_field: - attr_value.db_field = attr_name - base_fields[attr_name] = attr_value - doc_fields.update(base_fields) - - # Discover any document fields - field_names = {} - for attr_name, attr_value in attrs.iteritems(): - if not isinstance(attr_value, BaseField): - continue - attr_value.name = attr_name - if not attr_value.db_field: - attr_value.db_field = attr_name - doc_fields[attr_name] = attr_value - - # Count names to ensure no db_field redefinitions - field_names[attr_value.db_field] = field_names.get( - attr_value.db_field, 0) + 1 - - # Ensure no duplicate db_fields - duplicate_db_fields = [k for k, v in field_names.items() if v > 1] - if duplicate_db_fields: - msg = ("Multiple db_fields defined for: %s " % - ", ".join(duplicate_db_fields)) - raise InvalidDocumentError(msg) - - # Set _fields and db_field maps - attrs['_fields'] = doc_fields - attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) - - # - # Set document hierarchy - # - superclasses = {} - class_name = [name] - for base in flattened_bases: - if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): - # Collate heirarchy for _cls and _types - class_name.append(base.__name__) - - # Get superclasses from superclass - superclasses[base._class_name] = base - superclasses.update(base._superclasses) - - if hasattr(base, '_meta'): - # Warn if allow_inheritance isn't set and prevent - # inheritance of classes where inheritance is set to False - allow_inheritance = base._meta.get('allow_inheritance', - ALLOW_INHERITANCE) - if (not getattr(base, '_is_base_cls', True) - and allow_inheritance is None): - warnings.warn( - "%s uses inheritance, the default for " - "allow_inheritance is changing to off by default. " - "Please add it to the document meta." % name, - FutureWarning - ) - elif (allow_inheritance == False and - not base._meta.get('abstract')): - raise ValueError('Document %s may not be subclassed' % - base.__name__) - - attrs['_class_name'] = '.'.join(reversed(class_name)) - attrs['_superclasses'] = superclasses - - # Create the new_class - new_class = super_new(cls, name, bases, attrs) - - # Handle delete rules - Document, EmbeddedDocument, DictField = cls._import_classes() - for field in new_class._fields.itervalues(): - f = field - f.owner_document = new_class - delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) - if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): - delete_rule = getattr(f.field, - 'reverse_delete_rule', - DO_NOTHING) - if isinstance(f, DictField) and delete_rule != DO_NOTHING: - msg = ("Reverse delete rules are not supported " - "for %s (field: %s)" % - (field.__class__.__name__, field.name)) - raise InvalidDocumentError(msg) - - f = field.field - - if delete_rule != DO_NOTHING: - if issubclass(new_class, EmbeddedDocument): - msg = ("Reverse delete rules are not supported for " - "EmbeddedDocuments (field: %s)" % field.name) - raise InvalidDocumentError(msg) - f.document_type.register_delete_rule(new_class, - field.name, delete_rule) - - if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): - msg = ("%s is a document method and not a valid " - "field name" % field.name) - raise InvalidDocumentError(msg) - - # Add class to the _document_registry - _document_registry[new_class._class_name] = new_class - - # In Python 2, User-defined methods objects have special read-only - # attributes 'im_func' and 'im_self' which contain the function obj - # and class instance object respectively. With Python 3 these special - # attributes have been replaced by __func__ and __self__. The Blinker - # module continues to use im_func and im_self, so the code below - # copies __func__ into im_func and __self__ into im_self for - # classmethod objects in Document derived classes. - if PY3: - for key, val in new_class.__dict__.items(): - if isinstance(val, classmethod): - f = val.__get__(new_class) - if hasattr(f, '__func__') and not hasattr(f, 'im_func'): - f.__dict__.update({'im_func': getattr(f, '__func__')}) - if hasattr(f, '__self__') and not hasattr(f, 'im_self'): - f.__dict__.update({'im_self': getattr(f, '__self__')}) - - return new_class - - def add_to_class(self, name, value): - setattr(self, name, value) - - @classmethod - def _get_bases(cls, bases): - if isinstance(bases, BasesTuple): - return bases - seen = [] - bases = cls.__get_bases(bases) - unique_bases = (b for b in bases if not (b in seen or seen.append(b))) - return BasesTuple(unique_bases) - - @classmethod - def __get_bases(cls, bases): - for base in bases: - if base is object: - continue - yield base - for child_base in cls.__get_bases(base.__bases__): - yield child_base - - @classmethod - def _import_classes(cls): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - DictField = _import_class('DictField') - return (Document, EmbeddedDocument, DictField) - - -class TopLevelDocumentMetaclass(DocumentMetaclass): - """Metaclass for top-level documents (i.e. documents that have their own - collection in the database. - """ - - def __new__(cls, name, bases, attrs): - flattened_bases = cls._get_bases(bases) - super_new = super(TopLevelDocumentMetaclass, cls).__new__ - - # Set default _meta data if base class, otherwise get user defined meta - if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass): - # defaults - attrs['_meta'] = { - 'abstract': True, - 'max_documents': None, - 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime - 'id_field': None, - 'index_background': False, - 'index_drop_dups': False, - 'index_opts': None, - 'delete_rules': None, - 'allow_inheritance': None, - } - attrs['_is_base_cls'] = True - attrs['_meta'].update(attrs.get('meta', {})) - else: - attrs['_meta'] = attrs.get('meta', {}) - # Explictly set abstract to false unless set - attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) - attrs['_is_base_cls'] = False - - # Set flag marking as document class - as opposed to an object mixin - attrs['_is_document'] = True - - # Ensure queryset_class is inherited - if 'objects' in attrs: - manager = attrs['objects'] - if hasattr(manager, 'queryset_class'): - attrs['_meta']['queryset_class'] = manager.queryset_class - - # Clean up top level meta - if 'meta' in attrs: - del(attrs['meta']) - - # Find the parent document class - parent_doc_cls = [b for b in flattened_bases - if b.__class__ == TopLevelDocumentMetaclass] - parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] - - # Prevent classes setting collection different to their parents - # If parent wasn't an abstract class - if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) - and not parent_doc_cls._meta.get('abstract', True)): - msg = "Trying to set a collection on a subclass (%s)" % name - warnings.warn(msg, SyntaxWarning) - del(attrs['_meta']['collection']) - - # Ensure abstract documents have abstract bases - if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): - if (parent_doc_cls and - not parent_doc_cls._meta.get('abstract', False)): - msg = "Abstract document cannot have non-abstract base" - raise ValueError(msg) - return super_new(cls, name, bases, attrs) - - # Merge base class metas. - # Uses a special MetaDict that handles various merging rules - meta = MetaDict() - for base in flattened_bases[::-1]: - # Add any mixin metadata from plain objects - if hasattr(base, 'meta'): - meta.merge(base.meta) - elif hasattr(base, '_meta'): - meta.merge(base._meta) - - # Set collection in the meta if its callable - if (getattr(base, '_is_document', False) and - not base._meta.get('abstract')): - collection = meta.get('collection', None) - if callable(collection): - meta['collection'] = collection(base) - - meta.merge(attrs.get('_meta', {})) # Top level meta - - # Only simple classes (direct subclasses of Document) - # may set allow_inheritance to False - simple_class = all([b._meta.get('abstract') - for b in flattened_bases if hasattr(b, '_meta')]) - if (not simple_class and meta['allow_inheritance'] == False and - not meta['abstract']): - raise ValueError('Only direct subclasses of Document may set ' - '"allow_inheritance" to False') - - # Set default collection name - if 'collection' not in meta: - meta['collection'] = ''.join('_%s' % c if c.isupper() else c - for c in name).strip('_').lower() - attrs['_meta'] = meta - - # Call super and get the new class - new_class = super_new(cls, name, bases, attrs) - - meta = new_class._meta - - # Set index specifications - meta['index_specs'] = [QuerySet._build_index_spec(new_class, spec) - for spec in meta['indexes']] - unique_indexes = cls._unique_with_indexes(new_class) - new_class._meta['unique_indexes'] = unique_indexes - - # If collection is a callable - call it and set the value - collection = meta.get('collection') - if callable(collection): - new_class._meta['collection'] = collection(new_class) - - # Provide a default queryset unless one has been set - manager = attrs.get('objects', QuerySetManager()) - new_class.objects = manager - - # Validate the fields and set primary key if needed - for field_name, field in new_class._fields.iteritems(): - if field.primary_key: - # Ensure only one primary key is set - current_pk = new_class._meta.get('id_field') - if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') - - # Set primary key - if not current_pk: - new_class._meta['id_field'] = field_name - new_class.id = field - - # Set primary key if not defined by the document - if not new_class._meta.get('id_field'): - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class.id = new_class._fields['id'] - - # Merge in exceptions with parent hierarchy - exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) - module = attrs.get('__module__') - for exc in exceptions_to_merge: - name = exc.__name__ - parents = tuple(getattr(base, name) for base in flattened_bases - if hasattr(base, name)) or (exc,) - # Create new exception and set to new_class - exception = type(name, parents, {'__module__': module}) - setattr(new_class, name, exception) - - return new_class - - @classmethod - def _unique_with_indexes(cls, new_class, namespace=""): - """ - Find and set unique indexes - """ - unique_indexes = [] - for field_name, field in new_class._fields.items(): - # Generate a list of indexes needed by uniqueness constraints - if field.unique: - field.required = True - unique_fields = [field.db_field] - - # Add any unique_with fields to the back of the index spec - if field.unique_with: - if isinstance(field.unique_with, basestring): - field.unique_with = [field.unique_with] - - # Convert unique_with field names to real field names - unique_with = [] - for other_name in field.unique_with: - parts = other_name.split('.') - # Lookup real name - parts = QuerySet._lookup_field(new_class, parts) - name_parts = [part.db_field for part in parts] - unique_with.append('.'.join(name_parts)) - # Unique field should be required - parts[-1].required = True - unique_fields += unique_with - - # Add the new index to the list - index = [("%s%s" % (namespace, f), pymongo.ASCENDING) - for f in unique_fields] - unique_indexes.append(index) - - # Grab any embedded document field unique indexes - if (field.__class__.__name__ == "EmbeddedDocumentField" and - field.document_type != new_class): - field_namespace = "%s." % field_name - unique_indexes += cls._unique_with_indexes(field.document_type, - field_namespace) - - return unique_indexes - - -class MetaDict(dict): - """Custom dictionary for meta classes. - Handles the merging of set indexes - """ - _merge_options = ('indexes',) - - def merge(self, new_options): - for k, v in new_options.iteritems(): - if k in self._merge_options: - self[k] = self.get(k, []) + v - else: - self[k] = v - - -class BaseDocument(object): - - _dynamic = False - _created = True - _dynamic_lock = True - _initialised = False - - def __init__(self, **values): - signals.pre_init.send(self.__class__, document=self, values=values) - - self._data = {} - - # Assign default values to instance - for key, field in self._fields.iteritems(): - if self._db_field_map.get(key, key) in values: - continue - value = getattr(self, key, None) - setattr(self, key, value) - - # Set passed values after initialisation - if self._dynamic: - self._dynamic_fields = {} - dynamic_data = {} - for key, value in values.iteritems(): - if key in self._fields or key == '_id': - setattr(self, key, value) - elif self._dynamic: - dynamic_data[key] = value - else: - for key, value in values.iteritems(): - key = self._reverse_db_field_map.get(key, key) - setattr(self, key, value) - - # Set any get_fieldname_display methods - self.__set_field_display() - - if self._dynamic: - self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): - setattr(self, key, value) - - # Flag initialised - self._initialised = True - signals.post_init.send(self.__class__, document=self) - - def __setattr__(self, name, value): - # Handle dynamic data only if an initialised dynamic document - if self._dynamic and not self._dynamic_lock: - - field = None - if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class("DynamicField") - field = DynamicField(db_field=name) - field.name = name - self._dynamic_fields[name] = field - - if not name.startswith('_'): - value = self.__expand_dynamic_values(name, value) - - # Handle marking data as changed - if name in self._dynamic_fields: - self._data[name] = value - if hasattr(self, '_changed_fields'): - self._mark_as_changed(name) - - if (self._is_document and not self._created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): - OperationError = _import_class('OperationError') - msg = "Shard Keys are immutable. Tried to update %s" % name - raise OperationError(msg) - - super(BaseDocument, self).__setattr__(name, value) - - def __expand_dynamic_values(self, name, value): - """expand any dynamic values to their correct types / values""" - if not isinstance(value, (dict, list, tuple)): - return value - - is_list = False - if not hasattr(value, 'items'): - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - - if not is_list and '_cls' in value: - cls = get_document(value['_cls']) - return cls(**value) - - data = {} - for k, v in value.items(): - key = name if is_list else k - data[k] = self.__expand_dynamic_values(key, v) - - if is_list: # Convert back to a list - data_items = sorted(data.items(), key=operator.itemgetter(0)) - value = [v for k, v in data_items] - else: - value = data - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, self, name) - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, self, name) - - return value - - def validate(self): - """Ensure that all fields' values are valid and that required fields - are present. - """ - # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) - for name, field in self._fields.items()] - - # Ensure that each field is matched to a valid value - errors = {} - for field, value in fields: - if value is not None: - try: - field._validate(value) - except ValidationError, error: - errors[field.name] = error.errors or error - except (ValueError, AttributeError, AssertionError), error: - errors[field.name] = error - elif field.required: - errors[field.name] = ValidationError('Field is required', - field_name=field.name) - if errors: - raise ValidationError('ValidationError', errors=errors) - - def to_mongo(self): - """Return data dictionary ready for use with MongoDB. - """ - data = {} - for field_name, field in self._fields.items(): - value = getattr(self, field_name, None) - if value is not None: - data[field.db_field] = field.to_mongo(value) - # Only add _cls and _types if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): - data['_cls'] = self._class_name - data['_types'] = self._superclasses.keys() + [self._class_name] - if '_id' in data and data['_id'] is None: - del data['_id'] - - if not self._dynamic: - return data - - for name, field in self._dynamic_fields.items(): - data[name] = field.to_mongo(self._data.get(name, None)) - return data - - @classmethod - def _get_collection_name(cls): - """Returns the collection name for this class. - """ - return cls._meta.get('collection', None) - - @classmethod - def _from_son(cls, son): - """Create an instance of a Document (subclass) from a PyMongo SON. - """ - # get the class name from the document, falling back to the given - # class if unavailable - class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.items()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) - - if '_types' in data: - del data['_types'] - - if '_cls' in data: - del data['_cls'] - - # Return correct subclass for document type - if class_name != cls._class_name: - cls = get_document(class_name) - - changed_fields = [] - errors_dict = {} - - for field_name, field in cls._fields.items(): - if field.db_field in data: - value = data[field.db_field] - try: - data[field_name] = (value if value is None - else field.to_python(value)) - if field_name != field.db_field: - del data[field.db_field] - except (AttributeError, ValueError), e: - errors_dict[field_name] = e - elif field.default: - default = field.default - if callable(default): - default = default() - if isinstance(default, BaseDocument): - changed_fields.append(field_name) - - if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) - msg = ("Invalid data to create a `%s` instance.\n%s" - % (cls._class_name, errors)) - raise InvalidDocumentError(msg) - - obj = cls(**data) - obj._changed_fields = changed_fields - obj._created = False - return obj - - def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user - """ - if not key: - return - key = self._db_field_map.get(key, key) - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): - self._changed_fields.append(key) - - def _get_changed_fields(self, key='', inspected=None): - """Returns a list of all fields that have explicitly been changed. - """ - EmbeddedDocument = _import_class("EmbeddedDocument") - DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") - _changed_fields = [] - _changed_fields += getattr(self, '_changed_fields', []) - - inspected = inspected or set() - if hasattr(self, 'id'): - if self.id in inspected: - return _changed_fields - inspected.add(self.id) - - field_list = self._fields.copy() - if self._dynamic: - field_list.update(self._dynamic_fields) - - for field_name in field_list: - - db_field_name = self._db_field_map.get(field_name, field_name) - key = '%s.' % db_field_name - field = self._data.get(field_name, None) - if hasattr(field, 'id'): - if field.id in inspected: - continue - inspected.add(field.id) - - if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in _changed_fields): - # Find all embedded fields that have been changed - changed = field._get_changed_fields(key, inspected) - _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(field, (list, tuple, dict)) and - db_field_name not in _changed_fields): - # Loop list / dict fields as they contain documents - # Determine the iterator to use - if not hasattr(field, 'items'): - iterator = enumerate(field) - else: - iterator = field.iteritems() - for index, value in iterator: - if not hasattr(value, '_get_changed_fields'): - continue - list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(list_key, inspected) - _changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] - return _changed_fields - - def _delta(self): - """Returns the delta (set, unset) of the changes for a document. - Gets any values that have been explicitly changed. - """ - # Handles cases where not loaded from_son but has _id - doc = self.to_mongo() - set_fields = self._get_changed_fields() - set_data = {} - unset_data = {} - parts = [] - if hasattr(self, '_changed_fields'): - set_data = {} - # Fetch each set item from its path - for path in set_fields: - parts = path.split('.') - d = doc - new_path = [] - for p in parts: - if isinstance(d, DBRef): - break - elif p.isdigit(): - d = d[int(p)] - elif hasattr(d, 'get'): - d = d.get(p) - new_path.append(p) - path = '.'.join(new_path) - set_data[path] = d - else: - set_data = doc - if '_id' in set_data: - del(set_data['_id']) - - # Determine if any changed items were actually unset. - for path, value in set_data.items(): - if value or isinstance(value, (bool, int)): - continue - - # If we've set a value that ain't the default value dont unset it. - default = None - if (self._dynamic and len(parts) and - parts[0] in self._dynamic_fields): - del(set_data[path]) - unset_data[path] = 1 - continue - elif path in self._fields: - default = self._fields[path].default - else: # Perform a full lookup for lists / embedded lookups - d = self - parts = path.split('.') - db_field_name = parts.pop() - for p in parts: - if p.isdigit(): - d = d[int(p)] - elif (hasattr(d, '__getattribute__') and - not isinstance(d, dict)): - real_path = d._reverse_db_field_map.get(p, p) - d = getattr(d, real_path) - else: - d = d.get(p) - - if hasattr(d, '_fields'): - field_name = d._reverse_db_field_map.get(db_field_name, - db_field_name) - - if field_name in d._fields: - default = d._fields.get(field_name).default - else: - default = None - - if default is not None: - if callable(default): - default = default() - if default != value: - continue - - del(set_data[path]) - unset_data[path] = 1 - return set_data, unset_data - - @classmethod - def _geo_indices(cls, inspected=None): - inspected = inspected or [] - geo_indices = [] - inspected.append(cls) - - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GeoPointField = _import_class("GeoPointField") - - for field in cls._fields.values(): - if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): - continue - if hasattr(field, 'document_type'): - field_cls = field.document_type - if field_cls in inspected: - continue - if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected) - elif field._geo_index: - geo_indices.append(field) - return geo_indices - - def __getstate__(self): - removals = ("get_%s_display" % k - for k, v in self._fields.items() if v.choices) - for k in removals: - if hasattr(self, k): - delattr(self, k) - return self.__dict__ - - def __setstate__(self, __dict__): - self.__dict__ = __dict__ - self.__set_field_display() - - def __set_field_display(self): - """Dynamically set the display value for a field with choices""" - for attr_name, field in self._fields.items(): - if field.choices: - setattr(self, - 'get_%s_display' % attr_name, - partial(self.__get_field_display, field=field)) - - def __get_field_display(self, field): - """Returns the display value for a choice field""" - value = getattr(self, field.name) - if field.choices and isinstance(field.choices[0], (list, tuple)): - return dict(field.choices).get(value, value) - return value - - def __iter__(self): - return iter(self._fields) - - def __getitem__(self, name): - """Dictionary-style field access, return a field's value if present. - """ - try: - if name in self._fields: - return getattr(self, name) - except AttributeError: - pass - raise KeyError(name) - - def __setitem__(self, name, value): - """Dictionary-style field access, set a field's value. - """ - # Ensure that the field exists before settings its value - if name not in self._fields: - raise KeyError(name) - return setattr(self, name, value) - - def __contains__(self, name): - try: - val = getattr(self, name) - return val is not None - except AttributeError: - return False - - def __len__(self): - return len(self._data) - - def __repr__(self): - try: - u = self.__str__() - except (UnicodeEncodeError, UnicodeDecodeError): - u = '[Bad Unicode data]' - repr_type = type(u) - return repr_type('<%s: %s>' % (self.__class__.__name__, u)) - - def __str__(self): - if hasattr(self, '__unicode__'): - if PY3: - return self.__unicode__() - else: - return unicode(self).encode('utf-8') - return txt_type('%s object' % self.__class__.__name__) - - def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - if self.pk is None: - # For new object - return super(BaseDocument, self).__hash__() - else: - return hash(self.pk) - - -class BasesTuple(tuple): - """Special class to handle introspection of bases tuple in __new__""" - pass - - -class BaseList(list): - """A special list so we can watch any changes - """ - - _dereferenced = False - _instance = None - _name = None - - def __init__(self, list_items, instance, name): - self._instance = weakref.proxy(instance) - self._name = name - return super(BaseList, self).__init__(list_items) - - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setitem__(*args, **kwargs) - - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delitem__(*args, **kwargs) - - def __getstate__(self): - self.observer = None - return self - - def __setstate__(self, state): - self = state - return self - - def append(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).append(*args, **kwargs) - - def extend(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).extend(*args, **kwargs) - - def insert(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).insert(*args, **kwargs) - - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).pop(*args, **kwargs) - - def remove(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).remove(*args, **kwargs) - - def reverse(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).reverse(*args, **kwargs) - - def sort(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).sort(*args, **kwargs) - - def _mark_as_changed(self): - if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) - - -class BaseDict(dict): - """A special dict so we can watch any changes - """ - - _dereferenced = False - _instance = None - _name = None - - def __init__(self, dict_items, instance, name): - self._instance = weakref.proxy(instance) - self._name = name - return super(BaseDict, self).__init__(dict_items) - - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__setitem__(*args, **kwargs) - - def __delete__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delete__(*args, **kwargs) - - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delitem__(*args, **kwargs) - - def __delattr__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delattr__(*args, **kwargs) - - def __getstate__(self): - self.instance = None - self._dereferenced = False - return self - - def __setstate__(self, state): - self = state - return self - - def clear(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).clear(*args, **kwargs) - - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).pop(*args, **kwargs) - - def popitem(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).popitem(*args, **kwargs) - - def update(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).update(*args, **kwargs) - - def _mark_as_changed(self): - if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) - - -def _import_class(cls_name): - """Cached mechanism for imports""" - if cls_name in _class_registry: - return _class_registry.get(cls_name) - - doc_classes = ['Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument'] - field_classes = ['DictField', 'DynamicField', 'EmbeddedDocumentField', - 'GenericReferenceField', 'GeoPointField', - 'ReferenceField'] - queryset_classes = ['OperationError'] - deref_classes = ['DeReference'] - - if cls_name in doc_classes: - from mongoengine import document as module - import_classes = doc_classes - elif cls_name in field_classes: - from mongoengine import fields as module - import_classes = field_classes - elif cls_name in queryset_classes: - from mongoengine import queryset as module - import_classes = queryset_classes - elif cls_name in deref_classes: - from mongoengine import dereference as module - import_classes = deref_classes - else: - raise ValueError('No import set for: ' % cls_name) - - for cls in import_classes: - _class_registry[cls] = getattr(module, cls) - - return _class_registry.get(cls_name) diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py new file mode 100644 index 00000000..ce119b3a --- /dev/null +++ b/mongoengine/base/__init__.py @@ -0,0 +1,5 @@ +from mongoengine.base.common import * +from mongoengine.base.datastructures import * +from mongoengine.base.document import * +from mongoengine.base.fields import * +from mongoengine.base.metaclasses import * diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py new file mode 100644 index 00000000..3a966c79 --- /dev/null +++ b/mongoengine/base/common.py @@ -0,0 +1,26 @@ +from mongoengine.errors import NotRegistered + +__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') + +ALLOW_INHERITANCE = False + +_document_registry = {} + + +def get_document(name): + doc = _document_registry.get(name, None) + if not doc: + # Possible old style name + single_end = name.split('.')[-1] + compound_end = '.%s' % single_end + possible_match = [k for k in _document_registry.keys() + if k.endswith(compound_end) or k == single_end] + if len(possible_match) == 1: + doc = _document_registry.get(possible_match.pop(), None) + if not doc: + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() % name) + return doc diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py new file mode 100644 index 00000000..c750b5ba --- /dev/null +++ b/mongoengine/base/datastructures.py @@ -0,0 +1,142 @@ +import weakref +from mongoengine.common import _import_class + +__all__ = ("BaseDict", "BaseList") + + +class BaseDict(dict): + """A special dict so we can watch any changes + """ + + _dereferenced = False + _instance = None + _name = None + + def __init__(self, dict_items, instance, name): + self._instance = weakref.proxy(instance) + self._name = name + return super(BaseDict, self).__init__(dict_items) + + def __getitem__(self, *args, **kwargs): + value = super(BaseDict, self).__getitem__(*args, **kwargs) + + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = self._instance + return value + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__setitem__(*args, **kwargs) + + def __delete__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delete__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delitem__(*args, **kwargs) + + def __delattr__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delattr__(*args, **kwargs) + + def __getstate__(self): + self.instance = None + self._dereferenced = False + return self + + def __setstate__(self, state): + self = state + return self + + def clear(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).clear(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).pop(*args, **kwargs) + + def popitem(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).popitem(*args, **kwargs) + + def update(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).update(*args, **kwargs) + + def _mark_as_changed(self): + if hasattr(self._instance, '_mark_as_changed'): + self._instance._mark_as_changed(self._name) + + +class BaseList(list): + """A special list so we can watch any changes + """ + + _dereferenced = False + _instance = None + _name = None + + def __init__(self, list_items, instance, name): + self._instance = weakref.proxy(instance) + self._name = name + return super(BaseList, self).__init__(list_items) + + def __getitem__(self, *args, **kwargs): + value = super(BaseList, self).__getitem__(*args, **kwargs) + + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = self._instance + return value + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__setitem__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__delitem__(*args, **kwargs) + + def __getstate__(self): + self.instance = None + self._dereferenced = False + return self + + def __setstate__(self, state): + self = state + return self + + def append(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).append(*args, **kwargs) + + def extend(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).extend(*args, **kwargs) + + def insert(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).insert(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).pop(*args, **kwargs) + + def remove(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).remove(*args, **kwargs) + + def reverse(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).reverse(*args, **kwargs) + + def sort(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).sort(*args, **kwargs) + + def _mark_as_changed(self): + if hasattr(self._instance, '_mark_as_changed'): + self._instance._mark_as_changed(self._name) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py new file mode 100644 index 00000000..53686b25 --- /dev/null +++ b/mongoengine/base/document.py @@ -0,0 +1,815 @@ +import copy +import operator +import numbers +from functools import partial + +import pymongo +from bson import json_util +from bson.dbref import DBRef +from bson.son import SON + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.errors import (ValidationError, InvalidDocumentError, + LookUpError) +from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, + to_str_keys_recursive) + +from mongoengine.base.common import get_document, ALLOW_INHERITANCE +from mongoengine.base.datastructures import BaseDict, BaseList +from mongoengine.base.fields import ComplexBaseField + +__all__ = ('BaseDocument', 'NON_FIELD_ERRORS') + +NON_FIELD_ERRORS = '__all__' + + +class BaseDocument(object): + + _dynamic = False + _created = True + _dynamic_lock = True + _initialised = False + + def __init__(self, *args, **values): + """ + Initialise a document or embedded document + + :param __auto_convert: Try and will cast python objects to Object types + :param values: A dictionary of values for the document + """ + if args: + # Combine positional arguments with named arguments. + # We only want named arguments. + field = iter(self._fields_ordered) + for value in args: + name = next(field) + if name in values: + raise TypeError("Multiple values for keyword argument '" + name + "'") + values[name] = value + __auto_convert = values.pop("__auto_convert", True) + signals.pre_init.send(self.__class__, document=self, values=values) + + self._data = {} + + # Assign default values to instance + for key, field in self._fields.iteritems(): + if self._db_field_map.get(key, key) in values: + continue + value = getattr(self, key, None) + setattr(self, key, value) + + # Set passed values after initialisation + if self._dynamic: + self._dynamic_fields = {} + dynamic_data = {} + for key, value in values.iteritems(): + if key in self._fields or key == '_id': + setattr(self, key, value) + elif self._dynamic: + dynamic_data[key] = value + else: + FileField = _import_class('FileField') + for key, value in values.iteritems(): + if key == '__auto_convert': + continue + key = self._reverse_db_field_map.get(key, key) + if key in self._fields or key in ('id', 'pk', '_cls'): + if __auto_convert and value is not None: + field = self._fields.get(key) + if field and not isinstance(field, FileField): + value = field.to_python(value) + setattr(self, key, value) + else: + self._data[key] = value + + # Set any get_fieldname_display methods + self.__set_field_display() + + if self._dynamic: + self._dynamic_lock = False + for key, value in dynamic_data.iteritems(): + setattr(self, key, value) + + # Flag initialised + self._initialised = True + signals.post_init.send(self.__class__, document=self) + + def __delattr__(self, *args, **kwargs): + """Handle deletions of fields""" + field_name = args[0] + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + super(BaseDocument, self).__delattr__(*args, **kwargs) + + def __setattr__(self, name, value): + # Handle dynamic data only if an initialised dynamic document + if self._dynamic and not self._dynamic_lock: + + field = None + if not hasattr(self, name) and not name.startswith('_'): + DynamicField = _import_class("DynamicField") + field = DynamicField(db_field=name) + field.name = name + self._dynamic_fields[name] = field + + if not name.startswith('_'): + value = self.__expand_dynamic_values(name, value) + + # Handle marking data as changed + if name in self._dynamic_fields: + self._data[name] = value + if hasattr(self, '_changed_fields'): + self._mark_as_changed(name) + + if (self._is_document and not self._created and + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): + OperationError = _import_class('OperationError') + msg = "Shard Keys are immutable. Tried to update %s" % name + raise OperationError(msg) + + # Check if the user has created a new instance of a class + if (self._is_document and self._initialised + and self._created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) + + super(BaseDocument, self).__setattr__(name, value) + + def __getstate__(self): + removals = ("get_%s_display" % k + for k, v in self._fields.items() if v.choices) + for k in removals: + if hasattr(self, k): + delattr(self, k) + return self.__dict__ + + def __setstate__(self, __dict__): + self.__dict__ = __dict__ + self.__set_field_display() + + def __iter__(self): + if 'id' in self._fields and 'id' not in self._fields_ordered: + return iter(('id', ) + self._fields_ordered) + + return iter(self._fields_ordered) + + def __getitem__(self, name): + """Dictionary-style field access, return a field's value if present. + """ + try: + if name in self._fields: + return getattr(self, name) + except AttributeError: + pass + raise KeyError(name) + + def __setitem__(self, name, value): + """Dictionary-style field access, set a field's value. + """ + # Ensure that the field exists before settings its value + if name not in self._fields: + raise KeyError(name) + return setattr(self, name, value) + + def __contains__(self, name): + try: + val = getattr(self, name) + return val is not None + except AttributeError: + return False + + def __len__(self): + return len(self._data) + + def __repr__(self): + try: + u = self.__str__() + except (UnicodeEncodeError, UnicodeDecodeError): + u = '[Bad Unicode data]' + repr_type = type(u) + return repr_type('<%s: %s>' % (self.__class__.__name__, u)) + + def __str__(self): + if hasattr(self, '__unicode__'): + if PY3: + return self.__unicode__() + else: + return unicode(self).encode('utf-8') + return txt_type('%s object' % self.__class__.__name__) + + def __eq__(self, other): + if isinstance(other, self.__class__) and hasattr(other, 'id'): + if self.id == other.id: + return True + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + if self.pk is None: + # For new object + return super(BaseDocument, self).__hash__() + else: + return hash(self.pk) + + def clean(self): + """ + Hook for doing document level data cleaning before validation is run. + + Any ValidationError raised by this method will not be associated with + a particular field; it will have a special-case association with the + field defined by NON_FIELD_ERRORS. + """ + pass + + def to_mongo(self): + """Return as SON data ready for use with MongoDB. + """ + data = SON() + data["_id"] = None + data['_cls'] = self._class_name + + for field_name in self: + value = self._data.get(field_name, None) + field = self._fields.get(field_name) + + if value is not None: + value = field.to_mongo(value) + + # Handle self generating fields + if value is None and field._auto_gen: + value = field.generate() + self._data[field_name] = value + + if value is not None: + data[field.db_field] = value + + # If "_id" has not been set, then try and set it + if data["_id"] is None: + data["_id"] = self._data.get("id", None) + + if data['_id'] is None: + data.pop('_id') + + # Only add _cls if allow_inheritance is True + if (not hasattr(self, '_meta') or + not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + data.pop('_cls') + + if not self._dynamic: + return data + + # Sort dynamic fields by key + dynamic_fields = sorted(self._dynamic_fields.iteritems(), + key=operator.itemgetter(0)) + for name, field in dynamic_fields: + data[name] = field.to_mongo(self._data.get(name, None)) + + return data + + def validate(self, clean=True): + """Ensure that all fields' values are valid and that required fields + are present. + """ + # Ensure that each field is matched to a valid value + errors = {} + if clean: + try: + self.clean() + except ValidationError, error: + errors[NON_FIELD_ERRORS] = error + + # Get a list of tuples of field names and their current values + fields = [(field, self._data.get(name)) + for name, field in self._fields.items()] + if self._dynamic: + fields += [(field, self._data.get(name)) + for name, field in self._dynamic_fields.items()] + + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") + + for field, value in fields: + if value is not None: + try: + if isinstance(field, (EmbeddedDocumentField, + GenericEmbeddedDocumentField)): + field._validate(value, clean=clean) + else: + field._validate(value) + except ValidationError, error: + errors[field.name] = error.errors or error + except (ValueError, AttributeError, AssertionError), error: + errors[field.name] = error + elif field.required and not getattr(field, '_auto_gen', False): + errors[field.name] = ValidationError('Field is required', + field_name=field.name) + + if errors: + pk = "None" + if hasattr(self, 'pk'): + pk = self.pk + elif self._instance: + pk = self._instance.pk + message = "ValidationError (%s:%s) " % (self._class_name, pk) + raise ValidationError(message, errors=errors) + + def to_json(self): + """Converts a document to JSON""" + return json_util.dumps(self.to_mongo()) + + @classmethod + def from_json(cls, json_data): + """Converts json data to an unsaved document instance""" + return cls._from_son(json_util.loads(json_data)) + + def __expand_dynamic_values(self, name, value): + """expand any dynamic values to their correct types / values""" + if not isinstance(value, (dict, list, tuple)): + return value + + is_list = False + if not hasattr(value, 'items'): + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + + if not is_list and '_cls' in value: + cls = get_document(value['_cls']) + return cls(**value) + + data = {} + for k, v in value.items(): + key = name if is_list else k + data[k] = self.__expand_dynamic_values(key, v) + + if is_list: # Convert back to a list + data_items = sorted(data.items(), key=operator.itemgetter(0)) + value = [v for k, v in data_items] + else: + value = data + + # Convert lists / values so we can watch for any changes on them + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): + value = BaseList(value, self, name) + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, self, name) + + return value + + def _mark_as_changed(self, key): + """Marks a key as explicitly changed by the user + """ + if not key: + return + key = self._db_field_map.get(key, key) + if (hasattr(self, '_changed_fields') and + key not in self._changed_fields): + self._changed_fields.append(key) + + def _clear_changed_fields(self): + self._changed_fields = [] + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + for field_name, field in self._fields.iteritems(): + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + field_value[idx]._clear_changed_fields() + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + field_value._clear_changed_fields() + + def _get_changed_fields(self, key='', inspected=None): + """Returns a list of all fields that have explicitly been changed. + """ + EmbeddedDocument = _import_class("EmbeddedDocument") + DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") + _changed_fields = [] + _changed_fields += getattr(self, '_changed_fields', []) + + inspected = inspected or set() + if hasattr(self, 'id'): + if self.id in inspected: + return _changed_fields + inspected.add(self.id) + + field_list = self._fields.copy() + if self._dynamic: + field_list.update(self._dynamic_fields) + + for field_name in field_list: + + db_field_name = self._db_field_map.get(field_name, field_name) + key = '%s.' % db_field_name + field = self._data.get(field_name, None) + if hasattr(field, 'id'): + if field.id in inspected: + continue + inspected.add(field.id) + + if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) + and db_field_name not in _changed_fields): + # Find all embedded fields that have been changed + changed = field._get_changed_fields(key, inspected) + _changed_fields += ["%s%s" % (key, k) for k in changed if k] + elif (isinstance(field, (list, tuple, dict)) and + db_field_name not in _changed_fields): + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(field, 'items'): + iterator = enumerate(field) + else: + iterator = field.iteritems() + for index, value in iterator: + if not hasattr(value, '_get_changed_fields'): + continue + list_key = "%s%s." % (key, index) + changed = value._get_changed_fields(list_key, inspected) + _changed_fields += ["%s%s" % (list_key, k) + for k in changed if k] + return _changed_fields + + def _delta(self): + """Returns the delta (set, unset) of the changes for a document. + Gets any values that have been explicitly changed. + """ + # Handles cases where not loaded from_son but has _id + doc = self.to_mongo() + + set_fields = self._get_changed_fields() + set_data = {} + unset_data = {} + parts = [] + if hasattr(self, '_changed_fields'): + set_data = {} + # Fetch each set item from its path + for path in set_fields: + parts = path.split('.') + d = doc + new_path = [] + for p in parts: + if isinstance(d, DBRef): + break + elif isinstance(d, list) and p.isdigit(): + d = d[int(p)] + elif hasattr(d, 'get'): + d = d.get(p) + new_path.append(p) + path = '.'.join(new_path) + set_data[path] = d + else: + set_data = doc + if '_id' in set_data: + del(set_data['_id']) + + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value or isinstance(value, (numbers.Number, bool)): + continue + + # If we've set a value that ain't the default value dont unset it. + default = None + if (self._dynamic and len(parts) and parts[0] in + self._dynamic_fields): + del(set_data[path]) + unset_data[path] = 1 + continue + elif path in self._fields: + default = self._fields[path].default + else: # Perform a full lookup for lists / embedded lookups + d = self + parts = path.split('.') + db_field_name = parts.pop() + for p in parts: + if isinstance(d, list) and p.isdigit(): + d = d[int(p)] + elif (hasattr(d, '__getattribute__') and + not isinstance(d, dict)): + real_path = d._reverse_db_field_map.get(p, p) + d = getattr(d, real_path) + else: + d = d.get(p) + + if hasattr(d, '_fields'): + field_name = d._reverse_db_field_map.get(db_field_name, + db_field_name) + if field_name in d._fields: + default = d._fields.get(field_name).default + else: + default = None + + if default is not None: + if callable(default): + default = default() + + if default != value: + continue + + del(set_data[path]) + unset_data[path] = 1 + return set_data, unset_data + + @classmethod + def _get_collection_name(cls): + """Returns the collection name for this class. + """ + return cls._meta.get('collection', None) + + @classmethod + def _from_son(cls, son, _auto_dereference=True): + """Create an instance of a Document (subclass) from a PyMongo SON. + """ + + # get the class name from the document, falling back to the given + # class if unavailable + class_name = son.get('_cls', cls._class_name) + data = dict(("%s" % key, value) for key, value in son.iteritems()) + if not UNICODE_KWARGS: + # python 2.6.4 and lower cannot handle unicode keys + # passed to class constructor example: cls(**data) + to_str_keys_recursive(data) + + # Return correct subclass for document type + if class_name != cls._class_name: + cls = get_document(class_name) + + changed_fields = [] + errors_dict = {} + + fields = cls._fields + if not _auto_dereference: + fields = copy.copy(fields) + + for field_name, field in fields.iteritems(): + field._auto_dereference = _auto_dereference + if field.db_field in data: + value = data[field.db_field] + try: + data[field_name] = (value if value is None + else field.to_python(value)) + if field_name != field.db_field: + del data[field.db_field] + except (AttributeError, ValueError), e: + errors_dict[field_name] = e + elif field.default: + default = field.default + if callable(default): + default = default() + if isinstance(default, BaseDocument): + changed_fields.append(field_name) + + if errors_dict: + errors = "\n".join(["%s - %s" % (k, v) + for k, v in errors_dict.items()]) + msg = ("Invalid data to create a `%s` instance.\n%s" + % (cls._class_name, errors)) + raise InvalidDocumentError(msg) + + obj = cls(__auto_convert=False, **data) + obj._changed_fields = changed_fields + obj._created = False + if not _auto_dereference: + obj._fields = fields + return obj + + @classmethod + def _build_index_specs(cls, meta_indexes): + """Generate and merge the full index specs + """ + + geo_indices = cls._geo_indices() + unique_indices = cls._unique_with_indexes() + index_specs = [cls._build_index_spec(spec) + for spec in meta_indexes] + + def merge_index_specs(index_specs, indices): + if not indices: + return index_specs + + spec_fields = [v['fields'] + for k, v in enumerate(index_specs)] + # Merge unqiue_indexes with existing specs + for k, v in enumerate(indices): + if v['fields'] in spec_fields: + index_specs[spec_fields.index(v['fields'])].update(v) + else: + index_specs.append(v) + return index_specs + + index_specs = merge_index_specs(index_specs, geo_indices) + index_specs = merge_index_specs(index_specs, unique_indices) + return index_specs + + @classmethod + def _build_index_spec(cls, spec): + """Build a PyMongo index spec from a MongoEngine index spec. + """ + if isinstance(spec, basestring): + spec = {'fields': [spec]} + elif isinstance(spec, (list, tuple)): + spec = {'fields': list(spec)} + elif isinstance(spec, dict): + spec = dict(spec) + + index_list = [] + direction = None + + # Check to see if we need to include _cls + allow_inheritance = cls._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + include_cls = allow_inheritance and not spec.get('sparse', False) + + for key in spec['fields']: + # If inherited spec continue + if isinstance(key, (list, tuple)): + continue + + # ASCENDING from +, + # DESCENDING from - + # GEO2D from * + direction = pymongo.ASCENDING + if key.startswith("-"): + direction = pymongo.DESCENDING + elif key.startswith("*"): + direction = pymongo.GEO2D + if key.startswith(("+", "-", "*")): + key = key[1:] + + # Use real field name, do it manually because we need field + # objects for the next part (list field checking) + parts = key.split('.') + if parts in (['pk'], ['id'], ['_id']): + key = '_id' + fields = [] + else: + fields = cls._lookup_field(parts) + parts = [field if field == '_id' else field.db_field + for field in fields] + key = '.'.join(parts) + index_list.append((key, direction)) + + # Don't add cls to a geo index + if include_cls and direction is not pymongo.GEO2D: + index_list.insert(0, ('_cls', 1)) + + spec['fields'] = index_list + if spec.get('sparse', False) and len(spec['fields']) > 1: + raise ValueError( + 'Sparse indexes can only have one field in them. ' + 'See https://jira.mongodb.org/browse/SERVER-2193') + + return spec + + @classmethod + def _unique_with_indexes(cls, namespace=""): + """ + Find and set unique indexes + """ + unique_indexes = [] + for field_name, field in cls._fields.items(): + sparse = False + # Generate a list of indexes needed by uniqueness constraints + if field.unique: + field.required = True + unique_fields = [field.db_field] + + # Add any unique_with fields to the back of the index spec + if field.unique_with: + if isinstance(field.unique_with, basestring): + field.unique_with = [field.unique_with] + + # Convert unique_with field names to real field names + unique_with = [] + for other_name in field.unique_with: + parts = other_name.split('.') + # Lookup real name + parts = cls._lookup_field(parts) + name_parts = [part.db_field for part in parts] + unique_with.append('.'.join(name_parts)) + # Unique field should be required + parts[-1].required = True + sparse = (not sparse and + parts[-1].name not in cls.__dict__) + unique_fields += unique_with + + # Add the new index to the list + fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) + for f in unique_fields] + index = {'fields': fields, 'unique': True, 'sparse': sparse} + unique_indexes.append(index) + + # Grab any embedded document field unique indexes + if (field.__class__.__name__ == "EmbeddedDocumentField" and + field.document_type != cls): + field_namespace = "%s." % field_name + doc_cls = field.document_type + unique_indexes += doc_cls._unique_with_indexes(field_namespace) + + return unique_indexes + + @classmethod + def _geo_indices(cls, inspected=None): + inspected = inspected or [] + geo_indices = [] + inspected.append(cls) + + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GeoPointField = _import_class("GeoPointField") + + for field in cls._fields.values(): + if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): + continue + if hasattr(field, 'document_type'): + field_cls = field.document_type + if field_cls in inspected: + continue + if hasattr(field_cls, '_geo_indices'): + geo_indices += field_cls._geo_indices(inspected) + elif field._geo_index: + geo_indices.append({'fields': + [(field.db_field, pymongo.GEO2D)]}) + return geo_indices + + @classmethod + def _lookup_field(cls, parts): + """Lookup a field based on its attribute and return a list containing + the field's parents and the field. + """ + if not isinstance(parts, (list, tuple)): + parts = [parts] + fields = [] + field = None + + for field_name in parts: + # Handle ListField indexing: + if field_name.isdigit(): + new_field = field.field + fields.append(field_name) + continue + + if field is None: + # Look up first field from the document + if field_name == 'pk': + # Deal with "primary key" alias + field_name = cls._meta['id_field'] + if field_name in cls._fields: + field = cls._fields[field_name] + elif cls._dynamic: + DynamicField = _import_class('DynamicField') + field = DynamicField(db_field=field_name) + else: + raise LookUpError('Cannot resolve field "%s"' + % field_name) + else: + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + if isinstance(field, (ReferenceField, GenericReferenceField)): + raise LookUpError('Cannot perform join in mongoDB: %s' % + '__'.join(parts)) + if hasattr(getattr(field, 'field', None), 'lookup_member'): + new_field = field.field.lookup_member(field_name) + else: + # Look up subfield on the previous field + new_field = field.lookup_member(field_name) + if not new_field and isinstance(field, ComplexBaseField): + fields.append(field_name) + continue + elif not new_field: + raise LookUpError('Cannot resolve field "%s"' + % field_name) + field = new_field # update field to the new field type + fields.append(field) + return fields + + @classmethod + def _translate_field_name(cls, field, sep='.'): + """Translate a field attribute name to a database field name. + """ + parts = field.split(sep) + parts = [f.db_field for f in cls._lookup_field(parts)] + return '.'.join(parts) + + def __set_field_display(self): + """Dynamically set the display value for a field with choices""" + for attr_name, field in self._fields.items(): + if field.choices: + setattr(self, + 'get_%s_display' % attr_name, + partial(self.__get_field_display, field=field)) + + def __get_field_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + if field.choices and isinstance(field.choices[0], (list, tuple)): + return dict(field.choices).get(value, value) + return value diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py new file mode 100644 index 00000000..3929a3a5 --- /dev/null +++ b/mongoengine/base/fields.py @@ -0,0 +1,395 @@ +import operator +import warnings +import weakref + +from bson import DBRef, ObjectId + +from mongoengine.common import _import_class +from mongoengine.errors import ValidationError + +from mongoengine.base.common import ALLOW_INHERITANCE +from mongoengine.base.datastructures import BaseDict, BaseList + +__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField") + + +class BaseField(object): + """A base class for fields in a MongoDB document. Instances of this class + may be added to subclasses of `Document` to define a document's schema. + + .. versionchanged:: 0.5 - added verbose and help text + """ + + name = None + _geo_index = False + _auto_gen = False # Call `generate` to generate a value + _auto_dereference = True + + # These track each time a Field instance is created. Used to retain order. + # The auto_creation_counter is used for fields that MongoEngine implicitly + # creates, creation_counter is used for all user-specified fields. + creation_counter = 0 + auto_creation_counter = -1 + + def __init__(self, db_field=None, name=None, required=False, default=None, + unique=False, unique_with=None, primary_key=False, + validation=None, choices=None, verbose_name=None, + help_text=None): + self.db_field = (db_field or name) if not primary_key else '_id' + if name: + msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" + warnings.warn(msg, DeprecationWarning) + self.required = required or primary_key + self.default = default + self.unique = bool(unique or unique_with) + self.unique_with = unique_with + self.primary_key = primary_key + self.validation = validation + self.choices = choices + self.verbose_name = verbose_name + self.help_text = help_text + + # Adjust the appropriate creation counter, and save our local copy. + if self.db_field == '_id': + self.creation_counter = BaseField.auto_creation_counter + BaseField.auto_creation_counter -= 1 + else: + self.creation_counter = BaseField.creation_counter + BaseField.creation_counter += 1 + + def __get__(self, instance, owner): + """Descriptor for retrieving a value from a field in a document. Do + any necessary conversion between Python and MongoDB types. + """ + if instance is None: + # Document class being used rather than a document object + return self + # Get value from document instance if available, if not use default + value = instance._data.get(self.name) + + if value is None: + value = self.default + # Allow callable default values + if callable(value): + value = value() + + EmbeddedDocument = _import_class('EmbeddedDocument') + if isinstance(value, EmbeddedDocument) and value._instance is None: + value._instance = weakref.proxy(instance) + return value + + def __set__(self, instance, value): + """Descriptor for assigning a value to a field in a document. + """ + changed = False + if (self.name not in instance._data or + instance._data[self.name] != value): + changed = True + instance._data[self.name] = value + if changed and instance._initialised: + instance._mark_as_changed(self.name) + + def error(self, message="", errors=None, field_name=None): + """Raises a ValidationError. + """ + field_name = field_name if field_name else self.name + raise ValidationError(message, errors=errors, field_name=field_name) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + return value + + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + return self.to_python(value) + + def prepare_query_value(self, op, value): + """Prepare a value that is being used in a query for PyMongo. + """ + return value + + def validate(self, value, clean=True): + """Perform validation on a value. + """ + pass + + def _validate(self, value, **kwargs): + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + # check choices + if self.choices: + is_cls = isinstance(value, (Document, EmbeddedDocument)) + value_to_check = value.__class__ if is_cls else value + err_msg = 'an instance' if is_cls else 'one' + if isinstance(self.choices[0], (list, tuple)): + option_keys = [k for k, v in self.choices] + if value_to_check not in option_keys: + msg = ('Value must be %s of %s' % + (err_msg, unicode(option_keys))) + self.error(msg) + elif value_to_check not in self.choices: + msg = ('Value must be %s of %s' % + (err_msg, unicode(self.choices))) + self.error(msg) + + # check validation argument + if self.validation is not None: + if callable(self.validation): + if not self.validation(value): + self.error('Value does not match custom validation method') + else: + raise ValueError('validation argument for "%s" must be a ' + 'callable.' % self.name) + + self.validate(value, **kwargs) + + +class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. + + Allows for nesting of embedded documents inside complex types. + Handles the lazy dereferencing of a queryset by lazily dereferencing all + items in a list / dict rather than one at a time. + + .. versionadded:: 0.5 + """ + + field = None + __dereference = False + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + dereference = (self._auto_dereference and + (self.field is None or isinstance(self.field, + (GenericReferenceField, ReferenceField)))) + + self._auto_dereference = instance._fields[self.name]._auto_dereference + if not self.__dereference and instance._initialised and dereference: + instance._data[self.name] = self._dereference( + instance._data.get(self.name), max_depth=1, instance=instance, + name=self.name + ) + + value = super(ComplexBaseField, self).__get__(instance, owner) + + # Convert lists / values so we can watch for any changes on them + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): + value = BaseList(value, instance, self.name) + instance._data[self.name] = value + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, instance, self.name) + instance._data[self.name] = value + + if (self._auto_dereference and instance._initialised and + isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): + value = self._dereference( + value, max_depth=1, instance=instance, name=self.name + ) + value._dereferenced = True + instance._data[self.name] = value + + return value + + def __set__(self, instance, value): + """Descriptor for assigning a value to a field in a document. + """ + instance._data[self.name] = value + instance._mark_as_changed(self.name) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + Document = _import_class('Document') + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_python'): + return value.to_python() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_python(item)) + for key, item in value.items()]) + else: + value_dict = {} + for k, v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + self.error('You can only reference documents once they' + ' have been saved to the database') + collection = v._get_collection_name() + value_dict[k] = DBRef(collection, v.pk) + elif hasattr(v, 'to_python'): + value_dict[k] = v.to_python() + else: + value_dict[k] = self.to_python(v) + + if is_list: # Convert back to a list + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] + return value_dict + + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + Document = _import_class("Document") + EmbeddedDocument = _import_class("EmbeddedDocument") + GenericReferenceField = _import_class("GenericReferenceField") + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_mongo'): + if isinstance(value, Document): + return GenericReferenceField().to_mongo(value) + cls = value.__class__ + val = value.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(value, EmbeddedDocument)): + val['_cls'] = cls.__name__ + return val + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_mongo(item)) + for key, item in value.iteritems()]) + else: + value_dict = {} + for k, v in value.iteritems(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + self.error('You can only reference documents once they' + ' have been saved to the database') + + # If its a document that is not inheritable it won't have + # any _cls data so make it a generic reference allows + # us to dereference + meta = getattr(v, '_meta', {}) + allow_inheritance = ( + meta.get('allow_inheritance', ALLOW_INHERITANCE) + is True) + if not allow_inheritance and not self.field: + value_dict[k] = GenericReferenceField().to_mongo(v) + else: + collection = v._get_collection_name() + value_dict[k] = DBRef(collection, v.pk) + elif hasattr(v, 'to_mongo'): + cls = v.__class__ + val = v.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(v, (Document, EmbeddedDocument))): + val['_cls'] = cls.__name__ + value_dict[k] = val + else: + value_dict[k] = self.to_mongo(v) + + if is_list: # Convert back to a list + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] + return value_dict + + def validate(self, value): + """If field is provided ensure the value is valid. + """ + errors = {} + if self.field: + if hasattr(value, 'iteritems') or hasattr(value, 'items'): + sequence = value.iteritems() + else: + sequence = enumerate(value) + for k, v in sequence: + try: + self.field._validate(v) + except ValidationError, error: + errors[k] = error.errors or error + except (ValueError, AssertionError), error: + errors[k] = error + + if errors: + field_class = self.field.__class__.__name__ + self.error('Invalid %s item (%s)' % (field_class, value), + errors=errors) + # Don't allow empty values if required + if self.required and not value: + self.error('Field is required and cannot be empty') + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def lookup_member(self, member_name): + if self.field: + return self.field.lookup_member(member_name) + return None + + def _set_owner_document(self, owner_document): + if self.field: + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + @property + def _dereference(self,): + if not self.__dereference: + DeReference = _import_class("DeReference") + self.__dereference = DeReference() # Cached + return self.__dereference + + +class ObjectIdField(BaseField): + """A field wrapper around MongoDB's ObjectIds. + """ + + def to_python(self, value): + if not isinstance(value, ObjectId): + value = ObjectId(value) + return value + + def to_mongo(self, value): + if not isinstance(value, ObjectId): + try: + return ObjectId(unicode(value)) + except Exception, e: + # e.message attribute has been deprecated since Python 2.6 + self.error(unicode(e)) + return value + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def validate(self, value): + try: + ObjectId(unicode(value)) + except: + self.error('Invalid Object ID') diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py new file mode 100644 index 00000000..def8a055 --- /dev/null +++ b/mongoengine/base/metaclasses.py @@ -0,0 +1,396 @@ +import warnings + +import pymongo + +from mongoengine.common import _import_class +from mongoengine.errors import InvalidDocumentError +from mongoengine.python_support import PY3 +from mongoengine.queryset import (DO_NOTHING, DoesNotExist, + MultipleObjectsReturned, + QuerySet, QuerySetManager) + +from mongoengine.base.common import _document_registry, ALLOW_INHERITANCE +from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField + +__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') + + +class DocumentMetaclass(type): + """Metaclass for all documents. + """ + + def __new__(cls, name, bases, attrs): + flattened_bases = cls._get_bases(bases) + super_new = super(DocumentMetaclass, cls).__new__ + + # If a base class just call super + metaclass = attrs.get('my_metaclass') + if metaclass and issubclass(metaclass, DocumentMetaclass): + return super_new(cls, name, bases, attrs) + + attrs['_is_document'] = attrs.get('_is_document', False) + + # EmbeddedDocuments could have meta data for inheritance + if 'meta' in attrs: + attrs['_meta'] = attrs.pop('meta') + + # EmbeddedDocuments should inherit meta data + if '_meta' not in attrs: + meta = MetaDict() + for base in flattened_bases[::-1]: + # Add any mixin metadata from plain objects + if hasattr(base, 'meta'): + meta.merge(base.meta) + elif hasattr(base, '_meta'): + meta.merge(base._meta) + attrs['_meta'] = meta + + # Handle document Fields + + # Merge all fields from subclasses + doc_fields = {} + for base in flattened_bases[::-1]: + if hasattr(base, '_fields'): + doc_fields.update(base._fields) + + # Standard object mixin - merge in any Fields + if not hasattr(base, '_meta'): + base_fields = {} + for attr_name, attr_value in base.__dict__.iteritems(): + if not isinstance(attr_value, BaseField): + continue + attr_value.name = attr_name + if not attr_value.db_field: + attr_value.db_field = attr_name + base_fields[attr_name] = attr_value + + doc_fields.update(base_fields) + + # Discover any document fields + field_names = {} + for attr_name, attr_value in attrs.iteritems(): + if not isinstance(attr_value, BaseField): + continue + attr_value.name = attr_name + if not attr_value.db_field: + attr_value.db_field = attr_name + doc_fields[attr_name] = attr_value + + # Count names to ensure no db_field redefinitions + field_names[attr_value.db_field] = field_names.get( + attr_value.db_field, 0) + 1 + + # Ensure no duplicate db_fields + duplicate_db_fields = [k for k, v in field_names.items() if v > 1] + if duplicate_db_fields: + msg = ("Multiple db_fields defined for: %s " % + ", ".join(duplicate_db_fields)) + raise InvalidDocumentError(msg) + + # Set _fields and db_field maps + attrs['_fields'] = doc_fields + attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) + for k, v in doc_fields.iteritems()]) + attrs['_fields_ordered'] = tuple(i[1] for i in sorted( + (v.creation_counter, v.name) + for v in doc_fields.itervalues())) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) + + # + # Set document hierarchy + # + superclasses = () + class_name = [name] + for base in flattened_bases: + if (not getattr(base, '_is_base_cls', True) and + not getattr(base, '_meta', {}).get('abstract', True)): + # Collate heirarchy for _cls and _subclasses + class_name.append(base.__name__) + + if hasattr(base, '_meta'): + # Warn if allow_inheritance isn't set and prevent + # inheritance of classes where inheritance is set to False + allow_inheritance = base._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + if (allow_inheritance is not True and + not base._meta.get('abstract')): + raise ValueError('Document %s may not be subclassed' % + base.__name__) + + # Get superclasses from last base superclass + document_bases = [b for b in flattened_bases + if hasattr(b, '_class_name')] + if document_bases: + superclasses = document_bases[0]._superclasses + superclasses += (document_bases[0]._class_name, ) + + _cls = '.'.join(reversed(class_name)) + attrs['_class_name'] = _cls + attrs['_superclasses'] = superclasses + attrs['_subclasses'] = (_cls, ) + attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types + + # Create the new_class + new_class = super_new(cls, name, bases, attrs) + + # Set _subclasses + for base in document_bases: + if _cls not in base._subclasses: + base._subclasses += (_cls,) + base._types = base._subclasses # TODO depreciate _types + + # Handle delete rules + Document, EmbeddedDocument, DictField = cls._import_classes() + for field in new_class._fields.itervalues(): + f = field + f.owner_document = new_class + delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) + if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): + delete_rule = getattr(f.field, + 'reverse_delete_rule', + DO_NOTHING) + if isinstance(f, DictField) and delete_rule != DO_NOTHING: + msg = ("Reverse delete rules are not supported " + "for %s (field: %s)" % + (field.__class__.__name__, field.name)) + raise InvalidDocumentError(msg) + + f = field.field + + if delete_rule != DO_NOTHING: + if issubclass(new_class, EmbeddedDocument): + msg = ("Reverse delete rules are not supported for " + "EmbeddedDocuments (field: %s)" % field.name) + raise InvalidDocumentError(msg) + f.document_type.register_delete_rule(new_class, + field.name, delete_rule) + + if (field.name and hasattr(Document, field.name) and + EmbeddedDocument not in new_class.mro()): + msg = ("%s is a document method and not a valid " + "field name" % field.name) + raise InvalidDocumentError(msg) + + if issubclass(new_class, Document): + new_class._collection = None + + # Add class to the _document_registry + _document_registry[new_class._class_name] = new_class + + # In Python 2, User-defined methods objects have special read-only + # attributes 'im_func' and 'im_self' which contain the function obj + # and class instance object respectively. With Python 3 these special + # attributes have been replaced by __func__ and __self__. The Blinker + # module continues to use im_func and im_self, so the code below + # copies __func__ into im_func and __self__ into im_self for + # classmethod objects in Document derived classes. + if PY3: + for key, val in new_class.__dict__.items(): + if isinstance(val, classmethod): + f = val.__get__(new_class) + if hasattr(f, '__func__') and not hasattr(f, 'im_func'): + f.__dict__.update({'im_func': getattr(f, '__func__')}) + if hasattr(f, '__self__') and not hasattr(f, 'im_self'): + f.__dict__.update({'im_self': getattr(f, '__self__')}) + + return new_class + + def add_to_class(self, name, value): + setattr(self, name, value) + + @classmethod + def _get_bases(cls, bases): + if isinstance(bases, BasesTuple): + return bases + seen = [] + bases = cls.__get_bases(bases) + unique_bases = (b for b in bases if not (b in seen or seen.append(b))) + return BasesTuple(unique_bases) + + @classmethod + def __get_bases(cls, bases): + for base in bases: + if base is object: + continue + yield base + for child_base in cls.__get_bases(base.__bases__): + yield child_base + + @classmethod + def _import_classes(cls): + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + DictField = _import_class('DictField') + return (Document, EmbeddedDocument, DictField) + + +class TopLevelDocumentMetaclass(DocumentMetaclass): + """Metaclass for top-level documents (i.e. documents that have their own + collection in the database. + """ + + def __new__(cls, name, bases, attrs): + flattened_bases = cls._get_bases(bases) + super_new = super(TopLevelDocumentMetaclass, cls).__new__ + + # Set default _meta data if base class, otherwise get user defined meta + if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass): + # defaults + attrs['_meta'] = { + 'abstract': True, + 'max_documents': None, + 'max_size': None, + 'ordering': [], # default ordering applied at runtime + 'indexes': [], # indexes to be ensured at runtime + 'id_field': None, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': None, + 'delete_rules': None, + 'allow_inheritance': None, + } + attrs['_is_base_cls'] = True + attrs['_meta'].update(attrs.get('meta', {})) + else: + attrs['_meta'] = attrs.get('meta', {}) + # Explictly set abstract to false unless set + attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) + attrs['_is_base_cls'] = False + + # Set flag marking as document class - as opposed to an object mixin + attrs['_is_document'] = True + + # Ensure queryset_class is inherited + if 'objects' in attrs: + manager = attrs['objects'] + if hasattr(manager, 'queryset_class'): + attrs['_meta']['queryset_class'] = manager.queryset_class + + # Clean up top level meta + if 'meta' in attrs: + del(attrs['meta']) + + # Find the parent document class + parent_doc_cls = [b for b in flattened_bases + if b.__class__ == TopLevelDocumentMetaclass] + parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] + + # Prevent classes setting collection different to their parents + # If parent wasn't an abstract class + if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) + and not parent_doc_cls._meta.get('abstract', True)): + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['_meta']['collection']) + + # Ensure abstract documents have abstract bases + if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): + if (parent_doc_cls and + not parent_doc_cls._meta.get('abstract', False)): + msg = "Abstract document cannot have non-abstract base" + raise ValueError(msg) + return super_new(cls, name, bases, attrs) + + # Merge base class metas. + # Uses a special MetaDict that handles various merging rules + meta = MetaDict() + for base in flattened_bases[::-1]: + # Add any mixin metadata from plain objects + if hasattr(base, 'meta'): + meta.merge(base.meta) + elif hasattr(base, '_meta'): + meta.merge(base._meta) + + # Set collection in the meta if its callable + if (getattr(base, '_is_document', False) and + not base._meta.get('abstract')): + collection = meta.get('collection', None) + if callable(collection): + meta['collection'] = collection(base) + + meta.merge(attrs.get('_meta', {})) # Top level meta + + # Only simple classes (direct subclasses of Document) + # may set allow_inheritance to False + simple_class = all([b._meta.get('abstract') + for b in flattened_bases if hasattr(b, '_meta')]) + if (not simple_class and meta['allow_inheritance'] is False and + not meta['abstract']): + raise ValueError('Only direct subclasses of Document may set ' + '"allow_inheritance" to False') + + # Set default collection name + if 'collection' not in meta: + meta['collection'] = ''.join('_%s' % c if c.isupper() else c + for c in name).strip('_').lower() + attrs['_meta'] = meta + + # Call super and get the new class + new_class = super_new(cls, name, bases, attrs) + + meta = new_class._meta + + # Set index specifications + meta['index_specs'] = new_class._build_index_specs(meta['indexes']) + + # If collection is a callable - call it and set the value + collection = meta.get('collection') + if callable(collection): + new_class._meta['collection'] = collection(new_class) + + # Provide a default queryset unless exists or one has been set + if 'objects' not in dir(new_class): + new_class.objects = QuerySetManager() + + # Validate the fields and set primary key if needed + for field_name, field in new_class._fields.iteritems(): + if field.primary_key: + # Ensure only one primary key is set + current_pk = new_class._meta.get('id_field') + if current_pk and current_pk != field_name: + raise ValueError('Cannot override primary key field') + + # Set primary key + if not current_pk: + new_class._meta['id_field'] = field_name + new_class.id = field + + # Set primary key if not defined by the document + if not new_class._meta.get('id_field'): + new_class._meta['id_field'] = 'id' + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class._fields['id'].name = 'id' + new_class.id = new_class._fields['id'] + + # Merge in exceptions with parent hierarchy + exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) + module = attrs.get('__module__') + for exc in exceptions_to_merge: + name = exc.__name__ + parents = tuple(getattr(base, name) for base in flattened_bases + if hasattr(base, name)) or (exc,) + # Create new exception and set to new_class + exception = type(name, parents, {'__module__': module}) + setattr(new_class, name, exception) + + return new_class + + +class MetaDict(dict): + """Custom dictionary for meta classes. + Handles the merging of set indexes + """ + _merge_options = ('indexes',) + + def merge(self, new_options): + for k, v in new_options.iteritems(): + if k in self._merge_options: + self[k] = self.get(k, []) + v + else: + self[k] = v + + +class BasesTuple(tuple): + """Special class to handle introspection of bases tuple in __new__""" + pass diff --git a/mongoengine/common.py b/mongoengine/common.py new file mode 100644 index 00000000..718ac0b2 --- /dev/null +++ b/mongoengine/common.py @@ -0,0 +1,36 @@ +_class_registry_cache = {} + + +def _import_class(cls_name): + """Cached mechanism for imports""" + if cls_name in _class_registry_cache: + return _class_registry_cache.get(cls_name) + + doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', + 'MapReduceDocument') + field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', + 'FileField', 'GenericReferenceField', + 'GenericEmbeddedDocumentField', 'GeoPointField', + 'ReferenceField', 'StringField', 'ComplexBaseField') + queryset_classes = ('OperationError',) + deref_classes = ('DeReference',) + + if cls_name in doc_classes: + from mongoengine import document as module + import_classes = doc_classes + elif cls_name in field_classes: + from mongoengine import fields as module + import_classes = field_classes + elif cls_name in queryset_classes: + from mongoengine import queryset as module + import_classes = queryset_classes + elif cls_name in deref_classes: + from mongoengine import dereference as module + import_classes = deref_classes + else: + raise ValueError('No import set for: ' % cls_name) + + for cls in import_classes: + _class_registry_cache[cls] = getattr(module, cls) + + return _class_registry_cache.get(cls_name) \ No newline at end of file diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 1ccbbe31..3c53ea3c 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,5 @@ import pymongo -from pymongo import Connection, ReplicaSetConnection, uri_parser +from pymongo import MongoClient, MongoReplicaSetClient, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', @@ -28,8 +28,10 @@ def register_connection(alias, name, host='localhost', port=27017, :param name: the name of the specific database to use :param host: the host name of the :program:`mongod` instance to connect to :param port: the port that the :program:`mongod` instance is running on - :param is_slave: whether the connection can act as a slave ** Depreciated pymongo 2.0.1+ - :param read_preference: The read preference for the collection ** Added pymongo 2.1 + :param is_slave: whether the connection can act as a slave + ** Depreciated pymongo 2.0.1+ + :param read_preference: The read preference for the collection + ** Added pymongo 2.1 :param slaves: a list of aliases of slave connections; each of these must be a registered connection that has :attr:`is_slave` set to ``True`` :param username: username to authenticate with @@ -110,15 +112,15 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings['slaves'] = slaves conn_settings.pop('read_preference', None) - connection_class = Connection + connection_class = MongoClient if 'replicaSet' in conn_settings: conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) - # Discard port since it can't be used on ReplicaSetConnection + # Discard port since it can't be used on MongoReplicaSetClient conn_settings.pop('port', None) # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], basestring): conn_settings.pop('replicaSet', None) - connection_class = ReplicaSetConnection + connection_class = MongoReplicaSetClient try: _connections[alias] = connection_class(**conn_settings) @@ -161,6 +163,7 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): return get_connection(alias) + # Support old naming convention _get_connection = get_connection _get_db = get_db diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py new file mode 100644 index 00000000..76d5fbfa --- /dev/null +++ b/mongoengine/context_managers.py @@ -0,0 +1,194 @@ +from mongoengine.common import _import_class +from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.queryset import OperationError, QuerySet + +__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter") + + +class switch_db(object): + """ switch_db alias context manager. + + Example :: + + # Register connections + register_connection('default', 'mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group(name="test").save() # Saves in the default db + + with switch_db(Group, 'testdb-1') as Group: + Group(name="hello testdb!").save() # Saves in testdb-1 + + """ + + def __init__(self, cls, db_alias): + """ Construct the switch_db context manager + + :param cls: the class to change the registered db + :param db_alias: the name of the specific database to use + """ + self.cls = cls + self.collection = cls._get_collection() + self.db_alias = db_alias + self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) + + def __enter__(self): + """ change the db_alias and clear the cached collection """ + self.cls._meta["db_alias"] = self.db_alias + self.cls._collection = None + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the db_alias and collection """ + self.cls._meta["db_alias"] = self.ori_db_alias + self.cls._collection = self.collection + + +class switch_collection(object): + """ switch_collection alias context manager. + + Example :: + + class Group(Document): + name = StringField() + + Group(name="test").save() # Saves in the default db + + with switch_collection(Group, 'group1') as Group: + Group(name="hello testdb!").save() # Saves in group1 collection + + """ + + def __init__(self, cls, collection_name): + """ Construct the switch_collection context manager + + :param cls: the class to change the registered db + :param collection_name: the name of the collection to use + """ + self.cls = cls + self.ori_collection = cls._get_collection() + self.ori_get_collection_name = cls._get_collection_name + self.collection_name = collection_name + + def __enter__(self): + """ change the _get_collection_name and clear the cached collection """ + + @classmethod + def _get_collection_name(cls): + return self.collection_name + + self.cls._get_collection_name = _get_collection_name + self.cls._collection = None + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the collection """ + self.cls._collection = self.ori_collection + self.cls._get_collection_name = self.ori_get_collection_name + + +class no_dereference(object): + """ no_dereference context manager. + + Turns off all dereferencing in Documents for the duration of the context + manager:: + + with no_dereference(Group) as Group: + Group.objects.find() + + """ + + def __init__(self, cls): + """ Construct the no_dereference context manager. + + :param cls: the class to turn dereferencing off on + """ + self.cls = cls + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + ComplexBaseField = _import_class('ComplexBaseField') + + self.deref_fields = [k for k, v in self.cls._fields.iteritems() + if isinstance(v, (ReferenceField, + GenericReferenceField, + ComplexBaseField))] + + def __enter__(self): + """ change the objects default and _auto_dereference values""" + for field in self.deref_fields: + self.cls._fields[field]._auto_dereference = False + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the default and _auto_dereference values""" + for field in self.deref_fields: + self.cls._fields[field]._auto_dereference = True + return self.cls + + +class QuerySetNoDeRef(QuerySet): + """Special no_dereference QuerySet""" + def __dereference(items, max_depth=1, instance=None, name=None): + return items + + +class query_counter(object): + """ Query_counter context manager to get the number of queries. """ + + def __init__(self): + """ Construct the query_counter. """ + self.counter = 0 + self.db = get_db() + + def __enter__(self): + """ On every with block we need to drop the profile collection. """ + self.db.set_profiling_level(0) + self.db.system.profile.drop() + self.db.set_profiling_level(2) + return self + + def __exit__(self, t, value, traceback): + """ Reset the profiling level. """ + self.db.set_profiling_level(0) + + def __eq__(self, value): + """ == Compare querycounter. """ + return value == self._get_count() + + def __ne__(self, value): + """ != Compare querycounter. """ + return not self.__eq__(value) + + def __lt__(self, value): + """ < Compare querycounter. """ + return self._get_count() < value + + def __le__(self, value): + """ <= Compare querycounter. """ + return self._get_count() <= value + + def __gt__(self, value): + """ > Compare querycounter. """ + return self._get_count() > value + + def __ge__(self, value): + """ >= Compare querycounter. """ + return self._get_count() >= value + + def __int__(self): + """ int representation. """ + return self._get_count() + + def __repr__(self): + """ repr query_counter as the number of queries. """ + return u"%s" % self._get_count() + + def _get_count(self): + """ Get the number of queries. """ + count = self.db.system.profile.find().count() - self.counter + self.counter += 1 + return count diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index fcb6d89a..e5e8886b 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -33,7 +33,7 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and instance._fields: + if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) if hasattr(doc_type, 'field'): doc_type = doc_type.field @@ -84,7 +84,7 @@ class DeReference(object): # Recursively find dbreferences depth += 1 for k, item in iterator: - if hasattr(item, '_fields'): + if isinstance(item, Document): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) if isinstance(v, (DBRef)): @@ -115,13 +115,16 @@ class DeReference(object): object_map = {} for col, dbrefs in self.reference_map.iteritems(): keys = object_map.keys() - refs = list(set([dbref for dbref in dbrefs if str(dbref) not in keys])) + refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys])) if hasattr(col, 'objects'): # We have a document class for the refs references = col.objects.in_bulk(refs) for key, doc in references.iteritems(): object_map[key] = doc else: # Generic reference: use the refs data to convert to document - if doc_type and not isinstance(doc_type, (ListField, DictField, MapField,) ): + if isinstance(doc_type, (ListField, DictField, MapField,)): + continue + + if doc_type: references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) for ref in references: doc = doc_type._from_son(ref) @@ -164,7 +167,7 @@ class DeReference(object): if isinstance(items, (dict, SON)): if '_ref' in items: return self.object_map.get(items['_ref'].id, items) - elif '_types' in items and '_cls' in items: + elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) doc._data = self._attach_objects(doc._data, depth, doc, None) return doc @@ -188,7 +191,7 @@ class DeReference(object): if k in self.object_map and not is_list: data[k] = self.object_map[k] - elif hasattr(v, '_fields'): + elif isinstance(v, Document): for field_name, field in v._fields.iteritems(): v = data[k]._data.get(field_name, None) if isinstance(v, (DBRef)): diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index 65afacfd..cff4b743 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -1,10 +1,10 @@ -import datetime - from mongoengine import * from django.utils.encoding import smart_str -from django.contrib.auth.models import _user_get_all_permissions -from django.contrib.auth.models import _user_has_perm +from django.contrib.auth.models import _user_has_perm, _user_get_all_permissions, _user_has_module_perms +from django.db import models +from django.contrib.contenttypes.models import ContentTypeManager +from django.contrib import auth from django.contrib.auth.models import AnonymousUser from django.utils.translation import ugettext_lazy as _ @@ -33,9 +33,172 @@ except ImportError: hash = get_hexdigest(algo, salt, raw_password) return '%s$%s$%s' % (algo, salt, hash) +from .utils import datetime_now REDIRECT_FIELD_NAME = 'next' + +class ContentType(Document): + name = StringField(max_length=100) + app_label = StringField(max_length=100) + model = StringField(max_length=100, verbose_name=_('python model class name'), + unique_with='app_label') + objects = ContentTypeManager() + + class Meta: + verbose_name = _('content type') + verbose_name_plural = _('content types') + # db_table = 'django_content_type' + # ordering = ('name',) + # unique_together = (('app_label', 'model'),) + + def __unicode__(self): + return self.name + + def model_class(self): + "Returns the Python model class for this type of content." + from django.db import models + return models.get_model(self.app_label, self.model) + + def get_object_for_this_type(self, **kwargs): + """ + Returns an object of this type for the keyword arguments given. + Basically, this is a proxy around this object_type's get_object() model + method. The ObjectNotExist exception, if thrown, will not be caught, + so code that calls this method should catch it. + """ + return self.model_class()._default_manager.using(self._state.db).get(**kwargs) + + def natural_key(self): + return (self.app_label, self.model) + + +class SiteProfileNotAvailable(Exception): + pass + + +class PermissionManager(models.Manager): + def get_by_natural_key(self, codename, app_label, model): + return self.get( + codename=codename, + content_type=ContentType.objects.get_by_natural_key(app_label, model) + ) + + +class Permission(Document): + """The permissions system provides a way to assign permissions to specific + users and groups of users. + + The permission system is used by the Django admin site, but may also be + useful in your own code. The Django admin site uses permissions as follows: + + - The "add" permission limits the user's ability to view the "add" + form and add an object. + - The "change" permission limits a user's ability to view the change + list, view the "change" form and change an object. + - The "delete" permission limits the ability to delete an object. + + Permissions are set globally per type of object, not per specific object + instance. It is possible to say "Mary may change news stories," but it's + not currently possible to say "Mary may change news stories, but only the + ones she created herself" or "Mary may only change news stories that have + a certain status or publication date." + + Three basic permissions -- add, change and delete -- are automatically + created for each Django model. + """ + name = StringField(max_length=50, verbose_name=_('username')) + content_type = ReferenceField(ContentType) + codename = StringField(max_length=100, verbose_name=_('codename')) + # FIXME: don't access field of the other class + # unique_with=['content_type__app_label', 'content_type__model']) + + objects = PermissionManager() + + class Meta: + verbose_name = _('permission') + verbose_name_plural = _('permissions') + # unique_together = (('content_type', 'codename'),) + # ordering = ('content_type__app_label', 'content_type__model', 'codename') + + def __unicode__(self): + return u"%s | %s | %s" % ( + unicode(self.content_type.app_label), + unicode(self.content_type), + unicode(self.name)) + + def natural_key(self): + return (self.codename,) + self.content_type.natural_key() + natural_key.dependencies = ['contenttypes.contenttype'] + + +class Group(Document): + """Groups are a generic way of categorizing users to apply permissions, + or some other label, to those users. A user can belong to any number of + groups. + + A user in a group automatically has all the permissions granted to that + group. For example, if the group Site editors has the permission + can_edit_home_page, any user in that group will have that permission. + + Beyond permissions, groups are a convenient way to categorize users to + apply some label, or extended functionality, to them. For example, you + could create a group 'Special users', and you could write code that would + do special things to those users -- such as giving them access to a + members-only portion of your site, or sending them members-only + e-mail messages. + """ + name = StringField(max_length=80, unique=True, verbose_name=_('name')) + permissions = ListField(ReferenceField(Permission, verbose_name=_('permissions'), required=False)) + + class Meta: + verbose_name = _('group') + verbose_name_plural = _('groups') + + def __unicode__(self): + return self.name + + +class UserManager(models.Manager): + def create_user(self, username, email, password=None): + """ + Creates and saves a User with the given username, e-mail and password. + """ + now = datetime_now() + + # Normalize the address by lowercasing the domain part of the email + # address. + try: + email_name, domain_part = email.strip().split('@', 1) + except ValueError: + pass + else: + email = '@'.join([email_name, domain_part.lower()]) + + user = self.model(username=username, email=email, is_staff=False, + is_active=True, is_superuser=False, last_login=now, + date_joined=now) + + user.set_password(password) + user.save(using=self._db) + return user + + def create_superuser(self, username, email, password): + u = self.create_user(username, email, password) + u.is_staff = True + u.is_active = True + u.is_superuser = True + u.save(using=self._db) + return u + + def make_random_password(self, length=10, allowed_chars='abcdefghjkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789'): + "Generates a random password with the given length and given allowed_chars" + # Note that default value of allowed_chars does not have "I" or letters + # that look like it -- just to avoid confusion. + from random import choice + return ''.join([choice(allowed_chars) for i in range(length)]) + + class User(Document): """A User document that aims to mirror most of the API specified by Django at http://docs.djangoproject.com/en/dev/topics/auth/#users @@ -62,15 +225,18 @@ class User(Document): is_superuser = BooleanField(default=False, verbose_name=_('superuser status'), help_text=_("Designates that this user has all permissions without explicitly assigning them.")) - last_login = DateTimeField(default=datetime.datetime.now, + last_login = DateTimeField(default=datetime_now, verbose_name=_('last login')) - date_joined = DateTimeField(default=datetime.datetime.now, + date_joined = DateTimeField(default=datetime_now, verbose_name=_('date joined')) + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = ['email'] + meta = { 'allow_inheritance': True, 'indexes': [ - {'fields': ['username'], 'unique': True} + {'fields': ['username'], 'unique': True, 'sparse': True} ] } @@ -106,6 +272,40 @@ class User(Document): """ return check_password(raw_password, self.password) + @classmethod + def create_user(cls, username, password, email=None): + """Create (and save) a new user with the given username, password and + email address. + """ + now = datetime_now() + + # Normalize the address by lowercasing the domain part of the email + # address. + if email is not None: + try: + email_name, domain_part = email.strip().split('@', 1) + except ValueError: + pass + else: + email = '@'.join([email_name, domain_part.lower()]) + + user = cls(username=username, email=email, date_joined=now) + user.set_password(password) + user.save() + return user + + def get_group_permissions(self, obj=None): + """ + Returns a list of permission strings that this user has through his/her + groups. This method queries all available auth backends. If an object + is passed in, only permissions matching this object are returned. + """ + permissions = set() + for backend in auth.get_backends(): + if hasattr(backend, "get_group_permissions"): + permissions.update(backend.get_group_permissions(self, obj)) + return permissions + def get_all_permissions(self, obj=None): return _user_get_all_permissions(self, obj) @@ -125,30 +325,50 @@ class User(Document): # Otherwise we need to check the backends. return _user_has_perm(self, perm, obj) - @classmethod - def create_user(cls, username, password, email=None): - """Create (and save) a new user with the given username, password and - email address. + def has_module_perms(self, app_label): """ - now = datetime.datetime.now() + Returns True if the user has any permissions in the given app label. + Uses pretty much the same logic as has_perm, above. + """ + # Active superusers have all permissions. + if self.is_active and self.is_superuser: + return True - # Normalize the address by lowercasing the domain part of the email - # address. - if email is not None: + return _user_has_module_perms(self, app_label) + + def email_user(self, subject, message, from_email=None): + "Sends an e-mail to this User." + from django.core.mail import send_mail + send_mail(subject, message, from_email, [self.email]) + + def get_profile(self): + """ + Returns site-specific profile for this user. Raises + SiteProfileNotAvailable if this site does not allow profiles. + """ + if not hasattr(self, '_profile_cache'): + from django.conf import settings + if not getattr(settings, 'AUTH_PROFILE_MODULE', False): + raise SiteProfileNotAvailable('You need to set AUTH_PROFILE_MO' + 'DULE in your project settings') try: - email_name, domain_part = email.strip().split('@', 1) + app_label, model_name = settings.AUTH_PROFILE_MODULE.split('.') except ValueError: - pass - else: - email = '@'.join([email_name, domain_part.lower()]) + raise SiteProfileNotAvailable('app_label and model_name should' + ' be separated by a dot in the AUTH_PROFILE_MODULE set' + 'ting') - user = cls(username=username, email=email, date_joined=now) - user.set_password(password) - user.save() - return user - - def get_and_delete_messages(self): - return [] + try: + model = models.get_model(app_label, model_name) + if model is None: + raise SiteProfileNotAvailable('Unable to load the profile ' + 'model, check AUTH_PROFILE_MODULE in your project sett' + 'ings') + self._profile_cache = model._default_manager.using(self._state.db).get(user__id__exact=self.id) + self._profile_cache.user = self + except (ImportError, ImproperlyConfigured): + raise SiteProfileNotAvailable + return self._profile_cache class MongoEngineBackend(object): @@ -163,6 +383,8 @@ class MongoEngineBackend(object): user = User.objects(username=username).first() if user: if password and user.check_password(password): + backend = auth.get_backends()[0] + user.backend = "%s.%s" % (backend.__module__, backend.__class__.__name__) return user return None diff --git a/mongoengine/django/mongo_auth/__init__.py b/mongoengine/django/mongo_auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mongoengine/django/mongo_auth/models.py b/mongoengine/django/mongo_auth/models.py new file mode 100644 index 00000000..9629e644 --- /dev/null +++ b/mongoengine/django/mongo_auth/models.py @@ -0,0 +1,90 @@ +from importlib import import_module + +from django.conf import settings +from django.contrib.auth.models import UserManager +from django.core.exceptions import ImproperlyConfigured +from django.db import models +from django.utils.translation import ugettext_lazy as _ + + +MONGOENGINE_USER_DOCUMENT = getattr( + settings, 'MONGOENGINE_USER_DOCUMENT', 'mongoengine.django.auth.User') + + +class MongoUserManager(UserManager): + """A User manager wich allows the use of MongoEngine documents in Django. + + To use the manager, you must tell django.contrib.auth to use MongoUser as + the user model. In you settings.py, you need: + + INSTALLED_APPS = ( + ... + 'django.contrib.auth', + 'mongoengine.django.mongo_auth', + ... + ) + AUTH_USER_MODEL = 'mongo_auth.MongoUser' + + Django will use the model object to access the custom Manager, which will + replace the original queryset with MongoEngine querysets. + + By default, mongoengine.django.auth.User will be used to store users. You + can specify another document class in MONGOENGINE_USER_DOCUMENT in your + settings.py. + + The User Document class has the same requirements as a standard custom user + model: https://docs.djangoproject.com/en/dev/topics/auth/customizing/ + + In particular, the User Document class must define USERNAME_FIELD and + REQUIRED_FIELDS. + + `AUTH_USER_MODEL` has been added in Django 1.5. + + """ + + def contribute_to_class(self, model, name): + super(MongoUserManager, self).contribute_to_class(model, name) + self.dj_model = self.model + self.model = self._get_user_document() + + self.dj_model.USERNAME_FIELD = self.model.USERNAME_FIELD + username = models.CharField(_('username'), max_length=30, unique=True) + username.contribute_to_class(self.dj_model, self.dj_model.USERNAME_FIELD) + + self.dj_model.REQUIRED_FIELDS = self.model.REQUIRED_FIELDS + for name in self.dj_model.REQUIRED_FIELDS: + field = models.CharField(_(name), max_length=30) + field.contribute_to_class(self.dj_model, name) + + def _get_user_document(self): + try: + name = MONGOENGINE_USER_DOCUMENT + dot = name.rindex('.') + module = import_module(name[:dot]) + return getattr(module, name[dot + 1:]) + except ImportError: + raise ImproperlyConfigured("Error importing %s, please check " + "settings.MONGOENGINE_USER_DOCUMENT" + % name) + + def get(self, *args, **kwargs): + try: + return self.get_query_set().get(*args, **kwargs) + except self.model.DoesNotExist: + # ModelBackend expects this exception + raise self.dj_model.DoesNotExist + + @property + def db(self): + raise NotImplementedError + + def get_empty_query_set(self): + return self.model.objects.none() + + def get_query_set(self): + return self.model.objects + + +class MongoUser(models.Model): + objects = MongoUserManager() + diff --git a/mongoengine/django/sessions.py b/mongoengine/django/sessions.py index 810b6265..29583f5c 100644 --- a/mongoengine/django/sessions.py +++ b/mongoengine/django/sessions.py @@ -1,5 +1,3 @@ -from datetime import datetime - from django.conf import settings from django.contrib.sessions.backends.base import SessionBase, CreateError from django.core.exceptions import SuspiciousOperation @@ -10,6 +8,8 @@ from mongoengine import fields from mongoengine.queryset import OperationError from mongoengine.connection import DEFAULT_CONNECTION_NAME +from .utils import datetime_now + MONGOENGINE_SESSION_DB_ALIAS = getattr( settings, 'MONGOENGINE_SESSION_DB_ALIAS', @@ -25,15 +25,27 @@ MONGOENGINE_SESSION_DATA_ENCODE = getattr( settings, 'MONGOENGINE_SESSION_DATA_ENCODE', True) + class MongoSession(Document): session_key = fields.StringField(primary_key=True, max_length=40) session_data = fields.StringField() if MONGOENGINE_SESSION_DATA_ENCODE \ else fields.DictField() expire_date = fields.DateTimeField() - meta = {'collection': MONGOENGINE_SESSION_COLLECTION, - 'db_alias': MONGOENGINE_SESSION_DB_ALIAS, - 'allow_inheritance': False} + meta = { + 'collection': MONGOENGINE_SESSION_COLLECTION, + 'db_alias': MONGOENGINE_SESSION_DB_ALIAS, + 'allow_inheritance': False, + 'indexes': [ + { + 'fields': ['expire_date'], + 'expireAfterSeconds': settings.SESSION_COOKIE_AGE + } + ] + } + + def get_decoded(self): + return SessionStore().decode(self.session_data) class SessionStore(SessionBase): @@ -43,7 +55,7 @@ class SessionStore(SessionBase): def load(self): try: s = MongoSession.objects(session_key=self.session_key, - expire_date__gt=datetime.now())[0] + expire_date__gt=datetime_now)[0] if MONGOENGINE_SESSION_DATA_ENCODE: return self.decode(force_unicode(s.session_data)) else: @@ -76,7 +88,7 @@ class SessionStore(SessionBase): s.session_data = self._get_session(no_load=must_create) s.expire_date = self.get_expiry_date() try: - s.save(force_insert=must_create, safe=True) + s.save(force_insert=must_create) except OperationError: if must_create: raise CreateError diff --git a/mongoengine/django/shortcuts.py b/mongoengine/django/shortcuts.py index 637cee15..9cc8370b 100644 --- a/mongoengine/django/shortcuts.py +++ b/mongoengine/django/shortcuts.py @@ -1,6 +1,6 @@ from mongoengine.queryset import QuerySet from mongoengine.base import BaseDocument -from mongoengine.base import ValidationError +from mongoengine.errors import ValidationError def _get_queryset(cls): """Inspired by django.shortcuts.*""" diff --git a/mongoengine/django/utils.py b/mongoengine/django/utils.py new file mode 100644 index 00000000..d3ef8a4b --- /dev/null +++ b/mongoengine/django/utils.py @@ -0,0 +1,6 @@ +try: + # django >= 1.4 + from django.utils.timezone import now as datetime_now +except ImportError: + from datetime import datetime + datetime_now = datetime.now diff --git a/mongoengine/document.py b/mongoengine/document.py index a251f589..bd6ce191 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,19 +1,21 @@ +from __future__ import with_statement import warnings import pymongo import re from bson.dbref import DBRef -from mongoengine import signals, queryset +from mongoengine import signals +from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, + BaseDocument, BaseDict, BaseList, + ALLOW_INHERITANCE, get_document) +from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME +from mongoengine.context_managers import switch_db, switch_collection -from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, - BaseDict, BaseList) -from queryset import OperationError, NotUniqueError -from connection import get_db, DEFAULT_CONNECTION_NAME - -__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', +__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', - 'InvalidCollectionError', 'NotUniqueError'] + 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') class InvalidCollectionError(Exception): @@ -28,11 +30,11 @@ class EmbeddedDocument(BaseDocument): A :class:`~mongoengine.EmbeddedDocument` subclass may be itself subclassed, to create a specialised version of the embedded document that will be - stored in the same collection. To facilitate this behaviour, `_cls` and - `_types` fields are added to documents (hidden though the MongoEngine - interface though). To disable this behaviour and remove the dependence on - the presence of `_cls` and `_types`, set :attr:`allow_inheritance` to - ``False`` in the :attr:`meta` dictionary. + stored in the same collection. To facilitate this behaviour a `_cls` + field is added to documents (hidden though the MongoEngine interface). + To disable this behaviour and remove the dependence on the presence of + `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` + dictionary. """ # The __metaclass__ attribute is removed by 2to3 when running with Python3 @@ -40,21 +42,12 @@ class EmbeddedDocument(BaseDocument): my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass + _instance = None + def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) self._changed_fields = [] - def __delattr__(self, *args, **kwargs): - """Handle deletions of fields""" - field_name = args[0] - if field_name in self._fields: - default = self._fields[field_name].default - if callable(default): - default = default() - setattr(self, field_name, default) - else: - super(EmbeddedDocument, self).__delattr__(*args, **kwargs) - def __eq__(self, other): if isinstance(other, self.__class__): return self._data == other._data @@ -76,11 +69,11 @@ class Document(BaseDocument): A :class:`~mongoengine.Document` subclass may be itself subclassed, to create a specialised version of the document that will be stored in the - same collection. To facilitate this behaviour, `_cls` and `_types` - fields are added to documents (hidden though the MongoEngine interface - though). To disable this behaviour and remove the dependence on the - presence of `_cls` and `_types`, set :attr:`allow_inheritance` to - ``False`` in the :attr:`meta` dictionary. + same collection. To facilitate this behaviour a `_cls` + field is added to documents (hidden though the MongoEngine interface). + To disable this behaviour and remove the dependence on the presence of + `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` + dictionary. A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` @@ -98,13 +91,13 @@ class Document(BaseDocument): Automatic index creation can be disabled by specifying attr:`auto_create_index` in the :attr:`meta` dictionary. If this is set to False then indexes will not be created by MongoEngine. This is useful in - production systems where index creation is performed as part of a deployment - system. + production systems where index creation is performed as part of a + deployment system. - By default, _types will be added to the start of every index (that + By default, _cls will be added to the start of every index (that doesn't contain a list) if allow_inheritance is True. This can be - disabled by either setting types to False on the specific index or - by setting index_types to False on the meta dictionary for the document. + disabled by either setting cls to False on the specific index or + by setting index_cls to False on the meta dictionary for the document. """ # The __metaclass__ attribute is removed by 2to3 when running with Python3 @@ -117,6 +110,7 @@ class Document(BaseDocument): """ def fget(self): return getattr(self, self._meta['id_field']) + def fset(self, value): return setattr(self, self._meta['id_field'], value) return property(fget, fset) @@ -125,7 +119,7 @@ class Document(BaseDocument): @classmethod def _get_db(cls): """Some Model using other db_alias""" - return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME )) + return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) @classmethod def _get_collection(cls): @@ -148,7 +142,7 @@ class Document(BaseDocument): options.get('size') != max_size: msg = (('Cannot create collection "%s" as a capped ' 'collection as it already exists') - % cls._collection) + % cls._collection) raise InvalidCollectionError(msg) else: # Create the collection as a capped collection @@ -160,28 +154,28 @@ class Document(BaseDocument): ) else: cls._collection = db[collection_name] + if cls._meta.get('auto_create_index', True): + cls.ensure_indexes() return cls._collection - def save(self, safe=True, force_insert=False, validate=True, - write_options=None, cascade=None, cascade_kwargs=None, + def save(self, force_insert=False, validate=True, clean=True, + write_concern=None, cascade=None, cascade_kwargs=None, _refs=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an - :class:`~mongoengine.OperationError` will be raised. - - :param safe: check if the operation succeeded before returning :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. - :param write_options: Extra keyword arguments are passed down to + :param clean: call the document clean method, requires `validate` to be + True. + :param write_concern: Extra keyword arguments are passed down to :meth:`~pymongo.collection.Collection.save` OR :meth:`~pymongo.collection.Collection.insert` which will be used as options for the resultant ``getLastError`` command. For example, - ``save(..., write_options={w: 2, fsync: True}, ...)`` will + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. :param cascade: Sets the flag for cascading saves. You can set a @@ -205,24 +199,22 @@ class Document(BaseDocument): signals.pre_save.send(self.__class__, document=self) if validate: - self.validate() + self.validate(clean=clean) - if not write_options: - write_options = {} + if not write_concern: + write_concern = {} doc = self.to_mongo() - created = force_insert or '_id' not in doc + created = ('_id' not in doc or self._created or force_insert) try: - collection = self.__class__.objects._collection + collection = self._get_collection() if created: if force_insert: - object_id = collection.insert(doc, safe=safe, - **write_options) + object_id = collection.insert(doc, **write_concern) else: - object_id = collection.save(doc, safe=safe, - **write_options) + object_id = collection.save(doc, **write_concern) else: object_id = doc['_id'] updates, removals = self._delta() @@ -233,29 +225,38 @@ class Document(BaseDocument): actual_key = self._db_field_map.get(k, k) select_dict[actual_key] = doc[actual_key] - upsert = self._created - if updates: - collection.update(select_dict, {"$set": updates}, - upsert=upsert, safe=safe, **write_options) - if removals: - collection.update(select_dict, {"$unset": removals}, - upsert=upsert, safe=safe, **write_options) + def is_new_object(last_error): + if last_error is not None: + updated = last_error.get("updatedExisting") + if updated is not None: + return not updated + return created + + upsert = self._created + update_query = {} + + if updates: + update_query["$set"] = updates + if removals: + update_query["$unset"] = removals + if updates or removals: + last_error = collection.update(select_dict, update_query, + upsert=upsert, **write_concern) + created = is_new_object(last_error) - warn_cascade = not cascade and 'cascade' not in self._meta cascade = (self._meta.get('cascade', True) if cascade is None else cascade) if cascade: kwargs = { - "safe": safe, "force_insert": force_insert, "validate": validate, - "write_options": write_options, + "write_concern": write_concern, "cascade": cascade } if cascade_kwargs: # Allow granular control over cascades kwargs.update(cascade_kwargs) kwargs['_refs'] = _refs - self.cascade_save(warn_cascade=warn_cascade, **kwargs) + self.cascade_save(**kwargs) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' @@ -269,12 +270,12 @@ class Document(BaseDocument): if id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) - self._changed_fields = [] + self._clear_changed_fields() self._created = False signals.post_save.send(self.__class__, document=self, created=created) return self - def cascade_save(self, warn_cascade=None, *args, **kwargs): + def cascade_save(self, *args, **kwargs): """Recursively saves any references / generic references on an objects""" import fields @@ -294,15 +295,20 @@ class Document(BaseDocument): ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) if ref and ref_id not in _refs: - if warn_cascade: - msg = ("Cascading saves will default to off in 0.8, " - "please explicitly set `.save(cascade=True)`") - warnings.warn(msg, FutureWarning) _refs.append(ref_id) kwargs["_refs"] = _refs ref.save(**kwargs) ref._changed_fields = [] + @property + def _qs(self): + """ + Returns the queryset to use for updating / reloading / deletions + """ + if not hasattr(self, '__objects'): + self.__objects = QuerySet(self, self._get_collection()) + return self.__objects + @property def _object_key(self): """Dict to identify object in collection @@ -324,24 +330,80 @@ class Document(BaseDocument): raise OperationError('attempt to update a document not yet saved') # Need to add shard key to query, or you get an error - return self.__class__.objects(**self._object_key).update_one(**kwargs) + return self._qs.filter(**self._object_key).update_one(**kwargs) - def delete(self, safe=False): + def delete(self, **write_concern): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. - :param safe: check if the operation succeeded before returning + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. """ signals.pre_delete.send(self.__class__, document=self) try: - self.__class__.objects(**self._object_key).delete(safe=safe) + self._qs.filter(**self._object_key).delete(write_concern=write_concern) except pymongo.errors.OperationFailure, err: message = u'Could not delete document (%s)' % err.message raise OperationError(message) signals.post_delete.send(self.__class__, document=self) + def switch_db(self, db_alias): + """ + Temporarily switch the database for a document instance. + + Only really useful for archiving off data and calling `save()`:: + + user = User.objects.get(id=user_id) + user.switch_db('archive-db') + user.save() + + If you need to read from another database see + :class:`~mongoengine.context_managers.switch_db` + + :param db_alias: The database alias to use for saving the document + """ + with switch_db(self.__class__, db_alias) as cls: + collection = cls._get_collection() + db = cls._get_db + self._get_collection = lambda: collection + self._get_db = lambda: db + self._collection = collection + self._created = True + self.__objects = self._qs + self.__objects._collection_obj = collection + return self + + def switch_collection(self, collection_name): + """ + Temporarily switch the collection for a document instance. + + Only really useful for archiving off data and calling `save()`:: + + user = User.objects.get(id=user_id) + user.switch_collection('old-users') + user.save() + + If you need to read from another database see + :class:`~mongoengine.context_managers.switch_collection` + + :param collection_name: The database alias to use for saving the + document + """ + with switch_collection(self.__class__, collection_name) as cls: + collection = cls._get_collection() + self._get_collection = lambda: collection + self._collection = collection + self._created = True + self.__objects = self._qs + self.__objects._collection_obj = collection + return self + def select_related(self, max_depth=1): """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. @@ -359,9 +421,8 @@ class Document(BaseDocument): .. versionchanged:: 0.6 Now chainable """ id_field = self._meta['id_field'] - obj = self.__class__.objects( - **{id_field: self[id_field]} - ).limit(1).select_related(max_depth=max_depth) + obj = self._qs.filter(**{id_field: self[id_field]} + ).limit(1).select_related(max_depth=max_depth) if obj: obj = obj[0] else: @@ -373,6 +434,7 @@ class Document(BaseDocument): for name in self._dynamic_fields.keys(): setattr(self, name, self._reload(name, obj._data[name])) self._changed_fields = obj._changed_fields + self._created = False return obj def _reload(self, key, value): @@ -402,18 +464,93 @@ class Document(BaseDocument): """This method registers the delete rules to apply when removing this object. """ - delete_rules = cls._meta.get('delete_rules') or {} - delete_rules[(document_cls, field_name)] = rule - cls._meta['delete_rules'] = delete_rules + classes = [get_document(class_name) + for class_name in cls._subclasses + if class_name != cls.__name__] + [cls] + documents = [get_document(class_name) + for class_name in document_cls._subclasses + if class_name != document_cls.__name__] + [document_cls] + + for cls in classes: + for document_cls in documents: + delete_rules = cls._meta.get('delete_rules') or {} + delete_rules[(document_cls, field_name)] = rule + cls._meta['delete_rules'] = delete_rules @classmethod def drop_collection(cls): """Drops the entire collection associated with this :class:`~mongoengine.Document` type from the database. """ + cls._collection = None db = cls._get_db() db.drop_collection(cls._get_collection_name()) - queryset.QuerySet._reset_already_indexed(cls) + + @classmethod + def ensure_index(cls, key_or_list, drop_dups=False, background=False, + **kwargs): + """Ensure that the given indexes are in place. + + :param key_or_list: a single index key or a list of index keys (to + construct a multi-field index); keys may be prefixed with a **+** + or a **-** to determine the index ordering + """ + index_spec = cls._build_index_spec(key_or_list) + index_spec = index_spec.copy() + fields = index_spec.pop('fields') + index_spec['drop_dups'] = drop_dups + index_spec['background'] = background + index_spec.update(kwargs) + + return cls._get_collection().ensure_index(fields, **index_spec) + + @classmethod + def ensure_indexes(cls): + """Checks the document meta data and ensures all the indexes exist. + + .. note:: You can disable automatic index creation by setting + `auto_create_index` to False in the documents meta data + """ + background = cls._meta.get('index_background', False) + drop_dups = cls._meta.get('index_drop_dups', False) + index_opts = cls._meta.get('index_opts') or {} + index_cls = cls._meta.get('index_cls', True) + + collection = cls._get_collection() + + # determine if an index which we are creating includes + # _cls as its first field; if so, we can avoid creating + # an extra index on _cls, as mongodb will use the existing + # index to service queries against _cls + cls_indexed = False + + def includes_cls(fields): + first_field = None + if len(fields): + if isinstance(fields[0], basestring): + first_field = fields[0] + elif isinstance(fields[0], (list, tuple)) and len(fields[0]): + first_field = fields[0][0] + return first_field == '_cls' + + # Ensure document-defined indexes are created + if cls._meta['index_specs']: + index_spec = cls._meta['index_specs'] + for spec in index_spec: + spec = spec.copy() + fields = spec.pop('fields') + cls_indexed = cls_indexed or includes_cls(fields) + opts = index_opts.copy() + opts.update(spec) + collection.ensure_index(fields, background=background, + drop_dups=drop_dups, **opts) + + # If _cls is being used (for polymorphism), it needs an index, + # only if another index doesn't begin with _cls + if (index_cls and not cls_indexed and + cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): + collection.ensure_index('_cls', background=background, + **index_opts) class DynamicDocument(Document): @@ -422,7 +559,7 @@ class DynamicDocument(Document): way as an ordinary document but has expando style properties. Any data passed or set against the :class:`~mongoengine.DynamicDocument` that is not a field is automatically converted into a - :class:`~mongoengine.DynamicField` and data can be attributed to that + :class:`~mongoengine.fields.DynamicField` and data can be attributed to that field. .. note:: @@ -464,7 +601,13 @@ class DynamicEmbeddedDocument(EmbeddedDocument): """Deletes the attribute by setting to None and allowing _delta to unset it""" field_name = args[0] - setattr(self, field_name, None) + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + setattr(self, field_name, None) class MapReduceDocument(object): diff --git a/mongoengine/errors.py b/mongoengine/errors.py new file mode 100644 index 00000000..4b6b562c --- /dev/null +++ b/mongoengine/errors.py @@ -0,0 +1,126 @@ +from collections import defaultdict + +from mongoengine.python_support import txt_type + + +__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', + 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', + 'OperationError', 'NotUniqueError', 'ValidationError') + + +class NotRegistered(Exception): + pass + + +class InvalidDocumentError(Exception): + pass + + +class LookUpError(AttributeError): + pass + + +class DoesNotExist(Exception): + pass + + +class MultipleObjectsReturned(Exception): + pass + + +class InvalidQueryError(Exception): + pass + + +class OperationError(Exception): + pass + + +class NotUniqueError(OperationError): + pass + + +class ValidationError(AssertionError): + """Validation exception. + + May represent an error validating a field or a + document containing fields with validation errors. + + :ivar errors: A dictionary of errors for fields within this + document or list, or None if the error is for an + individual field. + """ + + errors = {} + field_name = None + _message = None + + def __init__(self, message="", **kwargs): + self.errors = kwargs.get('errors', {}) + self.field_name = kwargs.get('field_name') + self.message = message + + def __str__(self): + return txt_type(self.message) + + def __repr__(self): + return '%s(%s,)' % (self.__class__.__name__, self.message) + + def __getattribute__(self, name): + message = super(ValidationError, self).__getattribute__(name) + if name == 'message': + if self.field_name: + message = '%s' % message + if self.errors: + message = '%s(%s)' % (message, self._format_errors()) + return message + + def _get_message(self): + return self._message + + def _set_message(self, message): + self._message = message + + message = property(_get_message, _set_message) + + def to_dict(self): + """Returns a dictionary of all errors within a document + + Keys are field names or list indices and values are the + validation error messages, or a nested dictionary of + errors for an embedded document or list. + """ + + def build_dict(source): + errors_dict = {} + if not source: + return errors_dict + if isinstance(source, dict): + for field_name, error in source.iteritems(): + errors_dict[field_name] = build_dict(error) + elif isinstance(source, ValidationError) and source.errors: + return build_dict(source.errors) + else: + return unicode(source) + return errors_dict + if not self.errors: + return {} + return build_dict(self.errors) + + def _format_errors(self): + """Returns a string listing all errors within a document""" + + def generate_key(value, prefix=''): + if isinstance(value, list): + value = ' '.join([generate_key(k) for k in value]) + if isinstance(value, dict): + value = ' '.join( + [generate_key(v, k) for k, v in value.iteritems()]) + + results = "%s.%s" % (prefix, value) if prefix else value + return results + + error_dict = defaultdict(list) + for k, v in self.to_dict().iteritems(): + error_dict[generate_key(v)].append(k) + return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index de484a1d..cf2c802c 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -4,7 +4,6 @@ import itertools import re import time import urllib2 -import urlparse import uuid import warnings from operator import itemgetter @@ -12,28 +11,32 @@ from operator import itemgetter import gridfs from bson import Binary, DBRef, SON, ObjectId +from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) from base import (BaseField, ComplexBaseField, ObjectIdField, - ValidationError, get_document, BaseDocument) + get_document, BaseDocument) from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME - try: from PIL import Image, ImageOps except ImportError: Image = None ImageOps = None -__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', - 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', - 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', - 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', - 'GenericReferenceField', 'FileField', 'BinaryField', - 'SortedListField', 'EmailField', 'GeoPointField', 'ImageField', - 'SequenceField', 'UUIDField', 'GenericEmbeddedDocumentField'] +__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', + 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', + 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', + 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', + 'SortedListField', 'DictField', 'MapField', 'ReferenceField', + 'GenericReferenceField', 'BinaryField', 'GridFSError', + 'GridFSProxy', 'FileField', 'ImageGridFsProxy', + 'ImproperlyConfigured', 'ImageField', 'GeoPointField', + 'SequenceField', 'UUIDField'] + + RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -143,7 +146,7 @@ class EmailField(StringField): EMAIL_REGEX = re.compile( r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string - r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain + r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain ) def validate(self, value): @@ -153,7 +156,7 @@ class EmailField(StringField): class IntField(BaseField): - """An integer field. + """An 32-bit integer field. """ def __init__(self, min_value=None, max_value=None, **kwargs): @@ -186,6 +189,40 @@ class IntField(BaseField): return int(value) +class LongField(BaseField): + """An 64-bit integer field. + """ + + def __init__(self, min_value=None, max_value=None, **kwargs): + self.min_value, self.max_value = min_value, max_value + super(LongField, self).__init__(**kwargs) + + def to_python(self, value): + try: + value = long(value) + except ValueError: + pass + return value + + def validate(self, value): + try: + value = long(value) + except: + self.error('%s could not be converted to long' % value) + + if self.min_value is not None and value < self.min_value: + self.error('Long value is too small') + + if self.max_value is not None and value > self.max_value: + self.error('Long value is too large') + + def prepare_query_value(self, op, value): + if value is None: + return value + + return long(value) + + class FloatField(BaseField): """An floating point number field. """ @@ -223,30 +260,58 @@ class FloatField(BaseField): class DecimalField(BaseField): """A fixed-point decimal number field. + .. versionchanged:: 0.8 .. versionadded:: 0.3 """ - def __init__(self, min_value=None, max_value=None, **kwargs): - self.min_value, self.max_value = min_value, max_value + def __init__(self, min_value=None, max_value=None, force_string=False, + precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): + """ + :param min_value: Validation rule for the minimum acceptable value. + :param max_value: Validation rule for the maximum acceptable value. + :param force_string: Store as a string. + :param precision: Number of decimal places to store. + :param rounding: The rounding rule from the python decimal libary: + + - decimial.ROUND_CEILING (towards Infinity) + - decimial.ROUND_DOWN (towards zero) + - decimial.ROUND_FLOOR (towards -Infinity) + - decimial.ROUND_HALF_DOWN (to nearest with ties going towards zero) + - decimial.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer) + - decimial.ROUND_HALF_UP (to nearest with ties going away from zero) + - decimial.ROUND_UP (away from zero) + - decimial.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero) + + Defaults to: ``decimal.ROUND_HALF_UP`` + + """ + self.min_value = min_value + self.max_value = max_value + self.force_string = force_string + self.precision = decimal.Decimal(".%s" % ("0" * precision)) + self.rounding = rounding + super(DecimalField, self).__init__(**kwargs) def to_python(self, value): - original_value = value - if not isinstance(value, basestring): - value = unicode(value) - try: - value = decimal.Decimal(value) - except ValueError: - return original_value - return value + if value is None: + return value + + # Convert to string for python 2.6 before casting to Decimal + value = decimal.Decimal("%s" % value) + return value.quantize(self.precision, rounding=self.rounding) def to_mongo(self, value): - return unicode(value) + if value is None: + return value + if self.force_string: + return unicode(value) + return float(self.to_python(value)) def validate(self, value): if not isinstance(value, decimal.Decimal): if not isinstance(value, basestring): - value = str(value) + value = unicode(value) try: value = decimal.Decimal(value) except Exception, exc: @@ -258,6 +323,9 @@ class DecimalField(BaseField): if self.max_value is not None and value > self.max_value: self.error('Decimal value is too large') + def prepare_query_value(self, op, value): + return self.to_mongo(value) + class BooleanField(BaseField): """A boolean field type. @@ -300,6 +368,8 @@ class DateTimeField(BaseField): return value if isinstance(value, datetime.date): return datetime.datetime(value.year, value.month, value.day) + if callable(value): + return value() # Attempt to parse a datetime: # value = smart_str(value) @@ -314,16 +384,16 @@ class DateTimeField(BaseField): usecs = 0 kwargs = {'microsecond': usecs} try: # Seconds are optional, so try converting seconds first. - return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d %H:%M:%S')[:6], **kwargs) except ValueError: try: # Try without seconds. - return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d %H:%M')[:5], **kwargs) except ValueError: # Try without hour/minutes/seconds. try: - return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d')[:3], **kwargs) except ValueError: return None @@ -410,6 +480,7 @@ class ComplexDateTimeField(StringField): return super(ComplexDateTimeField, self).__set__(instance, value) def validate(self, value): + value = self.to_python(value) if not isinstance(value, datetime.datetime): self.error('Only datetime objects may used in a ' 'ComplexDateTimeField') @@ -422,6 +493,7 @@ class ComplexDateTimeField(StringField): return original_value def to_mongo(self, value): + value = self.to_python(value) return self._convert_from_datetime(value) def prepare_query_value(self, op, value): @@ -460,7 +532,7 @@ class EmbeddedDocumentField(BaseField): return value return self.document_type.to_mongo(value) - def validate(self, value): + def validate(self, value, clean=True): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ @@ -468,7 +540,7 @@ class EmbeddedDocumentField(BaseField): if not isinstance(value, self.document_type): self.error('Invalid embedded document instance provided to an ' 'EmbeddedDocumentField') - self.document_type.validate(value) + self.document_type.validate(value, clean) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -498,12 +570,12 @@ class GenericEmbeddedDocumentField(BaseField): return value - def validate(self, value): + def validate(self, value, clean=True): if not isinstance(value, EmbeddedDocument): self.error('Invalid embedded document instance provided to an ' 'GenericEmbeddedDocumentField') - value.validate() + value.validate(clean=clean) def to_mongo(self, document): if document is None: @@ -529,7 +601,12 @@ class DynamicField(BaseField): return value if hasattr(value, 'to_mongo'): - return value.to_mongo() + cls = value.__class__ + val = value.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(value, (Document, EmbeddedDocument))): + val['_cls'] = cls.__name__ + return val if not isinstance(value, (dict, list, tuple)): return value @@ -540,13 +617,12 @@ class DynamicField(BaseField): value = dict([(k, v) for k, v in enumerate(value)]) data = {} - for k, v in value.items(): + for k, v in value.iteritems(): data[k] = self.to_mongo(v) + value = data if is_list: # Convert back to a list - value = [v for k, v in sorted(data.items(), key=itemgetter(0))] - else: - value = data + value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] return value def lookup_member(self, member_name): @@ -558,6 +634,10 @@ class DynamicField(BaseField): return StringField().prepare_query_value(op, value) return self.to_mongo(value) + def validate(self, value, clean=True): + if hasattr(value, "validate"): + value.validate(clean=clean) + class ListField(ComplexBaseField): """A list field that wraps a standard field, allowing multiple instances @@ -569,9 +649,6 @@ class ListField(ComplexBaseField): Required means it cannot be empty - as the default for ListFields is [] """ - # ListFields cannot be indexed with _types - MongoDB doesn't support this - _index_with_types = False - def __init__(self, field=None, **kwargs): self.field = field kwargs.setdefault('default', lambda: []) @@ -623,7 +700,8 @@ class SortedListField(ListField): def to_mongo(self, value): value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse) + return sorted(value, key=itemgetter(self._ordering), + reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse) @@ -653,7 +731,9 @@ class DictField(ComplexBaseField): self.error('Only dictionaries may be used in a DictField') if any(k for k in value.keys() if not isinstance(k, basestring)): - self.error('Invalid dictionary key - documents must have only string keys') + msg = ("Invalid dictionary key - documents must " + "have only string keys") + self.error(msg) if any(('.' in k or '$' in k) for k in value.keys()): self.error('Invalid dictionary key name - keys may not contain "."' ' or "$" characters') @@ -669,7 +749,6 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField, self).prepare_query_value(op, value) @@ -703,7 +782,7 @@ class ReferenceField(BaseField): * NULLIFY - Updates the reference to null. * CASCADE - Deletes the documents associated with the reference. * DENY - Prevent the deletion of the reference object. - * PULL - Pull the reference from a :class:`~mongoengine.ListField` + * PULL - Pull the reference from a :class:`~mongoengine.fields.ListField` of references Alternative syntax for registering delete rules (useful when implementing @@ -724,7 +803,7 @@ class ReferenceField(BaseField): .. versionchanged:: 0.5 added `reverse_delete_rule` """ - def __init__(self, document_type, dbref=None, + def __init__(self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs): """Initialises the Reference Field. @@ -738,12 +817,7 @@ class ReferenceField(BaseField): self.error('Argument to ReferenceField constructor must be a ' 'document class or a string') - if dbref is None: - msg = ("ReferenceFields will default to using ObjectId " - " strings in 0.8, set DBRef=True if this isn't desired") - warnings.warn(msg, FutureWarning) - - self.dbref = dbref if dbref is not None else True # To change in 0.8 + self.dbref = dbref self.document_type_obj = document_type self.reverse_delete_rule = reverse_delete_rule super(ReferenceField, self).__init__(**kwargs) @@ -766,9 +840,9 @@ class ReferenceField(BaseField): # Get value from document instance if available value = instance._data.get(self.name) - + self._auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs - if isinstance(value, DBRef): + if self._auto_dereference and isinstance(value, DBRef): value = self.document_type._get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) @@ -848,17 +922,22 @@ class GenericReferenceField(BaseField): return self value = instance._data.get(self.name) - if isinstance(value, (dict, SON)): + self._auto_dereference = instance._fields[self.name]._auto_dereference + if self._auto_dereference and isinstance(value, (dict, SON)): instance._data[self.name] = self.dereference(value) return super(GenericReferenceField, self).__get__(instance, owner) def validate(self, value): - if not isinstance(value, (Document, DBRef)): + if not isinstance(value, (Document, DBRef, dict, SON)): self.error('GenericReferences can only contain documents') + if isinstance(value, (dict, SON)): + if '_ref' not in value or '_cls' not in value: + self.error('GenericReferences can only contain documents') + # We need the id from the saved object to create the DBRef - if isinstance(value, Document) and value.id is None: + elif isinstance(value, Document) and value.id is None: self.error('You can only reference documents once they have been' ' saved to the database') @@ -960,7 +1039,7 @@ class GridFSProxy(object): if name in attrs: return self.__getattribute__(name) obj = self.get() - if name in dir(obj): + if hasattr(obj, name): return getattr(obj, name) raise AttributeError @@ -975,14 +1054,22 @@ class GridFSProxy(object): self_dict['_fs'] = None return self_dict + def __copy__(self): + copied = GridFSProxy() + copied.__dict__.update(self.__getstate__()) + return copied + + def __deepcopy__(self, memo): + return self.__copy__() + def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) def __eq__(self, other): if isinstance(other, GridFSProxy): - return ((self.grid_id == other.grid_id) and - (self.collection_name == other.collection_name) and - (self.db_alias == other.db_alias)) + return ((self.grid_id == other.grid_id) and + (self.collection_name == other.collection_name) and + (self.db_alias == other.db_alias)) else: return False @@ -1107,13 +1194,11 @@ class FileField(BaseField): grid_file.delete() except: pass - # Create a new file with the new data - grid_file.put(value) - else: - # Create a new proxy object as we don't already have one - instance._data[key] = self.proxy_class(key=key, instance=instance, - collection_name=self.collection_name) - instance._data[key].put(value) + + # Create a new proxy object as we don't already have one + instance._data[key] = self.proxy_class(key=key, instance=instance, + collection_name=self.collection_name) + instance._data[key].put(value) else: instance._data[key] = value @@ -1150,6 +1235,8 @@ class ImageGridFsProxy(GridFSProxy): Insert a image in database applying field properties (size, thumbnail_size) """ + if not self.instance: + import ipdb; ipdb.set_trace(); field = self.instance._fields[self.key] try: @@ -1177,10 +1264,7 @@ class ImageGridFsProxy(GridFSProxy): size = field.thumbnail_size if size['force']: - thumbnail = ImageOps.fit(img, - (size['width'], - size['height']), - Image.ANTIALIAS) + thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS) else: thumbnail = img.copy() thumbnail.thumbnail((size['width'], @@ -1188,8 +1272,7 @@ class ImageGridFsProxy(GridFSProxy): Image.ANTIALIAS) if thumbnail: - thumb_id = self._put_thumbnail(thumbnail, - img_format) + thumb_id = self._put_thumbnail(thumbnail, img_format) else: thumb_id = None @@ -1292,7 +1375,7 @@ class ImageField(FileField): if isinstance(att, (tuple, list)): if PY3: value = dict(itertools.zip_longest(params_size, att, - fillvalue=None)) + fillvalue=None)) else: value = dict(map(None, params_size, att)) @@ -1325,8 +1408,9 @@ class GeoPointField(BaseField): self.error('Both values in point must be float or int') -class SequenceField(IntField): - """Provides a sequental counter (see http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers) +class SequenceField(BaseField): + """Provides a sequental counter see: + http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers .. note:: @@ -1336,15 +1420,29 @@ class SequenceField(IntField): cluster of machines, it is easier to create an object ID than have global, uniformly increasing sequence numbers. + Use any callable as `value_decorator` to transform calculated counter into + any value suitable for your needs, e.g. string or hexadecimal + representation of the default integer counter value. + .. versionadded:: 0.5 + + .. versionchanged:: 0.8 added `value_decorator` """ - def __init__(self, collection_name=None, db_alias=None, sequence_name=None, *args, **kwargs): - self.collection_name = collection_name or 'mongoengine.counters' + + _auto_gen = True + COLLECTION_NAME = 'mongoengine.counters' + VALUE_DECORATOR = int + + def __init__(self, collection_name=None, db_alias=None, sequence_name=None, + value_decorator=None, *args, **kwargs): + self.collection_name = collection_name or self.COLLECTION_NAME self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name + self.value_decorator = (callable(value_decorator) and + value_decorator or self.VALUE_DECORATOR) return super(SequenceField, self).__init__(*args, **kwargs) - def generate_new_value(self): + def generate(self): """ Generate and Increment the counter """ @@ -1355,7 +1453,18 @@ class SequenceField(IntField): update={"$inc": {"next": 1}}, new=True, upsert=True) - return counter['next'] + return self.value_decorator(counter['next']) + + def set_next_value(self, value): + """Helper method to set the next sequence value""" + sequence_name = self.get_sequence_name() + sequence_id = "%s.%s" % (sequence_name, self.name) + collection = get_db(alias=self.db_alias)[self.collection_name] + counter = collection.find_and_modify(query={"_id": sequence_id}, + update={"$set": {"next": value}}, + new=True, + upsert=True) + return self.value_decorator(counter['next']) def get_sequence_name(self): if self.sequence_name: @@ -1365,35 +1474,27 @@ class SequenceField(IntField): return owner._get_collection_name() else: return ''.join('_%s' % c if c.isupper() else c - for c in owner._class_name).strip('_').lower() + for c in owner._class_name).strip('_').lower() def __get__(self, instance, owner): - - if instance is None: - return self - - if not instance._data: - return - - value = instance._data.get(self.name) - - if not value and instance._initialised: - value = self.generate_new_value() + value = super(SequenceField, self).__get__(instance, owner) + if value is None and instance._initialised: + value = self.generate() instance._data[self.name] = value instance._mark_as_changed(self.name) - return int(value) if value else None + return value def __set__(self, instance, value): if value is None and instance._initialised: - value = self.generate_new_value() + value = self.generate() return super(SequenceField, self).__set__(instance, value) def to_python(self, value): if value is None: - value = self.generate_new_value() + value = self.generate() return value @@ -1404,19 +1505,15 @@ class UUIDField(BaseField): """ _binary = None - def __init__(self, binary=None, **kwargs): + def __init__(self, binary=True, **kwargs): """ Store UUID data in the database - :param binary: (optional) boolean store as binary. + :param binary: if False store as a string. + .. versionchanged:: 0.8.0 .. versionchanged:: 0.6.19 """ - if binary is None: - binary = False - msg = ("UUIDFields will soon default to store as binary, please " - "configure binary=False if you wish to store as a string") - warnings.warn(msg, FutureWarning) self._binary = binary super(UUIDField, self).__init__(**kwargs) @@ -1434,6 +1531,8 @@ class UUIDField(BaseField): def to_mongo(self, value): if not self._binary: return unicode(value) + elif isinstance(value, basestring): + return uuid.UUID(value) return value def prepare_query_value(self, op, value): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py deleted file mode 100644 index 727f56ea..00000000 --- a/mongoengine/queryset.py +++ /dev/null @@ -1,2059 +0,0 @@ -import pprint -import re -import copy -import itertools -import operator - -from collections import defaultdict -from functools import partial - -from mongoengine.python_support import product, reduce, PY3 - -import pymongo -from bson.code import Code -from bson.son import SON - -from mongoengine import signals - -__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', - 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL'] - - -# The maximum number of items to display in a QuerySet.__repr__ -REPR_OUTPUT_SIZE = 20 - -# Delete rules -DO_NOTHING = 0 -NULLIFY = 1 -CASCADE = 2 -DENY = 3 -PULL = 4 - - -class DoesNotExist(Exception): - pass - - -class MultipleObjectsReturned(Exception): - pass - - -class InvalidQueryError(Exception): - pass - - -class OperationError(Exception): - pass - - -class NotUniqueError(OperationError): - pass - - -RE_TYPE = type(re.compile('')) - - -class QNodeVisitor(object): - """Base visitor class for visiting Q-object nodes in a query tree. - """ - - def visit_combination(self, combination): - """Called by QCombination objects. - """ - return combination - - def visit_query(self, query): - """Called by (New)Q objects. - """ - return query - - -class SimplificationVisitor(QNodeVisitor): - """Simplifies query trees by combinging unnecessary 'and' connection nodes - into a single Q-object. - """ - - def visit_combination(self, combination): - if combination.operation == combination.AND: - # The simplification only applies to 'simple' queries - if all(isinstance(node, Q) for node in combination.children): - queries = [node.query for node in combination.children] - return Q(**self._query_conjunction(queries)) - return combination - - def _query_conjunction(self, queries): - """Merges query dicts - effectively &ing them together. - """ - query_ops = set() - combined_query = {} - for query in queries: - ops = set(query.keys()) - # Make sure that the same operation isn't applied more than once - # to a single field - intersection = ops.intersection(query_ops) - if intersection: - msg = 'Duplicate query conditions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) - - query_ops.update(ops) - combined_query.update(copy.deepcopy(query)) - return combined_query - - -class QueryTreeTransformerVisitor(QNodeVisitor): - """Transforms the query tree in to a form that may be used with MongoDB. - """ - - def visit_combination(self, combination): - if combination.operation == combination.AND: - # MongoDB doesn't allow us to have too many $or operations in our - # queries, so the aim is to move the ORs up the tree to one - # 'master' $or. Firstly, we must find all the necessary parts (part - # of an AND combination or just standard Q object), and store them - # separately from the OR parts. - or_groups = [] - and_parts = [] - for node in combination.children: - if isinstance(node, QCombination): - if node.operation == node.OR: - # Any of the children in an $or component may cause - # the query to succeed - or_groups.append(node.children) - elif node.operation == node.AND: - and_parts.append(node) - elif isinstance(node, Q): - and_parts.append(node) - - # Now we combine the parts into a usable query. AND together all of - # the necessary parts. Then for each $or part, create a new query - # that ANDs the necessary part with the $or part. - clauses = [] - for or_group in product(*or_groups): - q_object = reduce(lambda a, b: a & b, and_parts, Q()) - q_object = reduce(lambda a, b: a & b, or_group, q_object) - clauses.append(q_object) - # Finally, $or the generated clauses in to one query. Each of the - # clauses is sufficient for the query to succeed. - return reduce(lambda a, b: a | b, clauses, Q()) - - if combination.operation == combination.OR: - children = [] - # Crush any nested ORs in to this combination as MongoDB doesn't - # support nested $or operations - for node in combination.children: - if (isinstance(node, QCombination) and - node.operation == combination.OR): - children += node.children - else: - children.append(node) - combination.children = children - - return combination - - -class QueryCompilerVisitor(QNodeVisitor): - """Compiles the nodes in a query tree to a PyMongo-compatible query - dictionary. - """ - - def __init__(self, document): - self.document = document - - def visit_combination(self, combination): - if combination.operation == combination.OR: - return {'$or': combination.children} - elif combination.operation == combination.AND: - return self._mongo_query_conjunction(combination.children) - return combination - - def visit_query(self, query): - return QuerySet._transform_query(self.document, **query.query) - - def _mongo_query_conjunction(self, queries): - """Merges Mongo query dicts - effectively &ing them together. - """ - combined_query = {} - for query in queries: - for field, ops in query.items(): - if field not in combined_query: - combined_query[field] = ops - else: - # The field is already present in the query the only way - # we can merge is if both the existing value and the new - # value are operation dicts, reject anything else - if (not isinstance(combined_query[field], dict) or - not isinstance(ops, dict)): - message = 'Conflicting values for ' + field - raise InvalidQueryError(message) - - current_ops = set(combined_query[field].keys()) - new_ops = set(ops.keys()) - # Make sure that the same operation isn't applied more than - # once to a single field - intersection = current_ops.intersection(new_ops) - if intersection: - msg = 'Duplicate query conditions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) - - # Right! We've got two non-overlapping dicts of operations! - combined_query[field].update(copy.deepcopy(ops)) - return combined_query - - -class QNode(object): - """Base class for nodes in query trees. - """ - - AND = 0 - OR = 1 - - def to_query(self, document): - query = self.accept(SimplificationVisitor()) - query = query.accept(QueryTreeTransformerVisitor()) - query = query.accept(QueryCompilerVisitor(document)) - return query - - def accept(self, visitor): - raise NotImplementedError - - def _combine(self, other, operation): - """Combine this node with another node into a QCombination object. - """ - if getattr(other, 'empty', True): - return self - - if self.empty: - return other - - return QCombination(operation, [self, other]) - - @property - def empty(self): - return False - - def __or__(self, other): - return self._combine(other, self.OR) - - def __and__(self, other): - return self._combine(other, self.AND) - - -class QCombination(QNode): - """Represents the combination of several conditions by a given logical - operator. - """ - - def __init__(self, operation, children): - self.operation = operation - self.children = [] - for node in children: - # If the child is a combination of the same type, we can merge its - # children directly into this combinations children - if isinstance(node, QCombination) and node.operation == operation: - self.children += node.children - else: - self.children.append(node) - - def accept(self, visitor): - for i in range(len(self.children)): - if isinstance(self.children[i], QNode): - self.children[i] = self.children[i].accept(visitor) - - return visitor.visit_combination(self) - - @property - def empty(self): - return not bool(self.children) - - -class Q(QNode): - """A simple query object, used in a query tree to build up more complex - query structures. - """ - - def __init__(self, **query): - self.query = query - - def accept(self, visitor): - return visitor.visit_query(self) - - @property - def empty(self): - return not bool(self.query) - - -class QueryFieldList(object): - """Object that handles combinations of .only() and .exclude() calls""" - ONLY = 1 - EXCLUDE = 0 - - def __init__(self, fields=[], value=ONLY, always_include=[]): - self.value = value - self.fields = set(fields) - self.always_include = set(always_include) - self._id = None - - def as_dict(self): - field_list = dict((field, self.value) for field in self.fields) - if self._id is not None: - field_list['_id'] = self._id - return field_list - - def __add__(self, f): - if not self.fields: - self.fields = f.fields - self.value = f.value - elif self.value is self.ONLY and f.value is self.ONLY: - self.fields = self.fields.intersection(f.fields) - elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: - self.fields = self.fields.union(f.fields) - elif self.value is self.ONLY and f.value is self.EXCLUDE: - self.fields -= f.fields - elif self.value is self.EXCLUDE and f.value is self.ONLY: - self.value = self.ONLY - self.fields = f.fields - self.fields - - if '_id' in f.fields: - self._id = f.value - - if self.always_include: - if self.value is self.ONLY and self.fields: - self.fields = self.fields.union(self.always_include) - else: - self.fields -= self.always_include - return self - - def reset(self): - self.fields = set([]) - self.value = self.ONLY - - def __nonzero__(self): - return bool(self.fields) - - -class QuerySet(object): - """A set of results returned from a query. Wraps a MongoDB cursor, - providing :class:`~mongoengine.Document` objects as the results. - """ - - __already_indexed = set() - __dereference = False - - def __init__(self, document, collection): - self._document = document - self._collection_obj = collection - self._mongo_query = None - self._query_obj = Q() - self._initial_query = {} - self._where_clause = None - self._loaded_fields = QueryFieldList() - self._ordering = [] - self._snapshot = False - self._timeout = True - self._class_check = True - self._slave_okay = False - self._iter = False - self._scalar = [] - self._as_pymongo = False - self._as_pymongo_coerce = False - - # If inheritance is allowed, only return instances and instances of - # subclasses of the class being used - if document._meta.get('allow_inheritance') != False: - self._initial_query = {'_types': self._document._class_name} - self._loaded_fields = QueryFieldList(always_include=['_cls']) - self._cursor_obj = None - self._limit = None - self._skip = None - self._hint = -1 # Using -1 as None is a valid value for hint - - def __deepcopy__(self, memo): - """Essential for chained queries with ReferenceFields involved""" - return self.clone() - - def clone(self): - """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet` - - .. versionadded:: 0.5 - """ - c = self.__class__(self._document, self._collection_obj) - - copy_props = ('_initial_query', '_query_obj', '_where_clause', - '_loaded_fields', '_ordering', '_snapshot', '_timeout', - '_limit', '_skip', '_slave_okay', '_hint') - - for prop in copy_props: - val = getattr(self, prop) - setattr(c, prop, copy.deepcopy(val)) - - return c - - @property - def _query(self): - if self._mongo_query is None: - self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) - return self._mongo_query - - def ensure_index(self, key_or_list, drop_dups=False, background=False, - **kwargs): - """Ensure that the given indexes are in place. - - :param key_or_list: a single index key or a list of index keys (to - construct a multi-field index); keys may be prefixed with a **+** - or a **-** to determine the index ordering - """ - index_spec = QuerySet._build_index_spec(self._document, key_or_list) - index_spec = index_spec.copy() - fields = index_spec.pop('fields') - index_spec['drop_dups'] = drop_dups - index_spec['background'] = background - index_spec.update(kwargs) - - self._collection.ensure_index(fields, **index_spec) - return self - - def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): - """Filter the selected documents by calling the - :class:`~mongoengine.queryset.QuerySet` with a query. - - :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in - the query; the :class:`~mongoengine.queryset.QuerySet` is filtered - multiple times with different :class:`~mongoengine.queryset.Q` - objects, only the last one will be used - :param class_check: If set to False bypass class name check when - querying collection - :param slave_okay: if True, allows this query to be run against a - replica secondary. - :param query: Django-style query keyword arguments - """ - query = Q(**query) - if q_obj: - query &= q_obj - self._query_obj &= query - self._mongo_query = None - self._cursor_obj = None - self._class_check = class_check - return self - - def filter(self, *q_objs, **query): - """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` - """ - return self.__call__(*q_objs, **query) - - def all(self): - """Returns all documents.""" - return self.__call__() - - def _ensure_indexes(self): - """Checks the document meta data and ensures all the indexes exist. - - .. note:: You can disable automatic index creation by setting - `auto_create_index` to False in the documents meta data - """ - background = self._document._meta.get('index_background', False) - drop_dups = self._document._meta.get('index_drop_dups', False) - index_opts = self._document._meta.get('index_opts') or {} - index_types = self._document._meta.get('index_types', True) - - # determine if an index which we are creating includes - # _type as its first field; if so, we can avoid creating - # an extra index on _type, as mongodb will use the existing - # index to service queries against _type - types_indexed = False - - def includes_types(fields): - first_field = None - if len(fields): - if isinstance(fields[0], basestring): - first_field = fields[0] - elif isinstance(fields[0], (list, tuple)) and len(fields[0]): - first_field = fields[0][0] - return first_field == '_types' - - # Ensure indexes created by uniqueness constraints - for index in self._document._meta['unique_indexes']: - types_indexed = types_indexed or includes_types(index) - self._collection.ensure_index(index, unique=True, - background=background, drop_dups=drop_dups, **index_opts) - - # Ensure document-defined indexes are created - if self._document._meta['index_specs']: - index_spec = self._document._meta['index_specs'] - for spec in index_spec: - spec = spec.copy() - fields = spec.pop('fields') - types_indexed = types_indexed or includes_types(fields) - opts = index_opts.copy() - opts.update(spec) - self._collection.ensure_index(fields, - background=background, **opts) - - # If _types is being used (for polymorphism), it needs an index, - # only if another index doesn't begin with _types - if index_types and '_types' in self._query and not types_indexed: - self._collection.ensure_index('_types', - background=background, **index_opts) - - # Add geo indicies - for field in self._document._geo_indices(): - index_spec = [(field.db_field, pymongo.GEO2D)] - self._collection.ensure_index(index_spec, - background=background, **index_opts) - - @classmethod - def _build_index_spec(cls, doc_cls, spec): - """Build a PyMongo index spec from a MongoEngine index spec. - """ - if isinstance(spec, basestring): - spec = {'fields': [spec]} - elif isinstance(spec, (list, tuple)): - spec = {'fields': list(spec)} - elif isinstance(spec, dict): - spec = dict(spec) - - index_list = [] - direction = None - - allow_inheritance = doc_cls._meta.get('allow_inheritance') != False - - # If sparse - dont include types - use_types = allow_inheritance and not spec.get('sparse', False) - - for key in spec['fields']: - # If inherited spec continue - if isinstance(key, (list, tuple)): - continue - - # Get ASCENDING direction from +, DESCENDING from -, and GEO2D from * - direction = pymongo.ASCENDING - if key.startswith("-"): - direction = pymongo.DESCENDING - elif key.startswith("*"): - direction = pymongo.GEO2D - if key.startswith(("+", "-", "*")): - key = key[1:] - - # Use real field name, do it manually because we need field - # objects for the next part (list field checking) - parts = key.split('.') - if parts in (['pk'], ['id'], ['_id']): - key = '_id' - fields = [] - else: - fields = QuerySet._lookup_field(doc_cls, parts) - parts = [field if field == '_id' else field.db_field - for field in fields] - key = '.'.join(parts) - index_list.append((key, direction)) - - # Check if a list field is being used, don't use _types if it is - if use_types and not all(f._index_with_types for f in fields): - use_types = False - - # If _types is being used, prepend it to every specified index - index_types = doc_cls._meta.get('index_types', True) - - if (spec.get('types', index_types) and use_types - and direction is not pymongo.GEO2D): - index_list.insert(0, ('_types', 1)) - - spec['fields'] = index_list - if spec.get('sparse', False) and len(spec['fields']) > 1: - raise ValueError( - 'Sparse indexes can only have one field in them. ' - 'See https://jira.mongodb.org/browse/SERVER-2193') - - return spec - - @classmethod - def _reset_already_indexed(cls, document=None): - """Helper to reset already indexed, can be useful for testing purposes""" - if document: - cls.__already_indexed.discard(document) - cls.__already_indexed.clear() - - - @property - def _collection(self): - """Property that returns the collection object. This allows us to - perform operations only if the collection is accessed. - """ - if self._document not in QuerySet.__already_indexed: - # Ensure collection exists - db = self._document._get_db() - if self._collection_obj.name not in db.collection_names(): - self._document._collection = None - self._collection_obj = self._document._get_collection() - - QuerySet.__already_indexed.add(self._document) - - if self._document._meta.get('auto_create_index', True): - self._ensure_indexes() - - return self._collection_obj - - @property - def _cursor_args(self): - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - 'slave_okay': self._slave_okay - } - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() - return cursor_args - - @property - def _cursor(self): - if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - self._cursor_obj.where(self._where_clause) - - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model - self.order_by(*self._document._meta['ordering']) - self._cursor_obj.sort(self._ordering) - - if self._limit is not None: - self._cursor_obj.limit(self._limit - (self._skip or 0)) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - return self._cursor_obj - - @classmethod - def _lookup_field(cls, document, parts): - """Lookup a field based on its attribute and return a list containing - the field's parents and the field. - """ - if not isinstance(parts, (list, tuple)): - parts = [parts] - fields = [] - field = None - - for field_name in parts: - # Handle ListField indexing: - if field_name.isdigit(): - try: - new_field = field.field - except AttributeError, err: - raise InvalidQueryError( - "Can't use index on unsubscriptable field (%s)" % err) - fields.append(field_name) - continue - - if field is None: - # Look up first field from the document - if field_name == 'pk': - # Deal with "primary key" alias - field_name = document._meta['id_field'] - if field_name in document._fields: - field = document._fields[field_name] - elif document._dynamic: - from fields import DynamicField - field = DynamicField(db_field=field_name) - else: - raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) - else: - from mongoengine.fields import ReferenceField, GenericReferenceField - if isinstance(field, (ReferenceField, GenericReferenceField)): - raise InvalidQueryError('Cannot perform join in mongoDB: %s' % '__'.join(parts)) - if hasattr(getattr(field, 'field', None), 'lookup_member'): - new_field = field.field.lookup_member(field_name) - else: - # Look up subfield on the previous field - new_field = field.lookup_member(field_name) - from base import ComplexBaseField - if not new_field and isinstance(field, ComplexBaseField): - fields.append(field_name) - continue - elif not new_field: - raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) - field = new_field # update field to the new field type - fields.append(field) - return fields - - @classmethod - def _translate_field_name(cls, doc_cls, field, sep='.'): - """Translate a field attribute name to a database field name. - """ - parts = field.split(sep) - parts = [f.db_field for f in QuerySet._lookup_field(doc_cls, parts)] - return '.'.join(parts) - - @classmethod - def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): - """Transform a query from Django-style format to Mongo format. - """ - operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not'] - geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'within_polygon', 'near', 'near_sphere'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact'] - custom_operators = ['match'] - - mongo_query = {} - merge_query = defaultdict(list) - for key, value in sorted(query.items()): - if key == "__raw__": - mongo_query.update(value) - continue - - parts = key.split('__') - indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] - parts = [part for part in parts if not part.isdigit()] - # Check for an operator and transform to mongo-style if there is - op = None - if parts[-1] in operators + match_operators + geo_operators + custom_operators: - op = parts.pop() - - negate = False - if parts[-1] == 'not': - parts.pop() - negate = True - - if _doc_cls: - # Switch field names to proper names [set in Field(name='foo')] - fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [] - - cleaned_fields = [] - for field in fields: - append_field = True - if isinstance(field, basestring): - parts.append(field) - append_field = False - else: - parts.append(field.db_field) - if append_field: - cleaned_fields.append(field) - - # Convert value to proper value - field = cleaned_fields[-1] - - singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] - singular_ops += match_operators - if op in singular_ops: - if isinstance(field, basestring): - if op in match_operators and isinstance(value, basestring): - from mongoengine import StringField - value = StringField.prepare_query_value(op, value) - else: - value = field - else: - value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all', 'near'): - # 'in', 'nin' and 'all' require a list of values - value = [field.prepare_query_value(op, v) for v in value] - - # if op and op not in match_operators: - if op: - if op in geo_operators: - if op == "within_distance": - value = {'$within': {'$center': value}} - elif op == "within_spherical_distance": - value = {'$within': {'$centerSphere': value}} - elif op == "within_polygon": - value = {'$within': {'$polygon': value}} - elif op == "near": - value = {'$near': value} - elif op == "near_sphere": - value = {'$nearSphere': value} - elif op == 'within_box': - value = {'$within': {'$box': value}} - else: - raise NotImplementedError("Geo method '%s' has not " - "been implemented" % op) - elif op in custom_operators: - if op == 'match': - value = {"$elemMatch": value} - else: - NotImplementedError("Custom method '%s' has not " - "been implemented" % op) - elif op not in match_operators: - value = {'$' + op: value} - - if negate: - value = {'$not': value} - - for i, part in indices: - parts.insert(i, part) - key = '.'.join(parts) - if op is None or key not in mongo_query: - mongo_query[key] = value - elif key in mongo_query: - if key in mongo_query and isinstance(mongo_query[key], dict): - mongo_query[key].update(value) - else: - # Store for manually merging later - merge_query[key].append(value) - - # The queryset has been filter in such a way we must manually merge - for k, v in merge_query.items(): - merge_query[k].append(mongo_query[k]) - del mongo_query[k] - if isinstance(v, list): - value = [{k:val} for val in v] - if '$and' in mongo_query.keys(): - mongo_query['$and'].append(value) - else: - mongo_query['$and'] = value - return mongo_query - - def get(self, *q_objs, **query): - """Retrieve the the matching object raising - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` exception if multiple results and - :class:`~mongoengine.queryset.DoesNotExist` or `DocumentName.DoesNotExist` - if no results are found. - - .. versionadded:: 0.3 - """ - self.limit(2) - self.__call__(*q_objs, **query) - try: - result1 = self.next() - except StopIteration: - raise self._document.DoesNotExist("%s matching query does not exist." - % self._document._class_name) - try: - result2 = self.next() - except StopIteration: - return result1 - - self.rewind() - message = u'%d items returned, instead of 1' % self.count() - raise self._document.MultipleObjectsReturned(message) - - def get_or_create(self, write_options=None, auto_save=True, *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a tuple of - ``(object, created)``, where ``object`` is the retrieved or created object - and ``created`` is a boolean specifying whether a new object was created. Raises - :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` if multiple results are found. - A new document will be created if the document doesn't exists; a - dictionary of default values for the new document may be provided as a - keyword argument called :attr:`defaults`. - - .. note:: This requires two separate operations and therefore a - race condition exists. Because there are no transactions in mongoDB - other approaches should be investigated, to ensure you don't - accidently duplicate data when using this method. - - :param write_options: optional extra keyword arguments used if we - have to create a new document. - Passes any write_options onto :meth:`~mongoengine.Document.save` - - :param auto_save: if the object is to be saved automatically if not found. - - .. versionchanged:: 0.6 - added `auto_save` - .. versionadded:: 0.3 - """ - defaults = query.get('defaults', {}) - if 'defaults' in query: - del query['defaults'] - - try: - doc = self.get(*q_objs, **query) - return doc, False - except self._document.DoesNotExist: - query.update(defaults) - doc = self._document(**query) - - if auto_save: - doc.save(write_options=write_options) - return doc, True - - def create(self, **kwargs): - """Create new object. Returns the saved object instance. - - .. versionadded:: 0.4 - """ - doc = self._document(**kwargs) - doc.save() - return doc - - def first(self): - """Retrieve the first object matching the query. - """ - try: - result = self[0] - except IndexError: - result = None - return result - - def insert(self, doc_or_docs, load_bulk=True, safe=False, write_options=None): - """bulk insert documents - - If ``safe=True`` and the operation is unsuccessful, an - :class:`~mongoengine.OperationError` will be raised. - - :param docs_or_doc: a document or list of documents to be inserted - :param load_bulk (optional): If True returns the list of document instances - :param safe: check if the operation succeeded before returning - :param write_options: Extra keyword arguments are passed down to - :meth:`~pymongo.collection.Collection.insert` - which will be used as options for the resultant ``getLastError`` command. - For example, ``insert(..., {w: 2, fsync: True})`` will wait until at least two - servers have recorded the write and will force an fsync on each server being - written to. - - By default returns document instances, set ``load_bulk`` to False to - return just ``ObjectIds`` - - .. versionadded:: 0.5 - """ - from document import Document - - if not write_options: - write_options = {} - write_options.update({'safe': safe}) - - docs = doc_or_docs - return_one = False - if isinstance(docs, Document) or issubclass(docs.__class__, Document): - return_one = True - docs = [docs] - - raw = [] - for doc in docs: - if not isinstance(doc, self._document): - msg = "Some documents inserted aren't instances of %s" % str(self._document) - raise OperationError(msg) - if doc.pk and not doc._created: - msg = "Some documents have ObjectIds use doc.update() instead" - raise OperationError(msg) - raw.append(doc.to_mongo()) - - signals.pre_bulk_insert.send(self._document, documents=docs) - try: - ids = self._collection.insert(raw, **write_options) - except pymongo.errors.OperationFailure, err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) - - if not load_bulk: - signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False) - return return_one and ids[0] or ids - - documents = self.in_bulk(ids) - results = [] - for obj_id in ids: - results.append(documents.get(obj_id)) - signals.post_bulk_insert.send( - self._document, documents=results, loaded=True) - return return_one and results[0] or results - - def with_id(self, object_id): - """Retrieve the object matching the id provided. Uses `object_id` only - and raises InvalidQueryError if a filter has been applied. - - :param object_id: the value for the id of the document to look up - - .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set - """ - if not self._query_obj.empty: - raise InvalidQueryError("Cannot use a filter whilst using `with_id`") - return self.filter(pk=object_id).first() - - def in_bulk(self, object_ids): - """Retrieve a set of documents by their ids. - - :param object_ids: a list or tuple of ``ObjectId``\ s - :rtype: dict of ObjectIds as keys and collection-specific - Document subclasses as values. - - .. versionadded:: 0.3 - """ - doc_map = {} - - docs = self._collection.find({'_id': {'$in': object_ids}}, - **self._cursor_args) - if self._scalar: - for doc in docs: - doc_map[doc['_id']] = self._get_scalar( - self._document._from_son(doc)) - elif self._as_pymongo: - for doc in docs: - doc_map[doc['_id']] = self._get_as_pymongo(doc) - else: - for doc in docs: - doc_map[doc['_id']] = self._document._from_son(doc) - - return doc_map - - def next(self): - """Wrap the result in a :class:`~mongoengine.Document` object. - """ - self._iter = True - try: - if self._limit == 0: - raise StopIteration - if self._scalar: - return self._get_scalar(self._document._from_son( - self._cursor.next())) - if self._as_pymongo: - return self._get_as_pymongo(self._cursor.next()) - - return self._document._from_son(self._cursor.next()) - except StopIteration, e: - self.rewind() - raise e - - def rewind(self): - """Rewind the cursor to its unevaluated state. - - .. versionadded:: 0.3 - """ - self._iter = False - self._cursor.rewind() - - def count(self): - """Count the selected elements in the query. - """ - if self._limit == 0: - return 0 - return self._cursor.count(with_limit_and_skip=True) - - def __len__(self): - return self.count() - - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): - """Perform a map/reduce query using the current query spec - and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, - it must be the last call made, as it does not return a maleable - ``QuerySet``. - - See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` - and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` - tests in ``tests.queryset.QuerySetTest`` for usage examples. - - :param map_f: map function, as :class:`~bson.code.Code` or string - :param reduce_f: reduce function, as - :class:`~bson.code.Code` or string - :param output: output collection name, if set to 'inline' will try to - use :class:`~pymongo.collection.Collection.inline_map_reduce` - This can also be a dictionary containing output options - see: http://docs.mongodb.org/manual/reference/commands/#mapReduce - :param finalize_f: finalize function, an optional function that - performs any post-reduction processing. - :param scope: values to insert into map/reduce global scope. Optional. - :param limit: number of objects from current query to provide - to map/reduce method - - Returns an iterator yielding - :class:`~mongoengine.document.MapReduceDocument`. - - .. note:: - - Map/Reduce changed in server version **>= 1.7.4**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.11**. - - .. versionchanged:: 0.5 - - removed ``keep_temp`` keyword argument, which was only relevant - for MongoDB server versions older than 1.7.4 - - .. versionadded:: 0.3 - """ - from document import MapReduceDocument - - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - - map_f_scope = {} - if isinstance(map_f, Code): - map_f_scope = map_f.scope - map_f = unicode(map_f) - map_f = Code(self._sub_js_fields(map_f), map_f_scope) - - reduce_f_scope = {} - if isinstance(reduce_f, Code): - reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) - reduce_f_code = self._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) - - mr_args = {'query': self._query} - - if finalize_f: - finalize_f_scope = {} - if isinstance(finalize_f, Code): - finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) - finalize_f_code = self._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f - - if scope: - mr_args['scope'] = scope - - if limit: - mr_args['limit'] = limit - - if output == 'inline' and not self._ordering: - map_reduce_function = 'inline_map_reduce' - else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output - - results = getattr(self._collection, map_reduce_function)(map_f, reduce_f, **mr_args) - - if map_reduce_function == 'map_reduce': - results = results.find() - - if self._ordering: - results = results.sort(self._ordering) - - for doc in results: - yield MapReduceDocument(self._document, self._collection, - doc['_id'], doc['value']) - - def limit(self, n): - """Limit the number of returned documents to `n`. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[:5]``). - - :param n: the maximum number of objects to return - """ - if n == 0: - self._cursor.limit(1) - else: - self._cursor.limit(n) - self._limit = n - - # Return self to allow chaining - return self - - def skip(self, n): - """Skip `n` documents before returning the results. This may also be - achieved using array-slicing syntax (e.g. ``User.objects[5:]``). - - :param n: the number of objects to skip before returning results - """ - self._cursor.skip(n) - self._skip = n - return self - - def hint(self, index=None): - """Added 'hint' support, telling Mongo the proper index to use for the - query. - - Judicious use of hints can greatly improve query performance. When doing - a query on multiple fields (at least one of which is indexed) pass the - indexed field as a hint to the query. - - Hinting will not do anything if the corresponding index does not exist. - The last hint applied to this cursor takes precedence over all others. - - .. versionadded:: 0.5 - """ - self._cursor.hint(index) - self._hint = index - return self - - def __getitem__(self, key): - """Support skip and limit using getitem and slicing syntax. - """ - # Slice provided - if isinstance(key, slice): - try: - self._cursor_obj = self._cursor[key] - self._skip, self._limit = key.start, key.stop - except IndexError, err: - # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. - start = key.start or 0 - if start >= 0 and key.stop >= 0 and key.step is None: - if start == key.stop: - self.limit(0) - self._skip, self._limit = key.start, key.stop - start - return self - raise err - # Allow further QuerySet modifications to be performed - return self - # Integer index provided - elif isinstance(key, int): - if self._scalar: - return self._get_scalar(self._document._from_son( - self._cursor[key])) - if self._as_pymongo: - return self._get_as_pymongo(self._cursor.next()) - return self._document._from_son(self._cursor[key]) - raise AttributeError - - def distinct(self, field): - """Return a list of distinct values for a given field. - - :param field: the field to select distinct values from - - .. versionadded:: 0.4 - .. versionchanged:: 0.5 - Fixed handling references - .. versionchanged:: 0.6 - Improved db_field refrence handling - """ - try: - field = self._fields_to_dbfields([field]).pop() - finally: - return self._dereference(self._cursor.distinct(field), 1, - name=field, instance=self._document) - - def only(self, *fields): - """Load only a subset of this document's fields. :: - - post = BlogPost.objects(...).only("title", "author.name") - - :param fields: fields to include - - .. versionadded:: 0.3 - .. versionchanged:: 0.5 - Added subfield support - """ - fields = dict([(f, QueryFieldList.ONLY) for f in fields]) - return self.fields(**fields) - - def exclude(self, *fields): - """Opposite to .only(), exclude some document's fields. :: - - post = BlogPost.objects(...).exclude("comments") - - :param fields: fields to exclude - - .. versionadded:: 0.5 - """ - fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) - return self.fields(**fields) - - def fields(self, **kwargs): - """Manipulate how you load this document's fields. Used by `.only()` - and `.exclude()` to manipulate which fields to retrieve. Fields also - allows for a greater level of control for example: - - Retrieving a Subrange of Array Elements: - - You can use the $slice operator to retrieve a subrange of elements in - an array :: - - post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments - - :param kwargs: A dictionary identifying what to include - - .. versionadded:: 0.5 - """ - - # Check for an operator and transform to mongo-style if there is - operators = ["slice"] - cleaned_fields = [] - for key, value in kwargs.items(): - parts = key.split('__') - op = None - if parts[0] in operators: - op = parts.pop(0) - value = {'$' + op: value} - key = '.'.join(parts) - cleaned_fields.append((key, value)) - - fields = sorted(cleaned_fields, key=operator.itemgetter(1)) - for value, group in itertools.groupby(fields, lambda x: x[1]): - fields = [field for field, value in group] - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, value=value) - return self - - def all_fields(self): - """Include all fields. Reset all previously calls of .only() and .exclude(). :: - - post = BlogPost.objects(...).exclude("comments").only("title").all_fields() - - .. versionadded:: 0.5 - """ - self._loaded_fields = QueryFieldList(always_include=self._loaded_fields.always_include) - return self - - def _fields_to_dbfields(self, fields): - """Translate fields paths to its db equivalents""" - ret = [] - for field in fields: - field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) - ret.append(field) - return ret - - def order_by(self, *keys): - """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The - order may be specified by prepending each of the keys by a + or a -. - Ascending order is assumed. - - :param keys: fields to order the query results by; keys may be - prefixed with **+** or **-** to determine the ordering direction - """ - key_list = [] - for key in keys: - if not key: continue - direction = pymongo.ASCENDING - if key[0] == '-': - direction = pymongo.DESCENDING - if key[0] in ('-', '+'): - key = key[1:] - key = key.replace('__', '.') - try: - key = QuerySet._translate_field_name(self._document, key) - except: - pass - key_list.append((key, direction)) - - self._ordering = key_list - if self._cursor_obj: - self._cursor_obj.sort(key_list) - return self - - def explain(self, format=False): - """Return an explain plan record for the - :class:`~mongoengine.queryset.QuerySet`\ 's cursor. - - :param format: format the plan before returning it - """ - - plan = self._cursor.explain() - if format: - plan = pprint.pformat(plan) - return plan - - def snapshot(self, enabled): - """Enable or disable snapshot mode when querying. - - :param enabled: whether or not snapshot mode is enabled - - ..versionchanged:: 0.5 - made chainable - """ - self._snapshot = enabled - return self - - def timeout(self, enabled): - """Enable or disable the default mongod timeout when querying. - - :param enabled: whether or not the timeout is used - - ..versionchanged:: 0.5 - made chainable - """ - self._timeout = enabled - return self - - def slave_okay(self, enabled): - """Enable or disable the slave_okay when querying. - - :param enabled: whether or not the slave_okay is enabled - """ - self._slave_okay = enabled - return self - - def delete(self, safe=False): - """Delete the documents matched by the query. - - :param safe: check if the operation succeeded before returning - """ - doc = self._document - - # Handle deletes where skips or limits have been applied - if self._skip or self._limit: - for doc in self: - doc.delete() - return - - delete_rules = doc._meta.get('delete_rules') or {} - # Check for DENY rules before actually deleting/nullifying any other - # references - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects(**{field_name + '__in': self}).count() > 0: - msg = u'Could not delete document (at least %s.%s refers to it)' % \ - (document_cls.__name__, field_name) - raise OperationError(msg) - - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == CASCADE: - ref_q = document_cls.objects(**{field_name + '__in': self}) - ref_q_count = ref_q.count() - if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): - ref_q.delete(safe=safe) - elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'unset__%s' % field_name: 1}) - elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'pull_all__%s' % field_name: self}) - - self._collection.remove(self._query, safe=safe) - - @classmethod - def _transform_update(cls, _doc_cls=None, **update): - """Transform an update spec from Django-style format to Mongo format. - """ - operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', - 'pull', 'pull_all', 'add_to_set'] - match_operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not'] - - mongo_update = {} - for key, value in update.items(): - if key == "__raw__": - mongo_update.update(value) - continue - parts = key.split('__') - # Check for an operator and transform to mongo-style if there is - op = None - if parts[0] in operators: - op = parts.pop(0) - # Convert Pythonic names to Mongo equivalents - if op in ('push_all', 'pull_all'): - op = op.replace('_all', 'All') - elif op == 'dec': - # Support decrement by flipping a positive value's sign - # and using 'inc' - op = 'inc' - if value > 0: - value = -value - elif op == 'add_to_set': - op = op.replace('_to_set', 'ToSet') - - match = None - if parts[-1] in match_operators: - match = parts.pop() - - if _doc_cls: - # Switch field names to proper names [set in Field(name='foo')] - fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [] - - cleaned_fields = [] - for field in fields: - append_field = True - if isinstance(field, basestring): - # Convert the S operator to $ - if field == 'S': - field = '$' - parts.append(field) - append_field = False - else: - parts.append(field.db_field) - if append_field: - cleaned_fields.append(field) - - # Convert value to proper value - field = cleaned_fields[-1] - - if op in (None, 'set', 'push', 'pull'): - if field.required or value is not None: - value = field.prepare_query_value(op, value) - elif op in ('pushAll', 'pullAll'): - value = [field.prepare_query_value(op, v) for v in value] - elif op == 'addToSet': - if isinstance(value, (list, tuple, set)): - value = [field.prepare_query_value(op, v) for v in value] - elif field.required or value is not None: - value = field.prepare_query_value(op, value) - - if match: - match = '$' + match - value = {match: value} - - key = '.'.join(parts) - - if not op: - raise InvalidQueryError("Updates must supply an operation " - "eg: set__FIELD=value") - - if 'pull' in op and '.' in key: - # Dot operators don't work on pull operations - # it uses nested dict syntax - if op == 'pullAll': - raise InvalidQueryError("pullAll operations only support " - "a single field depth") - - parts.reverse() - for key in parts: - value = {key: value} - elif op == 'addToSet' and isinstance(value, list): - value = {key: {"$each": value}} - else: - value = {key: value} - key = '$' + op - - if key not in mongo_update: - mongo_update[key] = value - elif key in mongo_update and isinstance(mongo_update[key], dict): - mongo_update[key].update(value) - - return mongo_update - - def update(self, safe_update=True, upsert=False, multi=True, write_options=None, **update): - """Perform an atomic update on the fields matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` - - .. versionadded:: 0.2 - """ - if not update: - raise OperationError("No update parameters, would remove data") - - if not write_options: - write_options = {} - - update = QuerySet._transform_update(self._document, **update) - query = self._query - - # SERVER-5247 hack - remove_types = "_types" in query and ".$." in unicode(update) - if remove_types: - del query["_types"] - - try: - ret = self._collection.update(query, update, multi=multi, - upsert=upsert, safe=safe_update, - **write_options) - if ret is not None and 'n' in ret: - return ret['n'] - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' - raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) - - def update_one(self, safe_update=True, upsert=False, write_options=None, **update): - """Perform an atomic update on first field matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 - """ - if not update: - raise OperationError("No update parameters, would remove data") - - if not write_options: - write_options = {} - update = QuerySet._transform_update(self._document, **update) - query = self._query - - # SERVER-5247 hack - remove_types = "_types" in query and ".$." in unicode(update) - if remove_types: - del query["_types"] - - try: - # Explicitly provide 'multi=False' to newer versions of PyMongo - # as the default may change to 'True' - ret = self._collection.update(query, update, multi=False, - upsert=upsert, safe=safe_update, - **write_options) - - if ret is not None and 'n' in ret: - return ret['n'] - except pymongo.errors.OperationFailure, e: - raise OperationError(u'Update failed [%s]' % unicode(e)) - - def __iter__(self): - self.rewind() - return self - - def _get_scalar(self, doc): - - def lookup(obj, name): - chunks = name.split('__') - for chunk in chunks: - obj = getattr(obj, chunk) - return obj - - data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: - return data[0] - - return tuple(data) - - def _get_as_pymongo(self, row): - # Extract which fields paths we should follow if .fields(...) was - # used. If not, handle all fields. - if not getattr(self, '__as_pymongo_fields', None): - self.__as_pymongo_fields = [] - for field in self._loaded_fields.fields - set(['_cls', '_id', '_types']): - self.__as_pymongo_fields.append(field) - while '.' in field: - field, _ = field.rsplit('.', 1) - self.__as_pymongo_fields.append(field) - - all_fields = not self.__as_pymongo_fields - - def clean(data, path=None): - path = path or '' - - if isinstance(data, dict): - new_data = {} - for key, value in data.iteritems(): - new_path = '%s.%s' % (path, key) if path else key - if all_fields or new_path in self.__as_pymongo_fields: - new_data[key] = clean(value, path=new_path) - data = new_data - elif isinstance(data, list): - data = [clean(d, path=path) for d in data] - else: - if self._as_pymongo_coerce: - # If we need to coerce types, we need to determine the - # type of this field and use the corresponding .to_python(...) - from mongoengine.fields import EmbeddedDocumentField - obj = self._document - for chunk in path.split('.'): - obj = getattr(obj, chunk, None) - if obj is None: - break - elif isinstance(obj, EmbeddedDocumentField): - obj = obj.document_type - if obj and data is not None: - data = obj.to_python(data) - return data - return clean(row) - - def scalar(self, *fields): - """Instead of returning Document instances, return either a specific - value or a tuple of values in order. - - This effects all results and can be unset by calling ``scalar`` - without arguments. Calls ``only`` automatically. - - :param fields: One or more fields to return instead of a Document. - """ - self._scalar = list(fields) - - if fields: - self.only(*fields) - else: - self.all_fields() - - return self - - def values_list(self, *fields): - """An alias for scalar""" - return self.scalar(*fields) - - def as_pymongo(self, coerce_types=False): - """Instead of returning Document instances, return raw values from - pymongo. - - :param coerce_type: Field types (if applicable) would be use to coerce types. - """ - self._as_pymongo = True - self._as_pymongo_coerce = coerce_types - return self - - def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be - substituted for the MongoDB name of the field (specified using the - :attr:`name` keyword argument in a field's constructor). - """ - def field_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = QuerySet._lookup_field(self._document, field_name) - # Substitute the correct name for the field into the javascript - return u'["%s"]' % fields[-1].db_field - - def field_path_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = QuerySet._lookup_field(self._document, field_name) - # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) - - code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, code) - return code - - def exec_js(self, code, *fields, **options): - """Execute a Javascript function on the server. A list of fields may be - provided, which will be translated to their correct names and supplied - as the arguments to the function. A few extra variables are added to - the function's scope: ``collection``, which is the name of the - collection in use; ``query``, which is an object representing the - current query; and ``options``, which is an object containing any - options specified as keyword arguments. - - As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` - constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a - field, use square-bracket notation, and prefix the MongoEngine field - name with a tilde (~). - - :param code: a string of Javascript code to execute - :param fields: fields that you will be using in your function, which - will be passed in to your function as arguments - :param options: options that you want available to the function - (accessed in Javascript through the ``options`` object) - """ - code = self._sub_js_fields(code) - - fields = [QuerySet._translate_field_name(self._document, f) - for f in fields] - collection = self._document._get_collection_name() - - scope = { - 'collection': collection, - 'options': options or {}, - } - - query = self._query - if self._where_clause: - query['$where'] = self._where_clause - - scope['query'] = query - code = Code(code, scope=scope) - - db = self._document._get_db() - return db.eval(code, *fields) - - def where(self, where_clause): - """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript - expression). Performs automatic field name substitution like - :meth:`mongoengine.queryset.Queryset.exec_js`. - - .. note:: When using this mode of query, the database will call your - function, or evaluate your predicate clause, for each object - in the collection. - - .. versionadded:: 0.5 - """ - where_clause = self._sub_js_fields(where_clause) - self._where_clause = where_clause - return self - - def sum(self, field): - """Sum over the values of the specified field. - - :param field: the field to sum over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - emit(1, this[field] || 0); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var sum = 0; - for (var i in values) { - sum += values[i]; - } - return sum; - } - """) - - for result in self.map_reduce(map_func, reduce_func, output='inline'): - return result.value - else: - return 0 - - def average(self, field): - """Average over the values of the specified field. - - :param field: the field to average over; use dot-notation to refer to - embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. - """ - map_func = Code(""" - function() { - if (this.hasOwnProperty(field)) - emit(1, {t: this[field] || 0, c: 1}); - } - """, scope={'field': field}) - - reduce_func = Code(""" - function(key, values) { - var out = {t: 0, c: 0}; - for (var i in values) { - var value = values[i]; - out.t += value.t; - out.c += value.c; - } - return out; - } - """) - - finalize_func = Code(""" - function(key, value) { - return value.t / value.c; - } - """) - - for result in self.map_reduce(map_func, reduce_func, finalize_f=finalize_func, output='inline'): - return result.value - else: - return 0 - - def item_frequencies(self, field, normalize=False, map_reduce=True): - """Returns a dictionary of all items present in a field across - the whole queried set of documents, and their corresponding frequency. - This is useful for generating tag clouds, or searching documents. - - .. note:: - - Can only do direct simple mappings and cannot map across - :class:`~mongoengine.ReferenceField` or - :class:`~mongoengine.GenericReferenceField` for more complex - counting a manual map reduce call would is required. - - If the field is a :class:`~mongoengine.ListField`, the items within - each list will be counted individually. - - :param field: the field to use - :param normalize: normalize the results so they add to 1.0 - :param map_reduce: Use map_reduce over exec_js - - .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded - document lookups - """ - if map_reduce: - return self._item_frequencies_map_reduce(field, normalize=normalize) - return self._item_frequencies_exec_js(field, normalize=normalize) - - def _item_frequencies_map_reduce(self, field, normalize=False): - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'); - var field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(item, 1); - }); - } else if (typeof field != 'undefined') { - emit(field, 1); - } else { - emit(null, 1); - } - } - """ % dict(field=field) - reduce_func = """ - function(key, values) { - var total = 0; - var valuesSize = values.length; - for (var i=0; i < valuesSize; i++) { - total += parseInt(values[i], 10); - } - return total; - } - """ - values = self.map_reduce(map_func, reduce_func, 'inline') - frequencies = {} - for f in values: - key = f.key - if isinstance(key, float): - if int(key) == key: - key = int(key) - frequencies[key] = int(f.value) - - if normalize: - count = sum(frequencies.values()) - frequencies = dict([(k, float(v) / count) - for k, v in frequencies.items()]) - - return frequencies - - def _item_frequencies_exec_js(self, field, normalize=False): - """Uses exec_js to execute""" - freq_func = """ - function(path) { - var path = path.split('.'); - - var total = 0.0; - db[collection].find(query).forEach(function(doc) { - var field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - total += field.length; - } else { - total++; - } - }); - - var frequencies = {}; - var types = {}; - var inc = 1.0; - - db[collection].find(query).forEach(function(doc) { - field = doc; - for (p in path) { - if (field) - field = field[path[p]]; - else - break; - } - if (field && field.constructor == Array) { - field.forEach(function(item) { - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - }); - } else { - var item = field; - types[item] = item; - frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); - } - }); - return [total, frequencies, types]; - } - """ - total, data, types = self.exec_js(freq_func, field) - values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) - - if normalize: - values = dict([(k, float(v) / total) for k, v in values.items()]) - - frequencies = {} - for k, v in values.iteritems(): - if isinstance(k, float): - if int(k) == k: - k = int(k) - - frequencies[k] = v - - return frequencies - - def __repr__(self): - """Provides the string representation of the QuerySet - - .. versionchanged:: 0.6.13 Now doesnt modify the cursor - """ - - if self._iter: - return '.. queryset mid-iteration ..' - - data = [] - for i in xrange(REPR_OUTPUT_SIZE + 1): - try: - data.append(self.next()) - except StopIteration: - break - if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." - - self.rewind() - return repr(data) - - def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to - a maximum depth in order to cut down the number queries to mongodb. - - .. versionadded:: 0.5 - """ - # Make select related work the same for querysets - max_depth += 1 - return self._dereference(self, max_depth=max_depth) - - @property - def _dereference(self): - if not self.__dereference: - from dereference import DeReference - self.__dereference = DeReference() # Cached - return self.__dereference - - -class QuerySetManager(object): - """ - The default QuerySet Manager. - - Custom QuerySet Manager functions can extend this class and users can - add extra queryset functionality. Any custom manager methods must accept a - :class:`~mongoengine.Document` class as its first argument, and a - :class:`~mongoengine.queryset.QuerySet` as its second argument. - - The method function should return a :class:`~mongoengine.queryset.QuerySet` - , probably the same one that was passed in, but modified in some way. - """ - - get_queryset = None - - def __init__(self, queryset_func=None): - if queryset_func: - self.get_queryset = queryset_func - self._collections = {} - - def __get__(self, instance, owner): - """Descriptor for instantiating a new QuerySet object when - Document.objects is accessed. - """ - if instance is not None: - # Document class being used rather than a document object - return self - - # owner is the document that contains the QuerySetManager - queryset_class = owner._meta.get('queryset_class') or QuerySet - queryset = queryset_class(owner, owner._get_collection()) - if self.get_queryset: - arg_count = self.get_queryset.func_code.co_argcount - if arg_count == 1: - queryset = self.get_queryset(queryset) - elif arg_count == 2: - queryset = self.get_queryset(owner, queryset) - else: - queryset = partial(self.get_queryset, owner, queryset) - return queryset - - -def queryset_manager(func): - """Decorator that allows you to define custom QuerySet managers on - :class:`~mongoengine.Document` classes. The manager must be a function that - accepts a :class:`~mongoengine.Document` class as its first argument, and a - :class:`~mongoengine.queryset.QuerySet` as its second argument. The method - function should return a :class:`~mongoengine.queryset.QuerySet`, probably - the same one that was passed in, but modified in some way. - """ - if func.func_code.co_argcount == 1: - import warnings - msg = 'Methods decorated with queryset_manager should take 2 arguments' - warnings.warn(msg, DeprecationWarning) - return QuerySetManager(func) diff --git a/mongoengine/queryset/__init__.py b/mongoengine/queryset/__init__.py new file mode 100644 index 00000000..026a7acd --- /dev/null +++ b/mongoengine/queryset/__init__.py @@ -0,0 +1,11 @@ +from mongoengine.errors import (DoesNotExist, MultipleObjectsReturned, + InvalidQueryError, OperationError, + NotUniqueError) +from mongoengine.queryset.field_list import * +from mongoengine.queryset.manager import * +from mongoengine.queryset.queryset import * +from mongoengine.queryset.transform import * +from mongoengine.queryset.visitor import * + +__all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ + + transform.__all__ + visitor.__all__) diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py new file mode 100644 index 00000000..73d3cc24 --- /dev/null +++ b/mongoengine/queryset/field_list.py @@ -0,0 +1,85 @@ + +__all__ = ('QueryFieldList',) + + +class QueryFieldList(object): + """Object that handles combinations of .only() and .exclude() calls""" + ONLY = 1 + EXCLUDE = 0 + + def __init__(self, fields=None, value=ONLY, always_include=None, _only_called=False): + """The QueryFieldList builder + + :param fields: A list of fields used in `.only()` or `.exclude()` + :param value: How to handle the fields; either `ONLY` or `EXCLUDE` + :param always_include: Any fields to always_include eg `_cls` + :param _only_called: Has `.only()` been called? If so its a set of fields + otherwise it performs a union. + """ + self.value = value + self.fields = set(fields or []) + self.always_include = set(always_include or []) + self._id = None + self._only_called = _only_called + self.slice = {} + + def __add__(self, f): + if isinstance(f.value, dict): + for field in f.fields: + self.slice[field] = f.value + if not self.fields: + self.fields = f.fields + elif not self.fields: + self.fields = f.fields + self.value = f.value + self.slice = {} + elif self.value is self.ONLY and f.value is self.ONLY: + self._clean_slice() + if self._only_called: + self.fields = self.fields.union(f.fields) + else: + self.fields = f.fields + elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: + self.fields = self.fields.union(f.fields) + self._clean_slice() + elif self.value is self.ONLY and f.value is self.EXCLUDE: + self.fields -= f.fields + self._clean_slice() + elif self.value is self.EXCLUDE and f.value is self.ONLY: + self.value = self.ONLY + self.fields = f.fields - self.fields + self._clean_slice() + + if '_id' in f.fields: + self._id = f.value + + if self.always_include: + if self.value is self.ONLY and self.fields: + self.fields = self.fields.union(self.always_include) + else: + self.fields -= self.always_include + + if getattr(f, '_only_called', False): + self._only_called = True + return self + + def __nonzero__(self): + return bool(self.fields) + + def as_dict(self): + field_list = dict((field, self.value) for field in self.fields) + if self.slice: + field_list.update(self.slice) + if self._id is not None: + field_list['_id'] = self._id + return field_list + + def reset(self): + self.fields = set([]) + self.slice = {} + self.value = self.ONLY + + def _clean_slice(self): + if self.slice: + for field in set(self.slice.keys()) - self.fields: + del self.slice[field] diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py new file mode 100644 index 00000000..47c2143d --- /dev/null +++ b/mongoengine/queryset/manager.py @@ -0,0 +1,57 @@ +from functools import partial +from mongoengine.queryset.queryset import QuerySet + +__all__ = ('queryset_manager', 'QuerySetManager') + + +class QuerySetManager(object): + """ + The default QuerySet Manager. + + Custom QuerySet Manager functions can extend this class and users can + add extra queryset functionality. Any custom manager methods must accept a + :class:`~mongoengine.Document` class as its first argument, and a + :class:`~mongoengine.queryset.QuerySet` as its second argument. + + The method function should return a :class:`~mongoengine.queryset.QuerySet` + , probably the same one that was passed in, but modified in some way. + """ + + get_queryset = None + default = QuerySet + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func + + def __get__(self, instance, owner): + """Descriptor for instantiating a new QuerySet object when + Document.objects is accessed. + """ + if instance is not None: + # Document class being used rather than a document object + return self + + # owner is the document that contains the QuerySetManager + queryset_class = owner._meta.get('queryset_class', self.default) + queryset = queryset_class(owner, owner._get_collection()) + if self.get_queryset: + arg_count = self.get_queryset.func_code.co_argcount + if arg_count == 1: + queryset = self.get_queryset(queryset) + elif arg_count == 2: + queryset = self.get_queryset(owner, queryset) + else: + queryset = partial(self.get_queryset, owner, queryset) + return queryset + + +def queryset_manager(func): + """Decorator that allows you to define custom QuerySet managers on + :class:`~mongoengine.Document` classes. The manager must be a function that + accepts a :class:`~mongoengine.Document` class as its first argument, and a + :class:`~mongoengine.queryset.QuerySet` as its second argument. The method + function should return a :class:`~mongoengine.queryset.QuerySet`, probably + the same one that was passed in, but modified in some way. + """ + return QuerySetManager(func) diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py new file mode 100644 index 00000000..65d6553f --- /dev/null +++ b/mongoengine/queryset/queryset.py @@ -0,0 +1,1436 @@ +from __future__ import absolute_import + +import copy +import itertools +import operator +import pprint +import re +import warnings + +from bson.code import Code +from bson import json_util +import pymongo +from pymongo.common import validate_read_preference + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.errors import (OperationError, NotUniqueError, + InvalidQueryError) + +from mongoengine.queryset import transform +from mongoengine.queryset.field_list import QueryFieldList +from mongoengine.queryset.visitor import Q, QNode + + +__all__ = ('QuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') + +# The maximum number of items to display in a QuerySet.__repr__ +REPR_OUTPUT_SIZE = 20 + +# Delete rules +DO_NOTHING = 0 +NULLIFY = 1 +CASCADE = 2 +DENY = 3 +PULL = 4 + +RE_TYPE = type(re.compile('')) + + +class QuerySet(object): + """A set of results returned from a query. Wraps a MongoDB cursor, + providing :class:`~mongoengine.Document` objects as the results. + """ + __dereference = False + _auto_dereference = True + + def __init__(self, document, collection): + self._document = document + self._collection_obj = collection + self._mongo_query = None + self._query_obj = Q() + self._initial_query = {} + self._where_clause = None + self._loaded_fields = QueryFieldList() + self._ordering = [] + self._snapshot = False + self._timeout = True + self._class_check = True + self._slave_okay = False + self._read_preference = None + self._iter = False + self._scalar = [] + self._none = False + self._as_pymongo = False + self._as_pymongo_coerce = False + + # If inheritance is allowed, only return instances and instances of + # subclasses of the class being used + if document._meta.get('allow_inheritance') is True: + self._initial_query = {"_cls": {"$in": self._document._subclasses}} + self._loaded_fields = QueryFieldList(always_include=['_cls']) + self._cursor_obj = None + self._limit = None + self._skip = None + self._slice = None + self._hint = -1 # Using -1 as None is a valid value for hint + + def __call__(self, q_obj=None, class_check=True, slave_okay=False, + read_preference=None, **query): + """Filter the selected documents by calling the + :class:`~mongoengine.queryset.QuerySet` with a query. + + :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in + the query; the :class:`~mongoengine.queryset.QuerySet` is filtered + multiple times with different :class:`~mongoengine.queryset.Q` + objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection + :param slave_okay: if True, allows this query to be run against a + replica secondary. + :params read_preference: if set, overrides connection-level + read_preference from `ReplicaSetConnection`. + :param query: Django-style query keyword arguments + """ + query = Q(**query) + if q_obj: + # make sure proper query object is passed + if not isinstance(q_obj, QNode): + msg = ("Not a query object: %s. " + "Did you intend to use key=value?" % q_obj) + raise InvalidQueryError(msg) + query &= q_obj + + queryset = self.clone() + queryset._query_obj &= query + queryset._mongo_query = None + queryset._cursor_obj = None + if read_preference is not None: + queryset.read_preference(read_preference) + queryset._class_check = class_check + return queryset + + def __iter__(self): + """Support iterator protocol""" + queryset = self + if queryset._iter: + queryset = self.clone() + queryset.rewind() + return queryset + + def __getitem__(self, key): + """Support skip and limit using getitem and slicing syntax. + """ + queryset = self.clone() + + # Slice provided + if isinstance(key, slice): + try: + queryset._cursor_obj = queryset._cursor[key] + queryset._slice = key + queryset._skip, queryset._limit = key.start, key.stop + except IndexError, err: + # PyMongo raises an error if key.start == key.stop, catch it, + # bin it, kill it. + start = key.start or 0 + if start >= 0 and key.stop >= 0 and key.step is None: + if start == key.stop: + queryset.limit(0) + queryset._skip = key.start + queryset._limit = key.stop - start + return queryset + raise err + # Allow further QuerySet modifications to be performed + return queryset + # Integer index provided + elif isinstance(key, int): + if queryset._scalar: + return queryset._get_scalar( + queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference)) + if queryset._as_pymongo: + return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._document._from_son(queryset._cursor[key], + _auto_dereference=self._auto_dereference) + raise AttributeError + + def __repr__(self): + """Provides the string representation of the QuerySet + + .. versionchanged:: 0.6.13 Now doesnt modify the cursor + """ + if self._iter: + return '.. queryset mid-iteration ..' + + data = [] + for i in xrange(REPR_OUTPUT_SIZE + 1): + try: + data.append(self.next()) + except StopIteration: + break + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." + + self.rewind() + return repr(data) + + # Core functions + + def all(self): + """Returns all documents.""" + return self.__call__() + + def filter(self, *q_objs, **query): + """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` + """ + return self.__call__(*q_objs, **query) + + def get(self, *q_objs, **query): + """Retrieve the the matching object raising + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` exception if multiple results + and :class:`~mongoengine.queryset.DoesNotExist` or + `DocumentName.DoesNotExist` if no results are found. + + .. versionadded:: 0.3 + """ + queryset = self.__call__(*q_objs, **query) + queryset = queryset.limit(2) + try: + result = queryset.next() + except StopIteration: + msg = ("%s matching query does not exist." + % queryset._document._class_name) + raise queryset._document.DoesNotExist(msg) + try: + queryset.next() + except StopIteration: + return result + + queryset.rewind() + message = u'%d items returned, instead of 1' % queryset.count() + raise queryset._document.MultipleObjectsReturned(message) + + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + return self._document(**kwargs).save() + + def get_or_create(self, write_concern=None, auto_save=True, + *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a + tuple of ``(object, created)``, where ``object`` is the retrieved or + created object and ``created`` is a boolean specifying whether a new + object was created. Raises + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + `DocumentName.MultipleObjectsReturned` if multiple results are found. + A new document will be created if the document doesn't exists; a + dictionary of default values for the new document may be provided as a + keyword argument called :attr:`defaults`. + + .. note:: This requires two separate operations and therefore a + race condition exists. Because there are no transactions in + mongoDB other approaches should be investigated, to ensure you + don't accidently duplicate data when using this method. This is + now scheduled to be removed before 1.0 + + :param write_concern: optional extra keyword arguments used if we + have to create a new document. + Passes any write_concern onto :meth:`~mongoengine.Document.save` + + :param auto_save: if the object is to be saved automatically if + not found. + + .. deprecated:: 0.8 + .. versionchanged:: 0.6 - added `auto_save` + .. versionadded:: 0.3 + """ + msg = ("get_or_create is scheduled to be deprecated. The approach is " + "flawed without transactions. Upserts should be preferred.") + warnings.warn(msg, DeprecationWarning) + + defaults = query.get('defaults', {}) + if 'defaults' in query: + del query['defaults'] + + try: + doc = self.get(*q_objs, **query) + return doc, False + except self._document.DoesNotExist: + query.update(defaults) + doc = self._document(**query) + + if auto_save: + doc.save(write_concern=write_concern) + return doc, True + + def first(self): + """Retrieve the first object matching the query. + """ + queryset = self.clone() + try: + result = queryset[0] + except IndexError: + result = None + return result + + def insert(self, doc_or_docs, load_bulk=True, write_concern=None): + """bulk insert documents + + :param docs_or_doc: a document or list of documents to be inserted + :param load_bulk (optional): If True returns the list of document + instances + :param write_concern: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant + ``getLastError`` command. For example, + ``insert(..., {w: 2, fsync: True})`` will wait until at least + two servers have recorded the write and will force an fsync on + each server being written to. + + By default returns document instances, set ``load_bulk`` to False to + return just ``ObjectIds`` + + .. versionadded:: 0.5 + """ + Document = _import_class('Document') + + if not write_concern: + write_concern = {} + + docs = doc_or_docs + return_one = False + if isinstance(docs, Document) or issubclass(docs.__class__, Document): + return_one = True + docs = [docs] + + raw = [] + for doc in docs: + if not isinstance(doc, self._document): + msg = ("Some documents inserted aren't instances of %s" + % str(self._document)) + raise OperationError(msg) + if doc.pk and not doc._created: + msg = "Some documents have ObjectIds use doc.update() instead" + raise OperationError(msg) + raw.append(doc.to_mongo()) + + signals.pre_bulk_insert.send(self._document, documents=docs) + try: + ids = self._collection.insert(raw, **write_concern) + except pymongo.errors.OperationFailure, err: + message = 'Could not save document (%s)' + if re.match('^E1100[01] duplicate key', unicode(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = u'Tried to save duplicate unique keys (%s)' + raise NotUniqueError(message % unicode(err)) + raise OperationError(message % unicode(err)) + + if not load_bulk: + signals.post_bulk_insert.send( + self._document, documents=docs, loaded=False) + return return_one and ids[0] or ids + + documents = self.in_bulk(ids) + results = [] + for obj_id in ids: + results.append(documents.get(obj_id)) + signals.post_bulk_insert.send( + self._document, documents=results, loaded=True) + return return_one and results[0] or results + + def count(self, with_limit_and_skip=True): + """Count the selected elements in the query. + + :param with_limit_and_skip (optional): take any :meth:`limit` or + :meth:`skip` that has been applied to this cursor into account when + getting the count + """ + if self._limit == 0: + return 0 + return self._cursor.count(with_limit_and_skip=with_limit_and_skip) + + def delete(self, write_concern=None): + """Delete the documents matched by the query. + + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + """ + queryset = self.clone() + doc = queryset._document + + has_delete_signal = signals.signals_available and ( + signals.pre_delete.has_receivers_for(self._document) or + signals.post_delete.has_receivers_for(self._document)) + + if not write_concern: + write_concern = {} + + # Handle deletes where skips or limits have been applied or has a + # delete signal + if queryset._skip or queryset._limit or has_delete_signal: + for doc in queryset: + doc.delete(write_concern=write_concern) + return + + delete_rules = doc._meta.get('delete_rules') or {} + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects( + **{field_name + '__in': self}).count() > 0: + msg = ("Could not delete document (%s.%s refers to it)" + % (document_cls.__name__, field_name)) + raise OperationError(msg) + + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + ref_q = document_cls.objects(**{field_name + '__in': self}) + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): + ref_q.delete(write_concern=write_concern) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, **{'unset__%s' % field_name: 1}) + elif rule == PULL: + document_cls.objects(**{field_name + '__in': self}).update( + write_concern=write_concern, + **{'pull_all__%s' % field_name: self}) + + queryset._collection.remove(queryset._query, write_concern=write_concern) + + def update(self, upsert=False, multi=True, write_concern=None, **update): + """Perform an atomic update on the fields matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param multi: Update multiple documents. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + if not update: + raise OperationError("No update parameters, would remove data") + + if not write_concern: + write_concern = {} + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + # If doing an atomic upsert on an inheritable class + # then ensure we add _cls to the update operation + if upsert and '_cls' in query: + if '$set' in update: + update["$set"]["_cls"] = queryset._document._class_name + else: + update["$set"] = {"_cls": queryset._document._class_name} + + try: + ret = queryset._collection.update(query, update, multi=multi, + upsert=upsert, **write_concern) + if ret is not None and 'n' in ret: + return ret['n'] + except pymongo.errors.OperationFailure, err: + if unicode(err) == u'multi not coded yet': + message = u'update() method requires MongoDB 1.1.3+' + raise OperationError(message) + raise OperationError(u'Update failed (%s)' % unicode(err)) + + def update_one(self, upsert=False, write_concern=None, **update): + """Perform an atomic update on first field matched by the query. + + :param upsert: Any existing document with that "_id" is overwritten. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + return self.update(upsert=upsert, multi=False, write_concern=None, **update) + + def with_id(self, object_id): + """Retrieve the object matching the id provided. Uses `object_id` only + and raises InvalidQueryError if a filter has been applied. Returns + `None` if no document exists with that id. + + :param object_id: the value for the id of the document to look up + + .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set + """ + queryset = self.clone() + if not queryset._query_obj.empty: + msg = "Cannot use a filter whilst using `with_id`" + raise InvalidQueryError(msg) + return queryset.filter(pk=object_id).first() + + def in_bulk(self, object_ids): + """Retrieve a set of documents by their ids. + + :param object_ids: a list or tuple of ``ObjectId``\ s + :rtype: dict of ObjectIds as keys and collection-specific + Document subclasses as values. + + .. versionadded:: 0.3 + """ + doc_map = {} + + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) + if self._scalar: + for doc in docs: + doc_map[doc['_id']] = self._get_scalar( + self._document._from_son(doc)) + elif self._as_pymongo: + for doc in docs: + doc_map[doc['_id']] = self._get_as_pymongo(doc) + else: + for doc in docs: + doc_map[doc['_id']] = self._document._from_son(doc) + + return doc_map + + def none(self): + """Helper that just returns a list""" + queryset = self.clone() + queryset._none = True + return queryset + + def clone(self): + """Creates a copy of the current + :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 + """ + c = self.__class__(self._document, self._collection_obj) + + copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', + '_where_clause', '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_class_check', '_slave_okay', '_read_preference', + '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', + '_limit', '_skip', '_hint', '_auto_dereference') + + for prop in copy_props: + val = getattr(self, prop) + setattr(c, prop, copy.copy(val)) + + if self._slice: + c._slice = self._slice + + if self._cursor_obj: + c._cursor_obj = self._cursor_obj.clone() + + if self._slice: + c._cursor[self._slice] + + return c + + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to + a maximum depth in order to cut down the number queries to mongodb. + + .. versionadded:: 0.5 + """ + # Make select related work the same for querysets + max_depth += 1 + queryset = self.clone() + return queryset._dereference(queryset, max_depth=max_depth) + + def limit(self, n): + """Limit the number of returned documents to `n`. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[:5]``). + + :param n: the maximum number of objects to return + """ + queryset = self.clone() + if n == 0: + queryset._cursor.limit(1) + else: + queryset._cursor.limit(n) + queryset._limit = n + + # Return self to allow chaining + return queryset + + def skip(self, n): + """Skip `n` documents before returning the results. This may also be + achieved using array-slicing syntax (e.g. ``User.objects[5:]``). + + :param n: the number of objects to skip before returning results + """ + queryset = self.clone() + queryset._cursor.skip(n) + queryset._skip = n + return queryset + + def hint(self, index=None): + """Added 'hint' support, telling Mongo the proper index to use for the + query. + + Judicious use of hints can greatly improve query performance. When + doing a query on multiple fields (at least one of which is indexed) + pass the indexed field as a hint to the query. + + Hinting will not do anything if the corresponding index does not exist. + The last hint applied to this cursor takes precedence over all others. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._cursor.hint(index) + queryset._hint = index + return queryset + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. note:: This is a command and won't take ordering or limit into + account. + + .. versionadded:: 0.4 + .. versionchanged:: 0.5 - Fixed handling references + .. versionchanged:: 0.6 - Improved db_field refrence handling + """ + queryset = self.clone() + try: + field = self._fields_to_dbfields([field]).pop() + finally: + return self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) + + def only(self, *fields): + """Load only a subset of this document's fields. :: + + post = BlogPost.objects(...).only("title", "author.name") + + .. note :: `only()` is chainable and will perform a union :: + So with the following it will fetch both: `title` and `author.name`:: + + post = BlogPost.objects.only("title").only("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to include + + .. versionadded:: 0.3 + .. versionchanged:: 0.5 - Added subfield support + """ + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(True, **fields) + + def exclude(self, *fields): + """Opposite to .only(), exclude some document's fields. :: + + post = BlogPost.objects(...).exclude("comments") + + .. note :: `exclude()` is chainable and will perform a union :: + So with the following it will exclude both: `title` and `author.name`:: + + post = BlogPost.objects.exclude("title").exclude("author.name") + + :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any + field filters. + + :param fields: fields to exclude + + .. versionadded:: 0.5 + """ + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, _only_called=False, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements: + + You can use the $slice operator to retrieve a subrange of elements in + an array. For example to get the first 5 comments:: + + post = BlogPost.objects(...).fields(slice__comments=5) + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + queryset = self.clone() + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = queryset._fields_to_dbfields(fields) + queryset._loaded_fields += QueryFieldList(fields, value=value, _only_called=_only_called) + + return queryset + + def all_fields(self): + """Include all fields. Reset all previously calls of .only() or + .exclude(). :: + + post = BlogPost.objects.exclude("comments").all_fields() + + .. versionadded:: 0.5 + """ + queryset = self.clone() + queryset._loaded_fields = QueryFieldList( + always_include=queryset._loaded_fields.always_include) + return queryset + + def order_by(self, *keys): + """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The + order may be specified by prepending each of the keys by a + or a -. + Ascending order is assumed. + + :param keys: fields to order the query results by; keys may be + prefixed with **+** or **-** to determine the ordering direction + """ + queryset = self.clone() + queryset._ordering = queryset._get_order_by(keys) + return queryset + + def explain(self, format=False): + """Return an explain plan record for the + :class:`~mongoengine.queryset.QuerySet`\ 's cursor. + + :param format: format the plan before returning it + """ + plan = self._cursor.explain() + if format: + plan = pprint.pformat(plan) + return plan + + def snapshot(self, enabled): + """Enable or disable snapshot mode when querying. + + :param enabled: whether or not snapshot mode is enabled + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._snapshot = enabled + return queryset + + def timeout(self, enabled): + """Enable or disable the default mongod timeout when querying. + + :param enabled: whether or not the timeout is used + + ..versionchanged:: 0.5 - made chainable + """ + queryset = self.clone() + queryset._timeout = enabled + return queryset + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + queryset = self.clone() + queryset._slave_okay = enabled + return queryset + + def read_preference(self, read_preference): + """Change the read_preference when querying. + + :param read_preference: override ReplicaSetConnection-level + preference. + """ + validate_read_preference('read_preference', read_preference) + queryset = self.clone() + queryset._read_preference = read_preference + return queryset + + def scalar(self, *fields): + """Instead of returning Document instances, return either a specific + value or a tuple of values in order. + + Can be used along with + :func:`~mongoengine.queryset.QuerySet.no_dereference` to turn off + dereferencing. + + .. note:: This effects all results and can be unset by calling + ``scalar`` without arguments. Calls ``only`` automatically. + + :param fields: One or more fields to return instead of a Document. + """ + queryset = self.clone() + queryset._scalar = list(fields) + + if fields: + queryset = queryset.only(*fields) + else: + queryset = queryset.all_fields() + + return queryset + + def values_list(self, *fields): + """An alias for scalar""" + return self.scalar(*fields) + + def as_pymongo(self, coerce_types=False): + """Instead of returning Document instances, return raw values from + pymongo. + + :param coerce_type: Field types (if applicable) would be use to + coerce types. + """ + queryset = self.clone() + queryset._as_pymongo = True + queryset._as_pymongo_coerce = coerce_types + return queryset + + # JSON Helpers + + def to_json(self): + """Converts a queryset to JSON""" + queryset = self.clone() + return json_util.dumps(queryset._collection_obj.find(queryset._query)) + + def from_json(self, json_data): + """Converts json data to unsaved objects""" + son_data = json_util.loads(json_data) + return [self._document._from_son(data) for data in son_data] + + # JS functionality + + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): + """Perform a map/reduce query using the current query spec + and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, + it must be the last call made, as it does not return a maleable + ``QuerySet``. + + See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` + and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` + tests in ``tests.queryset.QuerySetTest`` for usage examples. + + :param map_f: map function, as :class:`~bson.code.Code` or string + :param reduce_f: reduce function, as + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce + :param finalize_f: finalize function, an optional function that + performs any post-reduction processing. + :param scope: values to insert into map/reduce global scope. Optional. + :param limit: number of objects from current query to provide + to map/reduce method + + Returns an iterator yielding + :class:`~mongoengine.document.MapReduceDocument`. + + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 + + .. versionadded:: 0.3 + """ + queryset = self.clone() + + MapReduceDocument = _import_class('MapReduceDocument') + + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") + + map_f_scope = {} + if isinstance(map_f, Code): + map_f_scope = map_f.scope + map_f = unicode(map_f) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + + reduce_f_scope = {} + if isinstance(reduce_f, Code): + reduce_f_scope = reduce_f.scope + reduce_f = unicode(reduce_f) + reduce_f_code = queryset._sub_js_fields(reduce_f) + reduce_f = Code(reduce_f_code, reduce_f_scope) + + mr_args = {'query': queryset._query} + + if finalize_f: + finalize_f_scope = {} + if isinstance(finalize_f, Code): + finalize_f_scope = finalize_f.scope + finalize_f = unicode(finalize_f) + finalize_f_code = queryset._sub_js_fields(finalize_f) + finalize_f = Code(finalize_f_code, finalize_f_scope) + mr_args['finalize'] = finalize_f + + if scope: + mr_args['scope'] = scope + + if limit: + mr_args['limit'] = limit + + if output == 'inline' and not queryset._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(queryset._collection, map_reduce_function)( + map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() + + if queryset._ordering: + results = results.sort(queryset._ordering) + + for doc in results: + yield MapReduceDocument(queryset._document, queryset._collection, + doc['_id'], doc['value']) + + def exec_js(self, code, *fields, **options): + """Execute a Javascript function on the server. A list of fields may be + provided, which will be translated to their correct names and supplied + as the arguments to the function. A few extra variables are added to + the function's scope: ``collection``, which is the name of the + collection in use; ``query``, which is an object representing the + current query; and ``options``, which is an object containing any + options specified as keyword arguments. + + As fields in MongoEngine may use different names in the database (set + using the :attr:`db_field` keyword argument to a :class:`Field` + constructor), a mechanism exists for replacing MongoEngine field names + with the database field names in Javascript code. When accessing a + field, use square-bracket notation, and prefix the MongoEngine field + name with a tilde (~). + + :param code: a string of Javascript code to execute + :param fields: fields that you will be using in your function, which + will be passed in to your function as arguments + :param options: options that you want available to the function + (accessed in Javascript through the ``options`` object) + """ + queryset = self.clone() + + code = queryset._sub_js_fields(code) + + fields = [queryset._document._translate_field_name(f) for f in fields] + collection = queryset._document._get_collection_name() + + scope = { + 'collection': collection, + 'options': options or {}, + } + + query = queryset._query + if queryset._where_clause: + query['$where'] = queryset._where_clause + + scope['query'] = query + code = Code(code, scope=scope) + + db = queryset._document._get_db() + return db.eval(code, *fields) + + def where(self, where_clause): + """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript + expression). Performs automatic field name substitution like + :meth:`mongoengine.queryset.Queryset.exec_js`. + + .. note:: When using this mode of query, the database will call your + function, or evaluate your predicate clause, for each object + in the collection. + + .. versionadded:: 0.5 + """ + queryset = self.clone() + where_clause = queryset._sub_js_fields(where_clause) + queryset._where_clause = where_clause + return queryset + + def sum(self, field): + """Sum over the values of the specified field. + + :param field: the field to sum over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = Code(""" + function() { + emit(1, this[field] || 0); + } + """, scope={'field': field}) + + reduce_func = Code(""" + function(key, values) { + var sum = 0; + for (var i in values) { + sum += values[i]; + } + return sum; + } + """) + + for result in self.map_reduce(map_func, reduce_func, output='inline'): + return result.value + else: + return 0 + + def average(self, field): + """Average over the values of the specified field. + + :param field: the field to average over; use dot-notation to refer to + embedded document fields + + .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work + with sharding. + """ + map_func = Code(""" + function() { + if (this.hasOwnProperty(field)) + emit(1, {t: this[field] || 0, c: 1}); + } + """, scope={'field': field}) + + reduce_func = Code(""" + function(key, values) { + var out = {t: 0, c: 0}; + for (var i in values) { + var value = values[i]; + out.t += value.t; + out.c += value.c; + } + return out; + } + """) + + finalize_func = Code(""" + function(key, value) { + return value.t / value.c; + } + """) + + for result in self.map_reduce(map_func, reduce_func, + finalize_f=finalize_func, output='inline'): + return result.value + else: + return 0 + + def item_frequencies(self, field, normalize=False, map_reduce=True): + """Returns a dictionary of all items present in a field across + the whole queried set of documents, and their corresponding frequency. + This is useful for generating tag clouds, or searching documents. + + .. note:: + + Can only do direct simple mappings and cannot map across + :class:`~mongoengine.fields.ReferenceField` or + :class:`~mongoengine.fields.GenericReferenceField` for more complex + counting a manual map reduce call would is required. + + If the field is a :class:`~mongoengine.fields.ListField`, the items within + each list will be counted individually. + + :param field: the field to use + :param normalize: normalize the results so they add to 1.0 + :param map_reduce: Use map_reduce over exec_js + + .. versionchanged:: 0.5 defaults to map_reduce and can handle embedded + document lookups + """ + if map_reduce: + return self._item_frequencies_map_reduce(field, + normalize=normalize) + return self._item_frequencies_exec_js(field, normalize=normalize) + + # Iterator helpers + + def next(self): + """Wrap the result in a :class:`~mongoengine.Document` object. + """ + self._iter = True + try: + if self._limit == 0 or self._none: + raise StopIteration + if self._scalar: + return self._get_scalar(self._document._from_son( + self._cursor.next())) + if self._as_pymongo: + return self._get_as_pymongo(self._cursor.next()) + + return self._document._from_son(self._cursor.next()) + except StopIteration, e: + self.rewind() + raise e + + def rewind(self): + """Rewind the cursor to its unevaluated state. + + .. versionadded:: 0.3 + """ + self._iter = False + self._cursor.rewind() + + # Properties + + @property + def _collection(self): + """Property that returns the collection object. This allows us to + perform operations only if the collection is accessed. + """ + return self._collection_obj + + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout + } + if self._read_preference is not None: + cursor_args['read_preference'] = self._read_preference + else: + cursor_args['slave_okay'] = self._slave_okay + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + + @property + def _cursor(self): + if self._cursor_obj is None: + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) + + if self._ordering: + # Apply query ordering + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) + + if self._limit is not None: + self._cursor_obj.limit(self._limit - (self._skip or 0)) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) + + return self._cursor_obj + + def __deepcopy__(self, memo): + """Essential for chained queries with ReferenceFields involved""" + return self.clone() + + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + if self._class_check: + self._mongo_query.update(self._initial_query) + return self._mongo_query + + @property + def _dereference(self): + if not self.__dereference: + self.__dereference = _import_class('DeReference')() + return self.__dereference + + def no_dereference(self): + """Turn off any dereferencing for the results of this queryset. + """ + queryset = self.clone() + queryset._auto_dereference = False + return queryset + + # Helper Functions + + def _item_frequencies_map_reduce(self, field, normalize=False): + map_func = """ + function() { + var path = '{{~%(field)s}}'.split('.'); + var field = this; + + for (p in path) { + if (typeof field != 'undefined') + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + emit(item, 1); + }); + } else if (typeof field != 'undefined') { + emit(field, 1); + } else { + emit(null, 1); + } + } + """ % dict(field=field) + reduce_func = """ + function(key, values) { + var total = 0; + var valuesSize = values.length; + for (var i=0; i < valuesSize; i++) { + total += parseInt(values[i], 10); + } + return total; + } + """ + values = self.map_reduce(map_func, reduce_func, 'inline') + frequencies = {} + for f in values: + key = f.key + if isinstance(key, float): + if int(key) == key: + key = int(key) + frequencies[key] = int(f.value) + + if normalize: + count = sum(frequencies.values()) + frequencies = dict([(k, float(v) / count) + for k, v in frequencies.items()]) + + return frequencies + + def _item_frequencies_exec_js(self, field, normalize=False): + """Uses exec_js to execute""" + freq_func = """ + function(path) { + var path = path.split('.'); + + var total = 0.0; + db[collection].find(query).forEach(function(doc) { + var field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + total += field.length; + } else { + total++; + } + }); + + var frequencies = {}; + var types = {}; + var inc = 1.0; + + db[collection].find(query).forEach(function(doc) { + field = doc; + for (p in path) { + if (field) + field = field[path[p]]; + else + break; + } + if (field && field.constructor == Array) { + field.forEach(function(item) { + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + }); + } else { + var item = field; + types[item] = item; + frequencies[item] = inc + (isNaN(frequencies[item]) ? 0: frequencies[item]); + } + }); + return [total, frequencies, types]; + } + """ + total, data, types = self.exec_js(freq_func, field) + values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + + if normalize: + values = dict([(k, float(v) / total) for k, v in values.items()]) + + frequencies = {} + for k, v in values.iteritems(): + if isinstance(k, float): + if int(k) == k: + k = int(k) + + frequencies[k] = v + + return frequencies + + def _fields_to_dbfields(self, fields): + """Translate fields paths to its db equivalents""" + ret = [] + for field in fields: + field = ".".join(f.db_field for f in + self._document._lookup_field(field.split('.'))) + ret.append(field) + return ret + + def _get_order_by(self, keys): + """Creates a list of order by fields + """ + key_list = [] + for key in keys: + if not key: + continue + direction = pymongo.ASCENDING + if key[0] == '-': + direction = pymongo.DESCENDING + if key[0] in ('-', '+'): + key = key[1:] + key = key.replace('__', '.') + try: + key = self._document._translate_field_name(key) + except: + pass + key_list.append((key, direction)) + + if self._cursor_obj: + self._cursor_obj.sort(key_list) + return key_list + + def _get_scalar(self, doc): + + def lookup(obj, name): + chunks = name.split('__') + for chunk in chunks: + obj = getattr(obj, chunk) + return obj + + data = [lookup(doc, n) for n in self._scalar] + if len(data) == 1: + return data[0] + + return tuple(data) + + def _get_as_pymongo(self, row): + # Extract which fields paths we should follow if .fields(...) was + # used. If not, handle all fields. + if not getattr(self, '__as_pymongo_fields', None): + self.__as_pymongo_fields = [] + for field in self._loaded_fields.fields - set(['_cls', '_id']): + self.__as_pymongo_fields.append(field) + while '.' in field: + field, _ = field.rsplit('.', 1) + self.__as_pymongo_fields.append(field) + + all_fields = not self.__as_pymongo_fields + + def clean(data, path=None): + path = path or '' + + if isinstance(data, dict): + new_data = {} + for key, value in data.iteritems(): + new_path = '%s.%s' % (path, key) if path else key + if all_fields or new_path in self.__as_pymongo_fields: + new_data[key] = clean(value, path=new_path) + data = new_data + elif isinstance(data, list): + data = [clean(d, path=path) for d in data] + else: + if self._as_pymongo_coerce: + # If we need to coerce types, we need to determine the + # type of this field and use the corresponding + # .to_python(...) + from mongoengine.fields import EmbeddedDocumentField + obj = self._document + for chunk in path.split('.'): + obj = getattr(obj, chunk, None) + if obj is None: + break + elif isinstance(obj, EmbeddedDocumentField): + obj = obj.document_type + if obj and data is not None: + data = obj.to_python(data) + return data + return clean(row) + + def _sub_js_fields(self, code): + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be + substituted for the MongoDB name of the field (specified using the + :attr:`name` keyword argument in a field's constructor). + """ + def field_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return u'["%s"]' % fields[-1].db_field + + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) + + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, + code) + return code + + # Deprecated + + def ensure_index(self, **kwargs): + """Deprecated use :func:`~Document.ensure_index`""" + msg = ("Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_index(**kwargs) + return self + + def _ensure_indexes(self): + """Deprecated use :func:`~Document.ensure_indexes`""" + msg = ("Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_indexes() diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py new file mode 100644 index 00000000..3da26935 --- /dev/null +++ b/mongoengine/queryset/transform.py @@ -0,0 +1,252 @@ +from collections import defaultdict + +from bson import SON + +from mongoengine.common import _import_class +from mongoengine.errors import InvalidQueryError, LookUpError + +__all__ = ('query', 'update') + + +COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', + 'all', 'size', 'exists', 'not') +GEO_OPERATORS = ('within_distance', 'within_spherical_distance', + 'within_box', 'within_polygon', 'near', 'near_sphere', + 'max_distance') +STRING_OPERATORS = ('contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact') +CUSTOM_OPERATORS = ('match',) +MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + + STRING_OPERATORS + CUSTOM_OPERATORS) + +UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set') + + +def query(_doc_cls=None, _field_operation=False, **query): + """Transform a query from Django-style format to Mongo format. + """ + mongo_query = {} + merge_query = defaultdict(list) + for key, value in sorted(query.items()): + if key == "__raw__": + mongo_query.update(value) + continue + + parts = key.split('__') + indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] + parts = [part for part in parts if not part.isdigit()] + # Check for an operator and transform to mongo-style if there is + op = None + if parts[-1] in MATCH_OPERATORS: + op = parts.pop() + + negate = False + if parts[-1] == 'not': + parts.pop() + negate = True + + if _doc_cls: + # Switch field names to proper names [set in Field(name='foo')] + try: + fields = _doc_cls._lookup_field(parts) + except Exception, e: + raise InvalidQueryError(e) + parts = [] + + cleaned_fields = [] + for field in fields: + append_field = True + if isinstance(field, basestring): + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) + + # Convert value to proper value + field = cleaned_fields[-1] + + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] + singular_ops += STRING_OPERATORS + if op in singular_ops: + if isinstance(field, basestring): + if (op in STRING_OPERATORS and + isinstance(value, basestring)): + StringField = _import_class('StringField') + value = StringField.prepare_query_value(op, value) + else: + value = field + else: + value = field.prepare_query_value(op, value) + elif op in ('in', 'nin', 'all', 'near'): + # 'in', 'nin' and 'all' require a list of values + value = [field.prepare_query_value(op, v) for v in value] + + # if op and op not in COMPARISON_OPERATORS: + if op: + if op in GEO_OPERATORS: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "within_spherical_distance": + value = {'$within': {'$centerSphere': value}} + elif op == "within_polygon": + value = {'$within': {'$polygon': value}} + elif op == "near": + value = {'$near': value} + elif op == "near_sphere": + value = {'$nearSphere': value} + elif op == 'within_box': + value = {'$within': {'$box': value}} + elif op == "max_distance": + value = {'$maxDistance': value} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented" % op) + elif op in CUSTOM_OPERATORS: + if op == 'match': + value = {"$elemMatch": value} + else: + NotImplementedError("Custom method '%s' has not " + "been implemented" % op) + elif op not in STRING_OPERATORS: + value = {'$' + op: value} + + if negate: + value = {'$not': value} + + for i, part in indices: + parts.insert(i, part) + key = '.'.join(parts) + if op is None or key not in mongo_query: + mongo_query[key] = value + elif key in mongo_query: + if key in mongo_query and isinstance(mongo_query[key], dict): + mongo_query[key].update(value) + # $maxDistance needs to come last - convert to SON + if '$maxDistance' in mongo_query[key]: + value_dict = mongo_query[key] + value_son = SON() + for k, v in value_dict.iteritems(): + if k == '$maxDistance': + continue + value_son[k] = v + value_son['$maxDistance'] = value_dict['$maxDistance'] + mongo_query[key] = value_son + else: + # Store for manually merging later + merge_query[key].append(value) + + # The queryset has been filter in such a way we must manually merge + for k, v in merge_query.items(): + merge_query[k].append(mongo_query[k]) + del mongo_query[k] + if isinstance(v, list): + value = [{k: val} for val in v] + if '$and' in mongo_query.keys(): + mongo_query['$and'].append(value) + else: + mongo_query['$and'] = value + + return mongo_query + + +def update(_doc_cls=None, **update): + """Transform an update spec from Django-style format to Mongo format. + """ + mongo_update = {} + for key, value in update.items(): + if key == "__raw__": + mongo_update.update(value) + continue + parts = key.split('__') + # Check for an operator and transform to mongo-style if there is + op = None + if parts[0] in UPDATE_OPERATORS: + op = parts.pop(0) + # Convert Pythonic names to Mongo equivalents + if op in ('push_all', 'pull_all'): + op = op.replace('_all', 'All') + elif op == 'dec': + # Support decrement by flipping a positive value's sign + # and using 'inc' + op = 'inc' + if value > 0: + value = -value + elif op == 'add_to_set': + op = op.replace('_to_set', 'ToSet') + + match = None + if parts[-1] in COMPARISON_OPERATORS: + match = parts.pop() + + if _doc_cls: + # Switch field names to proper names [set in Field(name='foo')] + try: + fields = _doc_cls._lookup_field(parts) + except Exception, e: + raise InvalidQueryError(e) + parts = [] + + cleaned_fields = [] + for field in fields: + append_field = True + if isinstance(field, basestring): + # Convert the S operator to $ + if field == 'S': + field = '$' + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) + + # Convert value to proper value + field = cleaned_fields[-1] + + if op in (None, 'set', 'push', 'pull'): + if field.required or value is not None: + value = field.prepare_query_value(op, value) + elif op in ('pushAll', 'pullAll'): + value = [field.prepare_query_value(op, v) for v in value] + elif op == 'addToSet': + if isinstance(value, (list, tuple, set)): + value = [field.prepare_query_value(op, v) for v in value] + elif field.required or value is not None: + value = field.prepare_query_value(op, value) + + if match: + match = '$' + match + value = {match: value} + + key = '.'.join(parts) + + if not op: + raise InvalidQueryError("Updates must supply an operation " + "eg: set__FIELD=value") + + if 'pull' in op and '.' in key: + # Dot operators don't work on pull operations + # it uses nested dict syntax + if op == 'pullAll': + raise InvalidQueryError("pullAll operations only support " + "a single field depth") + + parts.reverse() + for key in parts: + value = {key: value} + elif op == 'addToSet' and isinstance(value, list): + value = {key: {"$each": value}} + else: + value = {key: value} + key = '$' + op + + if key not in mongo_update: + mongo_update[key] = value + elif key in mongo_update and isinstance(mongo_update[key], dict): + mongo_update[key].update(value) + + return mongo_update diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py new file mode 100644 index 00000000..95d11e8f --- /dev/null +++ b/mongoengine/queryset/visitor.py @@ -0,0 +1,155 @@ +import copy + +from mongoengine.errors import InvalidQueryError +from mongoengine.python_support import product, reduce + +from mongoengine.queryset import transform + +__all__ = ('Q',) + + +class QNodeVisitor(object): + """Base visitor class for visiting Q-object nodes in a query tree. + """ + + def visit_combination(self, combination): + """Called by QCombination objects. + """ + return combination + + def visit_query(self, query): + """Called by (New)Q objects. + """ + return query + + +class SimplificationVisitor(QNodeVisitor): + """Simplifies query trees by combinging unnecessary 'and' connection nodes + into a single Q-object. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # The simplification only applies to 'simple' queries + if all(isinstance(node, Q) for node in combination.children): + queries = [n.query for n in combination.children] + return Q(**self._query_conjunction(queries)) + return combination + + def _query_conjunction(self, queries): + """Merges query dicts - effectively &ing them together. + """ + query_ops = set() + combined_query = {} + for query in queries: + ops = set(query.keys()) + # Make sure that the same operation isn't applied more than once + # to a single field + intersection = ops.intersection(query_ops) + if intersection: + msg = 'Duplicate query conditions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + query_ops.update(ops) + combined_query.update(copy.deepcopy(query)) + return combined_query + + +class QueryCompilerVisitor(QNodeVisitor): + """Compiles the nodes in a query tree to a PyMongo-compatible query + dictionary. + """ + + def __init__(self, document): + self.document = document + + def visit_combination(self, combination): + operator = "$and" + if combination.operation == combination.OR: + operator = "$or" + return {operator: combination.children} + + def visit_query(self, query): + return transform.query(self.document, **query.query) + + +class QNode(object): + """Base class for nodes in query trees. + """ + + AND = 0 + OR = 1 + + def to_query(self, document): + query = self.accept(SimplificationVisitor()) + query = query.accept(QueryCompilerVisitor(document)) + return query + + def accept(self, visitor): + raise NotImplementedError + + def _combine(self, other, operation): + """Combine this node with another node into a QCombination object. + """ + if getattr(other, 'empty', True): + return self + + if self.empty: + return other + + return QCombination(operation, [self, other]) + + @property + def empty(self): + return False + + def __or__(self, other): + return self._combine(other, self.OR) + + def __and__(self, other): + return self._combine(other, self.AND) + + +class QCombination(QNode): + """Represents the combination of several conditions by a given logical + operator. + """ + + def __init__(self, operation, children): + self.operation = operation + self.children = [] + for node in children: + # If the child is a combination of the same type, we can merge its + # children directly into this combinations children + if isinstance(node, QCombination) and node.operation == operation: + # self.children += node.children + self.children.append(node) + else: + self.children.append(node) + + def accept(self, visitor): + for i in range(len(self.children)): + if isinstance(self.children[i], QNode): + self.children[i] = self.children[i].accept(visitor) + + return visitor.visit_combination(self) + + @property + def empty(self): + return not bool(self.children) + + +class Q(QNode): + """A simple query object, used in a query tree to build up more complex + query structures. + """ + + def __init__(self, **query): + self.query = query + + def accept(self, visitor): + return visitor.visit_query(self) + + @property + def empty(self): + return not bool(self.query) diff --git a/mongoengine/tests.py b/mongoengine/tests.py deleted file mode 100644 index 68663772..00000000 --- a/mongoengine/tests.py +++ /dev/null @@ -1,59 +0,0 @@ -from mongoengine.connection import get_db - - -class query_counter(object): - """ Query_counter contextmanager to get the number of queries. """ - - def __init__(self): - """ Construct the query_counter. """ - self.counter = 0 - self.db = get_db() - - def __enter__(self): - """ On every with block we need to drop the profile collection. """ - self.db.set_profiling_level(0) - self.db.system.profile.drop() - self.db.set_profiling_level(2) - return self - - def __exit__(self, t, value, traceback): - """ Reset the profiling level. """ - self.db.set_profiling_level(0) - - def __eq__(self, value): - """ == Compare querycounter. """ - return value == self._get_count() - - def __ne__(self, value): - """ != Compare querycounter. """ - return not self.__eq__(value) - - def __lt__(self, value): - """ < Compare querycounter. """ - return self._get_count() < value - - def __le__(self, value): - """ <= Compare querycounter. """ - return self._get_count() <= value - - def __gt__(self, value): - """ > Compare querycounter. """ - return self._get_count() > value - - def __ge__(self, value): - """ >= Compare querycounter. """ - return self._get_count() >= value - - def __int__(self): - """ int representation. """ - return self._get_count() - - def __repr__(self): - """ repr query_counter as the number of queries. """ - return u"%s" % self._get_count() - - def _get_count(self): - """ Get the number of queries. """ - count = self.db.system.profile.find().count() - self.counter - self.counter += 1 - return count diff --git a/python-mongoengine.spec b/python-mongoengine.spec index b1ec3361..eaf478dc 100644 --- a/python-mongoengine.spec +++ b/python-mongoengine.spec @@ -5,7 +5,7 @@ %define srcname mongoengine Name: python-%{srcname} -Version: 0.7.9 +Version: 0.7.10 Release: 1%{?dist} Summary: A Python Document-Object Mapper for working with MongoDB @@ -51,4 +51,4 @@ rm -rf $RPM_BUILD_ROOT # %{python_sitearch}/* %changelog -* See: http://readthedocs.org/docs/mongoengine-odm/en/latest/changelog.html \ No newline at end of file +* See: http://docs.mongoengine.org/en/latest/changelog.html \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index d95a9176..3f3faa8c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,4 +8,4 @@ detailed-errors = 1 #cover-package = mongoengine py3where = build where = tests -#tests = test_bugfix.py \ No newline at end of file +#tests = document/__init__.py \ No newline at end of file diff --git a/setup.py b/setup.py index 54c2cdca..bdd01825 100644 --- a/setup.py +++ b/setup.py @@ -8,8 +8,8 @@ try: except ImportError: pass -DESCRIPTION = """MongoEngine is a Python Object-Document -Mapper for working with MongoDB.""" +DESCRIPTION = 'MongoEngine is a Python Object-Document ' + \ +'Mapper for working with MongoDB.' LONG_DESCRIPTION = None try: LONG_DESCRIPTION = open('README.rst').read() @@ -38,7 +38,6 @@ CLASSIFIERS = [ 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.5", "Programming Language :: Python :: 2.6", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", @@ -56,9 +55,9 @@ if sys.version_info[0] == 3: extra_opts['packages'] = find_packages(exclude=('tests',)) if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'].append("tests") - extra_opts['package_data'] = {"tests": ["mongoengine.png"]} + extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django==1.4.2', 'PIL'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL'] extra_opts['packages'] = find_packages(exclude=('tests',)) setup(name='mongoengine', diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..b24df5d2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +from all_warnings import AllWarnings +from document import * +from queryset import * +from fields import * +from migration import * diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py new file mode 100644 index 00000000..53ce638c --- /dev/null +++ b/tests/all_warnings/__init__.py @@ -0,0 +1,44 @@ +""" +This test has been put into a module. This is because it tests warnings that +only get triggered on first hit. This way we can ensure its imported into the +top level and called first by the test suite. +""" +import sys +sys.path[0:0] = [""] +import unittest +import warnings + +from mongoengine import * + + +__all__ = ('AllWarnings', ) + + +class AllWarnings(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.warning_list = [] + self.showwarning_default = warnings.showwarning + warnings.showwarning = self.append_to_warning_list + + def append_to_warning_list(self, message, category, *args): + self.warning_list.append({"message": message, + "category": category}) + + def tearDown(self): + # restore default handling of warnings + warnings.showwarning = self.showwarning_default + + def test_document_collection_syntax_warning(self): + + class NonAbstractBase(Document): + meta = {'allow_inheritance': True} + + class InheritedDocumentFailTest(NonAbstractBase): + meta = {'collection': 'fail'} + + warning = self.warning_list[0] + self.assertEqual(SyntaxWarning, warning["category"]) + self.assertEqual('non_abstract_base', + InheritedDocumentFailTest._get_collection_name()) diff --git a/tests/document/__init__.py b/tests/document/__init__.py new file mode 100644 index 00000000..1acc9f4b --- /dev/null +++ b/tests/document/__init__.py @@ -0,0 +1,15 @@ +import sys +sys.path[0:0] = [""] +import unittest + +from class_methods import * +from delta import * +from dynamic import * +from indexes import * +from inheritance import * +from instance import * +from json_serialisation import * +from validation import * + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py new file mode 100644 index 00000000..83e68ff8 --- /dev/null +++ b/tests/document/class_methods.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import sys +sys.path[0:0] = [""] +import unittest + +from mongoengine import * + +from mongoengine.queryset import NULLIFY +from mongoengine.connection import get_db + +__all__ = ("ClassMethodsTest", ) + + +class ClassMethodsTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_definition(self): + """Ensure that document may be defined using fields. + """ + self.assertEqual(['age', 'id', 'name'], + sorted(self.Person._fields.keys())) + self.assertEqual(["IntField", "ObjectIdField", "StringField"], + sorted([x.__class__.__name__ for x in + self.Person._fields.values()])) + + def test_get_db(self): + """Ensure that get_db returns the expected db. + """ + db = self.Person._get_db() + self.assertEqual(self.db, db) + + def test_get_collection_name(self): + """Ensure that get_collection_name returns the expected collection + name. + """ + collection_name = 'person' + self.assertEqual(collection_name, self.Person._get_collection_name()) + + def test_get_collection(self): + """Ensure that get_collection returns the expected collection. + """ + collection_name = 'person' + collection = self.Person._get_collection() + self.assertEqual(self.db[collection_name], collection) + + def test_drop_collection(self): + """Ensure that the collection may be dropped from the database. + """ + collection_name = 'person' + self.Person(name='Test').save() + self.assertTrue(collection_name in self.db.collection_names()) + + self.Person.drop_collection() + self.assertFalse(collection_name in self.db.collection_names()) + + def test_register_delete_rule(self): + """Ensure that register delete rule adds a delete rule to the document + meta. + """ + class Job(Document): + employee = ReferenceField(self.Person) + + self.assertEqual(self.Person._meta.get('delete_rules'), None) + + self.Person.register_delete_rule(Job, 'employee', NULLIFY) + self.assertEqual(self.Person._meta['delete_rules'], + {(Job, 'employee'): NULLIFY}) + + def test_collection_naming(self): + """Ensure that a collection with a specified name may be used. + """ + + class DefaultNamingTest(Document): + pass + self.assertEqual('default_naming_test', + DefaultNamingTest._get_collection_name()) + + class CustomNamingTest(Document): + meta = {'collection': 'pimp_my_collection'} + + self.assertEqual('pimp_my_collection', + CustomNamingTest._get_collection_name()) + + class DynamicNamingTest(Document): + meta = {'collection': lambda c: "DYNAMO"} + self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name()) + + # Use Abstract class to handle backwards compatibility + class BaseDocument(Document): + meta = { + 'abstract': True, + 'collection': lambda c: c.__name__.lower() + } + + class OldNamingConvention(BaseDocument): + pass + self.assertEqual('oldnamingconvention', + OldNamingConvention._get_collection_name()) + + class InheritedAbstractNamingTest(BaseDocument): + meta = {'collection': 'wibble'} + self.assertEqual('wibble', + InheritedAbstractNamingTest._get_collection_name()) + + # Mixin tests + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class OldMixinNamingConvention(Document, BaseMixin): + pass + self.assertEqual('oldmixinnamingconvention', + OldMixinNamingConvention._get_collection_name()) + + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class BaseDocument(Document, BaseMixin): + meta = {'allow_inheritance': True} + + class MyDocument(BaseDocument): + pass + + self.assertEqual('basedocument', MyDocument._get_collection_name()) + + def test_custom_collection_name_operations(self): + """Ensure that a collection with a specified name is used as expected. + """ + collection_name = 'personCollTest' + + class Person(Document): + name = StringField() + meta = {'collection': collection_name} + + Person(name="Test User").save() + self.assertTrue(collection_name in self.db.collection_names()) + + user_obj = self.db[collection_name].find_one() + self.assertEqual(user_obj['name'], "Test User") + + user_obj = Person.objects[0] + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + self.assertFalse(collection_name in self.db.collection_names()) + + def test_collection_name_and_primary(self): + """Ensure that a collection with a specified name may be used. + """ + + class Person(Document): + name = StringField(primary_key=True) + meta = {'collection': 'app'} + + Person(name="Test User").save() + + user_obj = Person.objects.first() + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/delta.py b/tests/document/delta.py new file mode 100644 index 00000000..16ab609b --- /dev/null +++ b/tests/document/delta.py @@ -0,0 +1,690 @@ +# -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] +import unittest + +from mongoengine import * +from mongoengine.connection import get_db + +__all__ = ("DeltaTest",) + + +class DeltaTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_delta(self): + self.delta(Document) + self.delta(DynamicDocument) + + def delta(self, DocClass): + + class Doc(DocClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + + def test_delta_recursive(self): + self.delta_recursive(Document, EmbeddedDocument) + self.delta_recursive(DynamicDocument, EmbeddedDocument) + self.delta_recursive(Document, DynamicEmbeddedDocument) + self.delta_recursive(DynamicDocument, DynamicEmbeddedDocument) + + def delta_recursive(self, DocClass, EmbeddedClass): + + class Embedded(EmbeddedClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + class Doc(DocClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc._get_changed_fields(), ['embedded_field']) + + embedded_delta = { + 'string_field': 'hello', + 'int_field': 1, + 'dict_field': {'hello': 'world'}, + 'list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEqual(doc._delta(), + ({'embedded_field': embedded_delta}, {})) + + doc.save() + doc = doc.reload(10) + + doc.embedded_field.dict_field = {} + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.dict_field']) + self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) + self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) + self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + + self.assertEqual(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEqual(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEqual(doc.embedded_field.list_field[2][k], + embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field.2.string_field']) + self.assertEqual(doc.embedded_field._delta(), + ({'list_field.2.string_field': 'world'}, {})) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.string_field': 'world'}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}}]}, {})) + self.assertEqual(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}]}, {})) + doc.save() + doc = doc.reload(10) + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort(key=str) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + doc.save() + doc = doc.reload(10) + + del(doc.embedded_field.list_field[2].list_field) + self.assertEqual(doc._delta(), + ({}, {'embedded_field.list_field.2.list_field': 1})) + + doc.save() + doc = doc.reload(10) + + doc.dict_field['Embedded'] = embedded_1 + doc.save() + doc = doc.reload(10) + + doc.dict_field['Embedded'].string_field = 'Hello World' + self.assertEqual(doc._get_changed_fields(), + ['dict_field.Embedded.string_field']) + self.assertEqual(doc._delta(), + ({'dict_field.Embedded.string_field': 'Hello World'}, {})) + + def test_circular_reference_deltas(self): + self.circular_reference_deltas(Document, Document) + self.circular_reference_deltas(Document, DynamicDocument) + self.circular_reference_deltas(DynamicDocument, Document) + self.circular_reference_deltas(DynamicDocument, DynamicDocument) + + def circular_reference_deltas(self, DocClass1, DocClass2): + + class Person(DocClass1): + name = StringField() + owns = ListField(ReferenceField('Organization')) + + class Organization(DocClass2): + name = StringField() + owner = ReferenceField('Person') + + Person.drop_collection() + Organization.drop_collection() + + person = Person(name="owner").save() + organization = Organization(name="company").save() + + person.owns.append(organization) + organization.owner = person + + person.save() + organization.save() + + p = Person.objects[0].select_related() + o = Organization.objects.first() + self.assertEqual(p.owns[0], o) + self.assertEqual(o.owner, p) + + def test_circular_reference_deltas_2(self): + self.circular_reference_deltas_2(Document, Document) + self.circular_reference_deltas_2(Document, DynamicDocument) + self.circular_reference_deltas_2(DynamicDocument, Document) + self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) + + def circular_reference_deltas_2(self, DocClass1, DocClass2): + + class Person(DocClass1): + name = StringField() + owns = ListField(ReferenceField('Organization')) + employer = ReferenceField('Organization') + + class Organization(DocClass2): + name = StringField() + owner = ReferenceField('Person') + employees = ListField(ReferenceField('Person')) + + Person.drop_collection() + Organization.drop_collection() + + person = Person(name="owner") + person.save() + + employee = Person(name="employee") + employee.save() + + organization = Organization(name="company") + organization.save() + + person.owns.append(organization) + organization.owner = person + + organization.employees.append(employee) + employee.employer = organization + + person.save() + organization.save() + employee.save() + + p = Person.objects.get(name="owner") + e = Person.objects.get(name="employee") + o = Organization.objects.first() + + self.assertEqual(p.owns[0], o) + self.assertEqual(o.owner, p) + self.assertEqual(e.employer, o) + + def test_delta_db_field(self): + self.delta_db_field(Document) + self.delta_db_field(DynamicDocument) + + def delta_db_field(self, DocClass): + + class Doc(DocClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['db_string_field']) + self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['db_int_field']) + self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) + + # Test it saves that data + doc = Doc() + doc.save() + + doc.string_field = 'hello' + doc.int_field = 1 + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.string_field, 'hello') + self.assertEqual(doc.int_field, 1) + self.assertEqual(doc.dict_field, {'hello': 'world'}) + self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) + + def test_delta_recursive_db_field(self): + self.delta_recursive_db_field(Document, EmbeddedDocument) + self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) + self.delta_recursive_db_field(DynamicDocument, EmbeddedDocument) + self.delta_recursive_db_field(DynamicDocument, DynamicEmbeddedDocument) + + def delta_recursive_db_field(self, DocClass, EmbeddedClass): + + class Embedded(EmbeddedClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + class Doc(DocClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + embedded_field = EmbeddedDocumentField(Embedded, + db_field='db_embedded_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) + + embedded_delta = { + 'db_string_field': 'hello', + 'db_int_field': 1, + 'db_dict_field': {'hello': 'world'}, + 'db_list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEqual(doc._delta(), + ({'db_embedded_field': embedded_delta}, {})) + + doc.save() + doc = doc.reload(10) + + doc.embedded_field.dict_field = {} + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_dict_field']) + self.assertEqual(doc.embedded_field._delta(), + ({}, {'db_dict_field': 1})) + self.assertEqual(doc._delta(), + ({}, {'db_embedded_field.db_dict_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), + ({}, {'db_list_field': 1})) + self.assertEqual(doc._delta(), + ({}, {'db_embedded_field.db_list_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEqual(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEqual(doc.embedded_field.list_field[2][k], + embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field.2.db_string_field']) + self.assertEqual(doc.embedded_field._delta(), + ({'db_list_field.2.db_string_field': 'world'}, {})) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, + {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}}]}, {})) + self.assertEqual(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [2, {'hello': 'world'}]}, {})) + doc.save() + doc = doc.reload(10) + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort(key=str) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [1, 2, {}]}, {})) + doc.save() + doc = doc.reload(10) + + del(doc.embedded_field.list_field[2].list_field) + self.assertEqual(doc._delta(), ({}, + {'db_embedded_field.db_list_field.2.db_list_field': 1})) + + def test_delta_for_dynamic_documents(self): + class Person(DynamicDocument): + name = StringField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + + p = Person(name="James", age=34) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {})) + + p.doc = 123 + del(p.doc) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {'doc': 1})) + + p = Person() + p.name = "Dean" + p.age = 22 + p.save() + + p.age = 24 + self.assertEqual(p.age, 24) + self.assertEqual(p._get_changed_fields(), ['age']) + self.assertEqual(p._delta(), ({'age': 24}, {})) + + p = self.Person.objects(age=22).get() + p.age = 24 + self.assertEqual(p.age, 24) + self.assertEqual(p._get_changed_fields(), ['age']) + self.assertEqual(p._delta(), ({'age': 24}, {})) + + p.save() + self.assertEqual(1, self.Person.objects(age=24).count()) + + def test_dynamic_delta(self): + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py new file mode 100644 index 00000000..6263e68c --- /dev/null +++ b/tests/document/dynamic.py @@ -0,0 +1,297 @@ +import unittest +import sys +sys.path[0:0] = [""] + +from mongoengine import * +from mongoengine.connection import get_db + +__all__ = ("DynamicTest", ) + + +class DynamicTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(DynamicDocument): + name = StringField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + + self.Person = Person + + def test_simple_dynamic_document(self): + """Ensures simple dynamic documents are saved correctly""" + + p = self.Person() + p.name = "James" + p.age = 34 + + self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", + "age": 34}) + self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"]) + p.save() + self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"]) + + self.assertEqual(self.Person.objects.first().age, 34) + + # Confirm no changes to self.Person + self.assertFalse(hasattr(self.Person, 'age')) + + def test_change_scope_of_variable(self): + """Test changing the scope of a dynamic field has no adverse effects""" + p = self.Person() + p.name = "Dean" + p.misc = 22 + p.save() + + p = self.Person.objects.get() + p.misc = {'hello': 'world'} + p.save() + + p = self.Person.objects.get() + self.assertEqual(p.misc, {'hello': 'world'}) + + def test_delete_dynamic_field(self): + """Test deleting a dynamic field works""" + self.Person.drop_collection() + p = self.Person() + p.name = "Dean" + p.misc = 22 + p.save() + + p = self.Person.objects.get() + p.misc = {'hello': 'world'} + p.save() + + p = self.Person.objects.get() + self.assertEqual(p.misc, {'hello': 'world'}) + collection = self.db[self.Person._get_collection_name()] + obj = collection.find_one() + self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'misc', 'name']) + + del(p.misc) + p.save() + + p = self.Person.objects.get() + self.assertFalse(hasattr(p, 'misc')) + + obj = collection.find_one() + self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'name']) + + def test_dynamic_document_queries(self): + """Ensure we can query dynamic fields""" + p = self.Person() + p.name = "Dean" + p.age = 22 + p.save() + + self.assertEqual(1, self.Person.objects(age=22).count()) + p = self.Person.objects(age=22) + p = p.get() + self.assertEqual(22, p.age) + + def test_complex_dynamic_document_queries(self): + class Person(DynamicDocument): + name = StringField() + + Person.drop_collection() + + p = Person(name="test") + p.age = "ten" + p.save() + + p1 = Person(name="test1") + p1.age = "less then ten and a half" + p1.save() + + p2 = Person(name="test2") + p2.age = 10 + p2.save() + + self.assertEqual(Person.objects(age__icontains='ten').count(), 2) + self.assertEqual(Person.objects(age__gte=10).count(), 1) + + def test_complex_data_lookups(self): + """Ensure you can query dynamic document dynamic fields""" + p = self.Person() + p.misc = {'hello': 'world'} + p.save() + + self.assertEqual(1, self.Person.objects(misc__hello='world').count()) + + def test_complex_embedded_document_validation(self): + """Ensure embedded dynamic documents may be validated""" + class Embedded(DynamicEmbeddedDocument): + content = URLField() + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_doc_1 = Embedded(content='http://mongoengine.org') + embedded_doc_1.validate() + + embedded_doc_2 = Embedded(content='this is not a url') + self.assertRaises(ValidationError, embedded_doc_2.validate) + + doc.embedded_field_1 = embedded_doc_1 + doc.embedded_field_2 = embedded_doc_2 + self.assertRaises(ValidationError, doc.validate) + + def test_inheritance(self): + """Ensure that dynamic document plays nice with inheritance""" + class Employee(self.Person): + salary = IntField() + + Employee.drop_collection() + + self.assertTrue('name' in Employee._fields) + self.assertTrue('salary' in Employee._fields) + self.assertEqual(Employee._get_collection_name(), + self.Person._get_collection_name()) + + joe_bloggs = Employee() + joe_bloggs.name = "Joe Bloggs" + joe_bloggs.salary = 10 + joe_bloggs.age = 20 + joe_bloggs.save() + + self.assertEqual(1, self.Person.objects(age=20).count()) + self.assertEqual(1, Employee.objects(age=20).count()) + + joe_bloggs = self.Person.objects.first() + self.assertTrue(isinstance(joe_bloggs, Employee)) + + def test_embedded_dynamic_document(self): + """Test dynamic embedded documents""" + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc.to_mongo(), { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, {'hello': 'world'}] + } + }) + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc.embedded_field.__class__, Embedded) + self.assertEqual(doc.embedded_field.string_field, "hello") + self.assertEqual(doc.embedded_field.int_field, 1) + self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(doc.embedded_field.list_field, + ['1', 2, {'hello': 'world'}]) + + def test_complex_embedded_documents(self): + """Test complex dynamic embedded documents setups""" + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + embedded_1.list_field = ['1', 2, embedded_2] + doc.embedded_field = embedded_1 + + self.assertEqual(doc.to_mongo(), { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, + {"_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, {'hello': 'world'}]} + ] + } + }) + doc.save() + doc = Doc.objects.first() + self.assertEqual(doc.embedded_field.__class__, Embedded) + self.assertEqual(doc.embedded_field.string_field, "hello") + self.assertEqual(doc.embedded_field.int_field, 1) + self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + + embedded_field = doc.embedded_field.list_field[2] + + self.assertEqual(embedded_field.__class__, Embedded) + self.assertEqual(embedded_field.string_field, "hello") + self.assertEqual(embedded_field.int_field, 1) + self.assertEqual(embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(embedded_field.list_field, ['1', 2, + {'hello': 'world'}]) + + def test_dynamic_and_embedded(self): + """Ensure embedded documents play nicely""" + + class Address(EmbeddedDocument): + city = StringField() + + class Person(DynamicDocument): + name = StringField() + + Person.drop_collection() + + Person(name="Ross", address=Address(city="London")).save() + + person = Person.objects.first() + person.address.city = "Lundenne" + person.save() + + self.assertEqual(Person.objects.first().address.city, "Lundenne") + + person = Person.objects.first() + person.address = Address(city="Londinium") + person.save() + + self.assertEqual(Person.objects.first().address.city, "Londinium") + + person = Person.objects.first() + person.age = 35 + person.save() + self.assertEqual(Person.objects.first().age, 35) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/indexes.py b/tests/document/indexes.py new file mode 100644 index 00000000..61e3c0e7 --- /dev/null +++ b/tests/document/indexes.py @@ -0,0 +1,738 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import unittest +import sys +sys.path[0:0] = [""] + +import os +import pymongo + +from nose.plugins.skip import SkipTest +from datetime import datetime + +from mongoengine import * +from mongoengine.connection import get_db, get_connection + +__all__ = ("IndexesTest", ) + + +class IndexesTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_indexes_document(self): + """Ensure that indexes are used when meta[indexes] is specified for + Documents + """ + self._index_test(Document) + + def test_indexes_dynamic_document(self): + """Ensure that indexes are used when meta[indexes] is specified for + Dynamic Documents + """ + self._index_test(DynamicDocument) + + def _index_test(self, InheritFrom): + + class BlogPost(InheritFrom): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + '-date', + 'tags', + ('category', '-date') + ] + } + + expected_specs = [{'fields': [('addDate', -1)]}, + {'fields': [('tags', 1)]}, + {'fields': [('category', 1), ('addDate', -1)]}] + self.assertEqual(expected_specs, BlogPost._meta['index_specs']) + + BlogPost.ensure_indexes() + info = BlogPost.objects._collection.index_information() + # _id, '-date', 'tags', ('cat', 'date') + self.assertEqual(len(info), 4) + info = [value['key'] for key, value in info.iteritems()] + for expected in expected_specs: + self.assertTrue(expected['fields'] in info) + + def _index_test_inheritance(self, InheritFrom): + + class BlogPost(InheritFrom): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + '-date', + 'tags', + ('category', '-date') + ], + 'allow_inheritance': True + } + + expected_specs = [{'fields': [('_cls', 1), ('addDate', -1)]}, + {'fields': [('_cls', 1), ('tags', 1)]}, + {'fields': [('_cls', 1), ('category', 1), + ('addDate', -1)]}] + self.assertEqual(expected_specs, BlogPost._meta['index_specs']) + + BlogPost.ensure_indexes() + info = BlogPost.objects._collection.index_information() + # _id, '-date', 'tags', ('cat', 'date') + # NB: there is no index on _cls by itself, since + # the indices on -date and tags will both contain + # _cls as first element in the key + self.assertEqual(len(info), 4) + info = [value['key'] for key, value in info.iteritems()] + for expected in expected_specs: + self.assertTrue(expected['fields'] in info) + + class ExtendedBlogPost(BlogPost): + title = StringField() + meta = {'indexes': ['title']} + + expected_specs.append({'fields': [('_cls', 1), ('title', 1)]}) + self.assertEqual(expected_specs, ExtendedBlogPost._meta['index_specs']) + + BlogPost.drop_collection() + + ExtendedBlogPost.ensure_indexes() + info = ExtendedBlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + for expected in expected_specs: + self.assertTrue(expected['fields'] in info) + + def test_indexes_document_inheritance(self): + """Ensure that indexes are used when meta[indexes] is specified for + Documents + """ + self._index_test_inheritance(Document) + + def test_indexes_dynamic_document_inheritance(self): + """Ensure that indexes are used when meta[indexes] is specified for + Dynamic Documents + """ + self._index_test_inheritance(DynamicDocument) + + def test_inherited_index(self): + """Ensure index specs are inhertited correctly""" + + class A(Document): + title = StringField() + meta = { + 'indexes': [ + { + 'fields': ('title',), + }, + ], + 'allow_inheritance': True, + } + + class B(A): + description = StringField() + + self.assertEqual(A._meta['index_specs'], B._meta['index_specs']) + self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], + A._meta['index_specs']) + + def test_build_index_spec_is_not_destructive(self): + + class MyDoc(Document): + keywords = StringField() + + meta = { + 'indexes': ['keywords'], + 'allow_inheritance': False + } + + self.assertEqual(MyDoc._meta['index_specs'], + [{'fields': [('keywords', 1)]}]) + + # Force index creation + MyDoc.ensure_indexes() + + self.assertEqual(MyDoc._meta['index_specs'], + [{'fields': [('keywords', 1)]}]) + + def test_embedded_document_index_meta(self): + """Ensure that embedded document indexes are created explicitly + """ + class Rank(EmbeddedDocument): + title = StringField(required=True) + + class Person(Document): + name = StringField(required=True) + rank = EmbeddedDocumentField(Rank, required=False) + + meta = { + 'indexes': [ + 'rank.title', + ], + 'allow_inheritance': False + } + + self.assertEqual([{'fields': [('rank.title', 1)]}], + Person._meta['index_specs']) + + Person.drop_collection() + + # Indexes are lazy so use list() to perform query + list(Person.objects) + info = Person.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('rank.title', 1)] in info) + + def test_explicit_geo2d_index(self): + """Ensure that geo2d indexes work when created via meta[indexes] + """ + class Place(Document): + location = DictField() + meta = { + 'allow_inheritance': True, + 'indexes': [ + '*location.point', + ] + } + + self.assertEqual([{'fields': [('location.point', '2d')]}], + Place._meta['index_specs']) + + Place.ensure_indexes() + info = Place._get_collection().index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('location.point', '2d')] in info) + + def test_explicit_geo2d_index_embedded(self): + """Ensure that geo2d indexes work when created via meta[indexes] + """ + class EmbeddedLocation(EmbeddedDocument): + location = DictField() + + class Place(Document): + current = DictField( + field=EmbeddedDocumentField('EmbeddedLocation')) + meta = { + 'allow_inheritance': True, + 'indexes': [ + '*current.location.point', + ] + } + + self.assertEqual([{'fields': [('current.location.point', '2d')]}], + Place._meta['index_specs']) + + Place.ensure_indexes() + info = Place._get_collection().index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('current.location.point', '2d')] in info) + + def test_dictionary_indexes(self): + """Ensure that indexes are used when meta[indexes] contains + dictionaries instead of lists. + """ + class BlogPost(Document): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + {'fields': ['-date'], 'unique': True, 'sparse': True}, + ], + } + + self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, + 'sparse': True}], + BlogPost._meta['index_specs']) + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # _id, '-date' + self.assertEqual(len(info), 2) + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('addDate', -1)], True, True) in info) + + BlogPost.drop_collection() + + def test_abstract_index_inheritance(self): + + class UserBase(Document): + user_guid = StringField(required=True) + meta = { + 'abstract': True, + 'indexes': ['user_guid'], + 'allow_inheritance': True + } + + class Person(UserBase): + name = StringField() + + meta = { + 'indexes': ['name'], + } + Person.drop_collection() + + Person(name="test", user_guid='123').save() + + self.assertEqual(1, Person.objects.count()) + info = Person.objects._collection.index_information() + self.assertEqual(sorted(info.keys()), + ['_cls_1_name_1', '_cls_1_user_guid_1', '_id_']) + + def test_disable_index_creation(self): + """Tests setting auto_create_index to False on the connection will + disable any index generation. + """ + class User(Document): + meta = { + 'allow_inheritance': True, + 'indexes': ['user_guid'], + 'auto_create_index': False + } + user_guid = StringField(required=True) + + class MongoUser(User): + pass + + User.drop_collection() + + User(user_guid='123').save() + MongoUser(user_guid='123').save() + + self.assertEqual(2, User.objects.count()) + info = User.objects._collection.index_information() + self.assertEqual(info.keys(), ['_id_']) + + User.ensure_indexes() + info = User.objects._collection.index_information() + self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) + User.drop_collection() + + def test_embedded_document_index(self): + """Tests settings an index on an embedded document + """ + class Date(EmbeddedDocument): + year = IntField(db_field='yr') + + class BlogPost(Document): + title = StringField() + date = EmbeddedDocumentField(Date) + + meta = { + 'indexes': [ + '-date.year' + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1']) + BlogPost.drop_collection() + + def test_list_embedded_document_index(self): + """Ensure list embedded documents can be indexed + """ + class Tag(EmbeddedDocument): + name = StringField(db_field='tag') + + class BlogPost(Document): + title = StringField() + tags = ListField(EmbeddedDocumentField(Tag)) + + meta = { + 'indexes': [ + 'tags.name' + ] + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # we don't use _cls in with list fields by default + self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1']) + + post1 = BlogPost(title="Embedded Indexes tests in place", + tags=[Tag(name="about"), Tag(name="time")] + ) + post1.save() + BlogPost.drop_collection() + + def test_recursive_embedded_objects_dont_break_indexes(self): + + class RecursiveObject(EmbeddedDocument): + obj = EmbeddedDocumentField('self') + + class RecursiveDocument(Document): + recursive_obj = EmbeddedDocumentField(RecursiveObject) + meta = {'allow_inheritance': True} + + RecursiveDocument.ensure_indexes() + info = RecursiveDocument._get_collection().index_information() + self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_']) + + def test_geo_indexes_recursion(self): + + class Location(Document): + name = StringField() + location = GeoPointField() + + class Parent(Document): + name = StringField() + location = ReferenceField(Location, dbref=False) + + Location.drop_collection() + Parent.drop_collection() + + list(Parent.objects) + + collection = Parent._get_collection() + info = collection.index_information() + + self.assertFalse('location_2d' in info) + + self.assertEqual(len(Parent._geo_indices()), 0) + self.assertEqual(len(Location._geo_indices()), 1) + + def test_covered_index(self): + """Ensure that covered indexes can be used + """ + + class Test(Document): + a = IntField() + + meta = { + 'indexes': ['a'], + 'allow_inheritance': False + } + + Test.drop_collection() + + obj = Test(a=1) + obj.save() + + # Need to be explicit about covered indexes as mongoDB doesn't know if + # the documents returned might have more keys in that here. + query_plan = Test.objects(id=obj.id).exclude('a').explain() + self.assertFalse(query_plan['indexOnly']) + + query_plan = Test.objects(id=obj.id).only('id').explain() + self.assertTrue(query_plan['indexOnly']) + + query_plan = Test.objects(a=1).only('a').exclude('id').explain() + self.assertTrue(query_plan['indexOnly']) + + def test_index_on_id(self): + + class BlogPost(Document): + meta = { + 'indexes': [ + ['categories', 'id'] + ] + } + + title = StringField(required=True) + description = StringField(required=True) + categories = ListField() + + BlogPost.drop_collection() + + indexes = BlogPost.objects._collection.index_information() + self.assertEqual(indexes['categories_1__id_1']['key'], + [('categories', 1), ('_id', 1)]) + + def test_hint(self): + + class BlogPost(Document): + tags = ListField(StringField()) + meta = { + 'indexes': [ + 'tags', + ], + } + + BlogPost.drop_collection() + + for i in xrange(0, 10): + tags = [("tag %i" % n) for n in xrange(0, i % 2)] + BlogPost(tags=tags).save() + + self.assertEqual(BlogPost.objects.count(), 10) + self.assertEqual(BlogPost.objects.hint().count(), 10) + self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) + + self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) + + def invalid_index(): + BlogPost.objects.hint('tags') + self.assertRaises(TypeError, invalid_index) + + def invalid_index_2(): + return BlogPost.objects.hint(('tags', 1)) + self.assertRaises(TypeError, invalid_index_2) + + def test_unique(self): + """Ensure that uniqueness constraints are applied to fields. + """ + class BlogPost(Document): + title = StringField() + slug = StringField(unique=True) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', slug='test') + post1.save() + + # Two posts with the same slug is not allowed + post2 = BlogPost(title='test2', slug='test') + self.assertRaises(NotUniqueError, post2.save) + + # Ensure backwards compatibilty for errors + self.assertRaises(OperationError, post2.save) + + def test_unique_with(self): + """Ensure that unique_with constraints are applied to fields. + """ + class Date(EmbeddedDocument): + year = IntField(db_field='yr') + + class BlogPost(Document): + title = StringField() + date = EmbeddedDocumentField(Date) + slug = StringField(unique_with='date.year') + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', date=Date(year=2009), slug='test') + post1.save() + + # day is different so won't raise exception + post2 = BlogPost(title='test2', date=Date(year=2010), slug='test') + post2.save() + + # Now there will be two docs with the same slug and the same day: fail + post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_embedded_document(self): + """Ensure that uniqueness constraints are applied to fields on embedded documents. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', + sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', + sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', + sub=SubDocument(year=2010, slug='test')) + self.assertRaises(NotUniqueError, post3.save) + + BlogPost.drop_collection() + + def test_unique_with_embedded_document_and_embedded_unique(self): + """Ensure that uniqueness constraints are applied to fields on + embedded documents. And work with unique_with as well. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField(unique_with='sub.year') + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', + sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', + sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', + sub=SubDocument(year=2010, slug='test')) + self.assertRaises(NotUniqueError, post3.save) + + # Now there will be two docs with the same title and year + post3 = BlogPost(title='test1', + sub=SubDocument(year=2009, slug='test-1')) + self.assertRaises(NotUniqueError, post3.save) + + BlogPost.drop_collection() + + def test_ttl_indexes(self): + + class Log(Document): + created = DateTimeField(default=datetime.now) + meta = { + 'indexes': [ + {'fields': ['created'], 'expireAfterSeconds': 3600} + ] + } + + Log.drop_collection() + + if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: + raise SkipTest('pymongo needs to be 2.3 or higher for this test') + + connection = get_connection() + version_array = connection.server_info()['versionArray'] + if version_array[0] < 2 and version_array[1] < 2: + raise SkipTest('MongoDB needs to be 2.2 or higher for this test') + + # Indexes are lazy so use list() to perform query + list(Log.objects) + info = Log.objects._collection.index_information() + self.assertEqual(3600, + info['created_1']['expireAfterSeconds']) + + def test_unique_and_indexes(self): + """Ensure that 'unique' constraints aren't overridden by + meta.indexes. + """ + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'allow_inheritance': False, + } + + Customer.drop_collection() + cust = Customer(cust_id=1) + cust.save() + + cust_dupe = Customer(cust_id=1) + try: + cust_dupe.save() + raise AssertionError("We saved a dupe!") + except NotUniqueError: + pass + Customer.drop_collection() + + def test_unique_and_primary(self): + """If you set a field as primary, then unexpected behaviour can occur. + You won't create a duplicate but you will update an existing document. + """ + + class User(Document): + name = StringField(primary_key=True, unique=True) + password = StringField() + + User.drop_collection() + + user = User(name='huangz', password='secret') + user.save() + + user = User(name='huangz', password='secret2') + user.save() + + self.assertEqual(User.objects.count(), 1) + self.assertEqual(User.objects.get().password, 'secret2') + + User.drop_collection() + + def test_index_with_pk(self): + """Ensure you can use `pk` as part of a query""" + + class Comment(EmbeddedDocument): + comment_id = IntField(required=True) + + try: + class BlogPost(Document): + comments = EmbeddedDocumentField(Comment) + meta = {'indexes': [ + {'fields': ['pk', 'comments.comment_id'], + 'unique': True}]} + except UnboundLocalError: + self.fail('Unbound local error at index + pk definition') + + info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + index_item = [('_id', 1), ('comments.comment_id', 1)] + self.assertTrue(index_item in info) + + def test_compound_key_embedded(self): + + class CompoundKey(EmbeddedDocument): + name = StringField(required=True) + term = StringField(required=True) + + class Report(Document): + key = EmbeddedDocumentField(CompoundKey, primary_key=True) + text = StringField() + + Report.drop_collection() + + my_key = CompoundKey(name="n", term="ok") + report = Report(text="OK", key=my_key).save() + + self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, + report.to_mongo()) + self.assertEqual(report, Report.objects.get(pk=my_key)) + + def test_compound_key_dictfield(self): + + class Report(Document): + key = DictField(primary_key=True) + text = StringField() + + Report.drop_collection() + + my_key = {"name": "n", "term": "ok"} + report = Report(text="OK", key=my_key).save() + + self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, + report.to_mongo()) + self.assertEqual(report, Report.objects.get(pk=my_key)) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py new file mode 100644 index 00000000..f0116311 --- /dev/null +++ b/tests/document/inheritance.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] +import unittest +import warnings + +from datetime import datetime + +from tests.fixtures import Base + +from mongoengine import Document, EmbeddedDocument, connect +from mongoengine.connection import get_db +from mongoengine.fields import (BooleanField, GenericReferenceField, + IntField, StringField) + +__all__ = ('InheritanceTest', ) + + +class InheritanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_superclasses(self): + """Ensure that the correct list of superclasses is assembled. + """ + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Fish._superclasses, ('Animal',)) + self.assertEqual(Guppy._superclasses, ('Animal', 'Animal.Fish')) + self.assertEqual(Mammal._superclasses, ('Animal',)) + self.assertEqual(Dog._superclasses, ('Animal', 'Animal.Mammal')) + self.assertEqual(Human._superclasses, ('Animal', 'Animal.Mammal')) + + def test_external_superclasses(self): + """Ensure that the correct list of super classes is assembled when + importing part of the model. + """ + class Animal(Base): pass + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._superclasses, ('Base', )) + self.assertEqual(Fish._superclasses, ('Base', 'Base.Animal',)) + self.assertEqual(Guppy._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Fish')) + self.assertEqual(Mammal._superclasses, ('Base', 'Base.Animal',)) + self.assertEqual(Dog._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Mammal')) + self.assertEqual(Human._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Mammal')) + + def test_subclasses(self): + """Ensure that the correct list of _subclasses (subclasses) is + assembled. + """ + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._subclasses, ('Animal', + 'Animal.Fish', + 'Animal.Fish.Guppy', + 'Animal.Mammal', + 'Animal.Mammal.Dog', + 'Animal.Mammal.Human')) + self.assertEqual(Fish._subclasses, ('Animal.Fish', + 'Animal.Fish.Guppy',)) + self.assertEqual(Guppy._subclasses, ('Animal.Fish.Guppy',)) + self.assertEqual(Mammal._subclasses, ('Animal.Mammal', + 'Animal.Mammal.Dog', + 'Animal.Mammal.Human')) + self.assertEqual(Human._subclasses, ('Animal.Mammal.Human',)) + + def test_external_subclasses(self): + """Ensure that the correct list of _subclasses (subclasses) is + assembled when importing part of the model. + """ + class Animal(Base): pass + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._subclasses, ('Base.Animal', + 'Base.Animal.Fish', + 'Base.Animal.Fish.Guppy', + 'Base.Animal.Mammal', + 'Base.Animal.Mammal.Dog', + 'Base.Animal.Mammal.Human')) + self.assertEqual(Fish._subclasses, ('Base.Animal.Fish', + 'Base.Animal.Fish.Guppy',)) + self.assertEqual(Guppy._subclasses, ('Base.Animal.Fish.Guppy',)) + self.assertEqual(Mammal._subclasses, ('Base.Animal.Mammal', + 'Base.Animal.Mammal.Dog', + 'Base.Animal.Mammal.Human')) + self.assertEqual(Human._subclasses, ('Base.Animal.Mammal.Human',)) + + def test_dynamic_declarations(self): + """Test that declaring an extra class updates meta data""" + + class Animal(Document): + meta = {'allow_inheritance': True} + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal',)) + + # Test dynamically adding a class changes the meta data + class Fish(Animal): + pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish')) + + self.assertEqual(Fish._superclasses, ('Animal', )) + self.assertEqual(Fish._subclasses, ('Animal.Fish',)) + + # Test dynamically adding an inherited class changes the meta data + class Pike(Fish): + pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish', + 'Animal.Fish.Pike')) + + self.assertEqual(Fish._superclasses, ('Animal', )) + self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike')) + + self.assertEqual(Pike._superclasses, ('Animal', 'Animal.Fish')) + self.assertEqual(Pike._subclasses, ('Animal.Fish.Pike',)) + + def test_inheritance_meta_data(self): + """Ensure that document may inherit fields from a superclass document. + """ + class Person(Document): + name = StringField() + age = IntField() + + meta = {'allow_inheritance': True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(['age', 'id', 'name', 'salary'], + sorted(Employee._fields.keys())) + self.assertEqual(Employee._get_collection_name(), + Person._get_collection_name()) + + def test_inheritance_to_mongo_keys(self): + """Ensure that document may inherit fields from a superclass document. + """ + class Person(Document): + name = StringField() + age = IntField() + + meta = {'allow_inheritance': True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(['age', 'id', 'name', 'salary'], + sorted(Employee._fields.keys())) + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) + self.assertEqual(Employee._get_collection_name(), + Person._get_collection_name()) + + def test_polymorphic_queries(self): + """Ensure that the correct subclasses are returned from a query + """ + + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + Animal.drop_collection() + + Animal().save() + Fish().save() + Mammal().save() + Dog().save() + Human().save() + + classes = [obj.__class__ for obj in Animal.objects] + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + + classes = [obj.__class__ for obj in Mammal.objects] + self.assertEqual(classes, [Mammal, Dog, Human]) + + classes = [obj.__class__ for obj in Human.objects] + self.assertEqual(classes, [Human]) + + def test_allow_inheritance(self): + """Ensure that inheritance may be disabled on simple classes and that + _cls and _subclasses will not be used. + """ + + class Animal(Document): + name = StringField() + + def create_dog_class(): + class Dog(Animal): + pass + + self.assertRaises(ValueError, create_dog_class) + + # Check that _cls etc aren't present on simple documents + dog = Animal(name='dog').save() + self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) + + collection = self.db[Animal._get_collection_name()] + obj = collection.find_one() + self.assertFalse('_cls' in obj) + + def test_cant_turn_off_inheritance_on_subclass(self): + """Ensure if inheritance is on in a subclass you cant turn it off + """ + + class Animal(Document): + name = StringField() + meta = {'allow_inheritance': True} + + def create_mammal_class(): + class Mammal(Animal): + meta = {'allow_inheritance': False} + self.assertRaises(ValueError, create_mammal_class) + + def test_allow_inheritance_abstract_document(self): + """Ensure that abstract documents can set inheritance rules and that + _cls will not be used. + """ + class FinalDocument(Document): + meta = {'abstract': True, + 'allow_inheritance': False} + + class Animal(FinalDocument): + name = StringField() + + def create_mammal_class(): + class Mammal(Animal): + pass + self.assertRaises(ValueError, create_mammal_class) + + # Check that _cls isn't present in simple documents + doc = Animal(name='dog') + self.assertFalse('_cls' in doc.to_mongo()) + + def test_allow_inheritance_embedded_document(self): + """Ensure embedded documents respect inheritance + """ + + class Comment(EmbeddedDocument): + content = StringField() + + def create_special_comment(): + class SpecialComment(Comment): + pass + + self.assertRaises(ValueError, create_special_comment) + + doc = Comment(content='test') + self.assertFalse('_cls' in doc.to_mongo()) + + class Comment(EmbeddedDocument): + content = StringField() + meta = {'allow_inheritance': True} + + doc = Comment(content='test') + self.assertTrue('_cls' in doc.to_mongo()) + + def test_document_inheritance(self): + """Ensure mutliple inheritance of abstract documents + """ + class DateCreatedDocument(Document): + meta = { + 'allow_inheritance': True, + 'abstract': True, + } + + class DateUpdatedDocument(Document): + meta = { + 'allow_inheritance': True, + 'abstract': True, + } + + try: + class MyDocument(DateCreatedDocument, DateUpdatedDocument): + pass + except: + self.assertTrue(False, "Couldn't create MyDocument class") + + def test_abstract_documents(self): + """Ensure that a document superclass can be marked as abstract + thereby not using it as the name for the collection.""" + + defaults = {'index_background': True, + 'index_drop_dups': True, + 'index_opts': {'hello': 'world'}, + 'allow_inheritance': True, + 'queryset_class': 'QuerySet', + 'db_alias': 'myDB', + 'shard_key': ('hello', 'world')} + + meta_settings = {'abstract': True} + meta_settings.update(defaults) + + class Animal(Document): + name = StringField() + meta = meta_settings + + class Fish(Animal): pass + class Guppy(Fish): pass + + class Mammal(Animal): + meta = {'abstract': True} + class Human(Mammal): pass + + for k, v in defaults.iteritems(): + for cls in [Animal, Fish, Guppy]: + self.assertEqual(cls._meta[k], v) + + self.assertFalse('collection' in Animal._meta) + self.assertFalse('collection' in Mammal._meta) + + self.assertEqual(Animal._get_collection_name(), None) + self.assertEqual(Mammal._get_collection_name(), None) + + self.assertEqual(Fish._get_collection_name(), 'fish') + self.assertEqual(Guppy._get_collection_name(), 'fish') + self.assertEqual(Human._get_collection_name(), 'human') + + def create_bad_abstract(): + class EvilHuman(Human): + evil = BooleanField(default=True) + meta = {'abstract': True} + self.assertRaises(ValueError, create_bad_abstract) + + def test_inherited_collections(self): + """Ensure that subclassed documents don't override parents' + collections + """ + + class Drink(Document): + name = StringField() + meta = {'allow_inheritance': True} + + class Drinker(Document): + drink = GenericReferenceField() + + try: + warnings.simplefilter("error") + + class AcloholicDrink(Drink): + meta = {'collection': 'booze'} + + except SyntaxWarning: + warnings.simplefilter("ignore") + + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + else: + raise AssertionError("SyntaxWarning should be triggered") + + warnings.resetwarnings() + + Drink.drop_collection() + AlcoholicDrink.drop_collection() + Drinker.drop_collection() + + red_bull = Drink(name='Red Bull') + red_bull.save() + + programmer = Drinker(drink=red_bull) + programmer.save() + + beer = AlcoholicDrink(name='Beer') + beer.save() + real_person = Drinker(drink=beer) + real_person.save() + + self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) + self.assertEqual(Drinker.objects[1].drink.name, beer.name) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/instance.py b/tests/document/instance.py new file mode 100644 index 00000000..06744ab4 --- /dev/null +++ b/tests/document/instance.py @@ -0,0 +1,2246 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import sys +sys.path[0:0] = [""] + +import bson +import os +import pickle +import unittest +import uuid + +from datetime import datetime +from tests.fixtures import PickleEmbedded, PickleTest + +from mongoengine import * +from mongoengine.errors import (NotRegistered, InvalidDocumentError, + InvalidQueryError) +from mongoengine.queryset import NULLIFY, Q +from mongoengine.connection import get_db +from mongoengine.base import get_document +from mongoengine.context_managers import switch_db, query_counter +from mongoengine import signals + +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), + '../fields/mongoengine.png') + +__all__ = ("InstanceTest",) + + +class InstanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_capped_collection(self): + """Ensure that capped collections work properly. + """ + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 10, + 'max_size': 90000, + } + + Log.drop_collection() + + # Ensure that the collection handles up to its maximum + for _ in range(10): + Log().save() + + self.assertEqual(Log.objects.count(), 10) + + # Check that extra documents don't increase the size + Log().save() + self.assertEqual(Log.objects.count(), 10) + + options = Log.objects._collection.options() + self.assertEqual(options['capped'], True) + self.assertEqual(options['max'], 10) + self.assertEqual(options['size'], 90000) + + # Check that the document cannot be redefined with different options + def recreate_log_document(): + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 11, + } + # Create the collection by accessing Document.objects + Log.objects + self.assertRaises(InvalidCollectionError, recreate_log_document) + + Log.drop_collection() + + def test_repr(self): + """Ensure that unicode representation works + """ + class Article(Document): + title = StringField() + + def __unicode__(self): + return self.title + + doc = Article(title=u'привет мир') + + self.assertEqual('', repr(doc)) + + def test_queryset_resurrects_dropped_collection(self): + self.Person.drop_collection() + + self.assertEqual([], list(self.Person.objects())) + + class Actor(self.Person): + pass + + # Ensure works correctly with inhertited classes + Actor.objects() + self.Person.drop_collection() + self.assertEqual([], list(Actor.objects())) + + def test_polymorphic_references(self): + """Ensure that the correct subclasses are returned from a query when + using references / generic references + """ + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + class Zoo(Document): + animals = ListField(ReferenceField(Animal)) + + Zoo.drop_collection() + Animal.drop_collection() + + Animal().save() + Fish().save() + Mammal().save() + Dog().save() + Human().save() + + # Save a reference to each animal + zoo = Zoo(animals=Animal.objects) + zoo.save() + zoo.reload() + + classes = [a.__class__ for a in Zoo.objects.first().animals] + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + + Zoo.drop_collection() + + class Zoo(Document): + animals = ListField(GenericReferenceField(Animal)) + + # Save a reference to each animal + zoo = Zoo(animals=Animal.objects) + zoo.save() + zoo.reload() + + classes = [a.__class__ for a in Zoo.objects.first().animals] + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + + Zoo.drop_collection() + Animal.drop_collection() + + def test_reference_inheritance(self): + class Stats(Document): + created = DateTimeField(default=datetime.now) + + meta = {'allow_inheritance': False} + + class CompareStats(Document): + generated = DateTimeField(default=datetime.now) + stats = ListField(ReferenceField(Stats)) + + Stats.drop_collection() + CompareStats.drop_collection() + + list_stats = [] + + for i in xrange(10): + s = Stats() + s.save() + list_stats.append(s) + + cmp_stats = CompareStats(stats=list_stats) + cmp_stats.save() + + self.assertEqual(list_stats, CompareStats.objects.first().stats) + + def test_db_field_load(self): + """Ensure we load data correctly + """ + class Person(Document): + name = StringField(required=True) + _rank = StringField(required=False, db_field="rank") + + @property + def rank(self): + return self._rank or "Private" + + Person.drop_collection() + + Person(name="Jack", _rank="Corporal").save() + + Person(name="Fred").save() + + self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") + self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + + def test_db_embedded_doc_field_load(self): + """Ensure we load embedded document data correctly + """ + class Rank(EmbeddedDocument): + title = StringField(required=True) + + class Person(Document): + name = StringField(required=True) + rank_ = EmbeddedDocumentField(Rank, + required=False, + db_field='rank') + + @property + def rank(self): + if self.rank_ is None: + return "Private" + return self.rank_.title + + Person.drop_collection() + + Person(name="Jack", rank_=Rank(title="Corporal")).save() + Person(name="Fred").save() + + self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") + self.assertEqual(Person.objects.get(name="Fred").rank, "Private") + + def test_custom_id_field(self): + """Ensure that documents may be created with custom primary keys. + """ + class User(Document): + username = StringField(primary_key=True) + name = StringField() + + meta = {'allow_inheritance': True} + + User.drop_collection() + + self.assertEqual(User._fields['username'].db_field, '_id') + self.assertEqual(User._meta['id_field'], 'username') + + def create_invalid_user(): + User(name='test').save() # no primary key field + self.assertRaises(ValidationError, create_invalid_user) + + def define_invalid_user(): + class EmailUser(User): + email = StringField(primary_key=True) + self.assertRaises(ValueError, define_invalid_user) + + class EmailUser(User): + email = StringField() + + user = User(username='test', name='test user') + user.save() + + user_obj = User.objects.first() + self.assertEqual(user_obj.id, 'test') + self.assertEqual(user_obj.pk, 'test') + + user_son = User.objects._collection.find_one() + self.assertEqual(user_son['_id'], 'test') + self.assertTrue('username' not in user_son['_id']) + + User.drop_collection() + + user = User(pk='mongo', name='mongo user') + user.save() + + user_obj = User.objects.first() + self.assertEqual(user_obj.id, 'mongo') + self.assertEqual(user_obj.pk, 'mongo') + + user_son = User.objects._collection.find_one() + self.assertEqual(user_son['_id'], 'mongo') + self.assertTrue('username' not in user_son['_id']) + + User.drop_collection() + + def test_document_not_registered(self): + + class Place(Document): + name = StringField() + + meta = {'allow_inheritance': True} + + class NicePlace(Place): + pass + + Place.drop_collection() + + Place(name="London").save() + NicePlace(name="Buckingham Palace").save() + + # Mimic Place and NicePlace definitions being in a different file + # and the NicePlace model not being imported in at query time. + from mongoengine.base import _document_registry + del(_document_registry['Place.NicePlace']) + + def query_without_importing_nice_place(): + print Place.objects.all() + self.assertRaises(NotRegistered, query_without_importing_nice_place) + + def test_document_registry_regressions(self): + + class Location(Document): + name = StringField() + meta = {'allow_inheritance': True} + + class Area(Location): + location = ReferenceField('Location', dbref=True) + + Location.drop_collection() + + self.assertEquals(Area, get_document("Area")) + self.assertEquals(Area, get_document("Location.Area")) + + def test_creation(self): + """Ensure that document may be created using keyword arguments. + """ + person = self.Person(name="Test User", age=30) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 30) + + def test_to_dbref(self): + """Ensure that you can get a dbref of a document""" + person = self.Person(name="Test User", age=30) + self.assertRaises(OperationError, person.to_dbref) + person.save() + + person.to_dbref() + + def test_reload(self): + """Ensure that attributes may be reloaded. + """ + person = self.Person(name="Test User", age=20) + person.save() + + person_obj = self.Person.objects.first() + person_obj.name = "Mr Test User" + person_obj.age = 21 + person_obj.save() + + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 20) + + person.reload() + self.assertEqual(person.name, "Mr Test User") + self.assertEqual(person.age, 21) + + def test_reload_sharded(self): + class Animal(Document): + superphylum = StringField() + meta = {'shard_key': ('superphylum',)} + + Animal.drop_collection() + doc = Animal(superphylum='Deuterostomia') + doc.save() + doc.reload() + Animal.drop_collection() + + def test_reload_referencing(self): + """Ensures reloading updates weakrefs correctly + """ + class Embedded(EmbeddedDocument): + dict_field = DictField() + list_field = ListField() + + class Doc(Document): + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + doc = Doc() + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + + embedded_1 = Embedded() + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + doc.save() + + doc = doc.reload(10) + doc.list_field.append(1) + doc.dict_field['woot'] = "woot" + doc.embedded_field.list_field.append(1) + doc.embedded_field.dict_field['woot'] = "woot" + + self.assertEqual(doc._get_changed_fields(), [ + 'list_field', 'dict_field', 'embedded_field.list_field', + 'embedded_field.dict_field']) + doc.save() + + doc = doc.reload(10) + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(len(doc.list_field), 4) + self.assertEqual(len(doc.dict_field), 2) + self.assertEqual(len(doc.embedded_field.list_field), 4) + self.assertEqual(len(doc.embedded_field.dict_field), 2) + + def test_dictionary_access(self): + """Ensure that dictionary-style field access works properly. + """ + person = self.Person(name='Test User', age=30) + self.assertEqual(person['name'], 'Test User') + + self.assertRaises(KeyError, person.__getitem__, 'salary') + self.assertRaises(KeyError, person.__setitem__, 'salary', 50) + + person['name'] = 'Another User' + self.assertEqual(person['name'], 'Another User') + + # Length = length(assigned fields + id) + self.assertEqual(len(person), 3) + + self.assertTrue('age' in person) + person.age = None + self.assertFalse('age' in person) + self.assertFalse('nationality' in person) + + def test_embedded_document_to_mongo(self): + class Person(EmbeddedDocument): + name = StringField() + age = IntField() + + meta = {"allow_inheritance": True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), + ['_cls', 'name', 'age']) + self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ['_cls', 'name', 'age', 'salary']) + + def test_embedded_document(self): + """Ensure that embedded documents are set up correctly. + """ + class Comment(EmbeddedDocument): + content = StringField() + + self.assertTrue('content' in Comment._fields) + self.assertFalse('id' in Comment._fields) + + def test_embedded_document_instance(self): + """Ensure that embedded documents can reference parent instance + """ + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + Doc(embedded_field=Embedded(string="Hi")).save() + + doc = Doc.objects.get() + self.assertEqual(doc, doc.embedded_field._instance) + + def test_embedded_document_complex_instance(self): + """Ensure that embedded documents in complex fields can reference + parent instance""" + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + embedded_field = ListField(EmbeddedDocumentField(Embedded)) + + Doc.drop_collection() + Doc(embedded_field=[Embedded(string="Hi")]).save() + + doc = Doc.objects.get() + self.assertEqual(doc, doc.embedded_field[0]._instance) + + def test_document_clean(self): + class TestDocument(Document): + status = StringField() + pub_date = DateTimeField() + + def clean(self): + if self.status == 'draft' and self.pub_date is not None: + msg = 'Draft entries may not have a publication date.' + raise ValidationError(msg) + # Set the pub_date for published items if not set. + if self.status == 'published' and self.pub_date is None: + self.pub_date = datetime.now() + + TestDocument.drop_collection() + + t = TestDocument(status="draft", pub_date=datetime.now()) + + try: + t.save() + except ValidationError, e: + expect_msg = "Draft entries may not have a publication date." + self.assertTrue(expect_msg in e.message) + self.assertEqual(e.to_dict(), {'__all__': expect_msg}) + + t = TestDocument(status="published") + t.save(clean=False) + + self.assertEquals(t.pub_date, None) + + t = TestDocument(status="published") + t.save(clean=True) + + self.assertEquals(type(t.pub_date), datetime) + + def test_document_embedded_clean(self): + class TestEmbeddedDocument(EmbeddedDocument): + x = IntField(required=True) + y = IntField(required=True) + z = IntField(required=True) + + meta = {'allow_inheritance': False} + + def clean(self): + if self.z: + if self.z != self.x + self.y: + raise ValidationError('Value of z != x + y') + else: + self.z = self.x + self.y + + class TestDocument(Document): + doc = EmbeddedDocumentField(TestEmbeddedDocument) + status = StringField() + + TestDocument.drop_collection() + + t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) + try: + t.save() + except ValidationError, e: + expect_msg = "Value of z != x + y" + self.assertTrue(expect_msg in e.message) + self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) + + t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() + self.assertEquals(t.doc.z, 35) + + # Asserts not raises + t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) + t.save(clean=False) + + def test_save(self): + """Ensure that a document may be saved in the database. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + # Ensure that the object is in the database + collection = self.db[self.Person._get_collection_name()] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(person_obj['name'], 'Test User') + self.assertEqual(person_obj['age'], 30) + self.assertEqual(person_obj['_id'], person.id) + # Test skipping validation on save + + class Recipient(Document): + email = EmailField(required=True) + + recipient = Recipient(email='root@localhost') + self.assertRaises(ValidationError, recipient.save) + + try: + recipient.save(validate=False) + except ValidationError: + self.fail() + + def test_save_to_a_value_that_equates_to_false(self): + + class Thing(EmbeddedDocument): + count = IntField() + + class User(Document): + thing = EmbeddedDocumentField(Thing) + + User.drop_collection() + + user = User(thing=Thing(count=1)) + user.save() + user.reload() + + user.thing.count = 0 + user.save() + + user.reload() + self.assertEqual(user.thing.count, 0) + + def test_save_max_recursion_not_hit(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friend = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p1.friend = p2 + p1.save() + + # Confirm can save and it resets the changed fields without hitting + # max recursion error + p0 = Person.objects.first() + p0.name = 'wpjunior' + p0.save() + + def test_save_max_recursion_not_hit_with_file_field(self): + + class Foo(Document): + name = StringField() + picture = FileField() + bar = ReferenceField('self') + + Foo.drop_collection() + + a = Foo(name='hello').save() + + a.bar = a + with open(TEST_IMAGE_PATH, 'rb') as test_image: + a.picture = test_image + a.save() + + # Confirm can save and it resets the changed fields without hitting + # max recursion error + b = Foo.objects.with_id(a.id) + b.name = 'world' + b.save() + + self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) + + def test_save_cascades(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEqual(p1.name, p.parent.name) + + def test_save_cascade_kwargs(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save(force_insert=True, cascade_kwargs={"force_insert": False}) + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEqual(p1.name, p.parent.name) + + def test_save_cascade_meta_false(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + meta = {'cascade': False} + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertNotEqual(p1.name, p.parent.name) + + p.save(cascade=True) + p1.reload() + self.assertEqual(p1.name, p.parent.name) + + def test_save_cascade_meta_true(self): + + class Person(Document): + name = StringField() + parent = ReferenceField('self') + + meta = {'cascade': False} + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.parent = None + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save(cascade=True) + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertNotEqual(p1.name, p.parent.name) + + def test_save_cascades_generically(self): + + class Person(Document): + name = StringField() + parent = GenericReferenceField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + + p2 = Person(name="Wilson Jr") + p2.parent = p1 + p2.save() + + p = Person.objects(name="Wilson Jr").get() + p.parent.name = "Daddy Wilson" + p.save() + + p1.reload() + self.assertEqual(p1.name, p.parent.name) + + def test_update(self): + """Ensure that an existing document is updated instead of be + overwritten.""" + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + + # Create same person object, with same id, without age + same_person = self.Person(name='Test') + same_person.id = person.id + same_person.save() + + # Confirm only one object + self.assertEqual(self.Person.objects.count(), 1) + + # reload + person.reload() + same_person.reload() + + # Confirm the same + self.assertEqual(person, same_person) + self.assertEqual(person.name, same_person.name) + self.assertEqual(person.age, same_person.age) + + # Confirm the saved values + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # Test only / exclude only updates included fields + person = self.Person.objects.only('name').get() + person.name = 'User' + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Test only / exclude can set non excluded / included fields + person = self.Person.objects.only('name').get() + person.name = 'Test' + person.age = 30 + person.save() + + person.reload() + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.name = 'User' + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Confirm does remove unrequired fields + person = self.Person.objects.exclude('name').get() + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, None) + + person = self.Person.objects.get() + person.name = None + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, None) + self.assertEqual(person.age, None) + + def test_can_save_if_not_included(self): + + class EmbeddedDoc(EmbeddedDocument): + pass + + class Simple(Document): + pass + + class Doc(Document): + string_field = StringField(default='1') + int_field = IntField(default=1) + float_field = FloatField(default=1.1) + boolean_field = BooleanField(default=True) + datetime_field = DateTimeField(default=datetime.now) + embedded_document_field = EmbeddedDocumentField( + EmbeddedDoc, default=lambda: EmbeddedDoc()) + list_field = ListField(default=lambda: [1, 2, 3]) + dict_field = DictField(default=lambda: {"hello": "world"}) + objectid_field = ObjectIdField(default=bson.ObjectId) + reference_field = ReferenceField(Simple, default=lambda: + Simple().save()) + map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) + complex_datetime_field = ComplexDateTimeField(default=datetime.now) + url_field = URLField(default="http://mongoengine.org") + dynamic_field = DynamicField(default=1) + generic_reference_field = GenericReferenceField( + default=lambda: Simple().save()) + sorted_list_field = SortedListField(IntField(), + default=lambda: [1, 2, 3]) + email_field = EmailField(default="ross@example.com") + geo_point_field = GeoPointField(default=lambda: [1, 2]) + sequence_field = SequenceField() + uuid_field = UUIDField(default=uuid.uuid4) + generic_embedded_document_field = GenericEmbeddedDocumentField( + default=lambda: EmbeddedDoc()) + + Simple.drop_collection() + Doc.drop_collection() + + Doc().save() + my_doc = Doc.objects.only("string_field").first() + my_doc.string_field = "string" + my_doc.save() + + my_doc = Doc.objects.get(string_field="string") + self.assertEqual(my_doc.string_field, "string") + self.assertEqual(my_doc.int_field, 1) + + def test_document_update(self): + + def update_not_saved_raises(): + person = self.Person(name='dcrosta') + person.update(set__name='Dan Crosta') + + self.assertRaises(OperationError, update_not_saved_raises) + + author = self.Person(name='dcrosta') + author.save() + + author.update(set__name='Dan Crosta') + author.reload() + + p1 = self.Person.objects.first() + self.assertEqual(p1.name, author.name) + + def update_no_value_raises(): + person = self.Person.objects.first() + person.update() + + self.assertRaises(OperationError, update_no_value_raises) + + def update_no_op_raises(): + person = self.Person.objects.first() + person.update(name="Dan") + + self.assertRaises(InvalidQueryError, update_no_op_raises) + + def test_embedded_update(self): + """ + Test update on `EmbeddedDocumentField` fields + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + Site.drop_collection() + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_embedded_update_db_field(self): + """ + Test update on `EmbeddedDocumentField` fields when db_field is other + than default. + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + db_field="page_log_message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + Site.drop_collection() + + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_save_only_changed_fields(self): + """Ensure save only sets / unsets changed fields + """ + + class User(self.Person): + active = BooleanField(default=True) + + User.drop_collection() + + # Create person object and save it to the database + user = User(name='Test User', age=30, active=True) + user.save() + user.reload() + + # Simulated Race condition + same_person = self.Person.objects.get() + same_person.active = False + + user.age = 21 + user.save() + + same_person.name = 'User' + same_person.save() + + person = self.Person.objects.get() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + self.assertEqual(person.active, False) + + def test_set_unset_one_operation(self): + """Ensure that $set and $unset actions are performed in the same + operation. + """ + class FooBar(Document): + foo = StringField(default=None) + bar = StringField(default=None) + + FooBar.drop_collection() + + # write an entity with a single prop + foo = FooBar(foo='foo').save() + + self.assertEqual(foo.foo, 'foo') + del foo.foo + foo.bar = 'bar' + + with query_counter() as q: + self.assertEqual(0, q) + foo.save() + self.assertEqual(1, q) + + def test_save_only_changed_fields_recursive(self): + """Ensure save only sets / unsets changed fields + """ + + class Comment(EmbeddedDocument): + published = BooleanField(default=True) + + class User(self.Person): + comments_dict = DictField() + comments = ListField(EmbeddedDocumentField(Comment)) + active = BooleanField(default=True) + + User.drop_collection() + + # Create person object and save it to the database + person = User(name='Test User', age=30, active=True) + person.comments.append(Comment()) + person.save() + person.reload() + + person = self.Person.objects.get() + self.assertTrue(person.comments[0].published) + + person.comments[0].published = False + person.save() + + person = self.Person.objects.get() + self.assertFalse(person.comments[0].published) + + # Simple dict w + person.comments_dict['first_post'] = Comment() + person.save() + + person = self.Person.objects.get() + self.assertTrue(person.comments_dict['first_post'].published) + + person.comments_dict['first_post'].published = False + person.save() + + person = self.Person.objects.get() + self.assertFalse(person.comments_dict['first_post'].published) + + def test_delete(self): + """Ensure that document may be deleted using the delete method. + """ + person = self.Person(name="Test User", age=30) + person.save() + self.assertEqual(self.Person.objects.count(), 1) + person.delete() + self.assertEqual(self.Person.objects.count(), 0) + + def test_save_custom_id(self): + """Ensure that a document may be saved with a custom _id. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30, + id='497ce96f395f2f052a494fd4') + person.save() + # Ensure that the object is in the database with the correct _id + collection = self.db[self.Person._get_collection_name()] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + + def test_save_custom_pk(self): + """Ensure that a document may be saved with a custom _id using pk alias. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30, + pk='497ce96f395f2f052a494fd4') + person.save() + # Ensure that the object is in the database with the correct _id + collection = self.db[self.Person._get_collection_name()] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + + def test_save_list(self): + """Ensure that a list field may be properly saved. + """ + class Comment(EmbeddedDocument): + content = StringField() + + class BlogPost(Document): + content = StringField() + comments = ListField(EmbeddedDocumentField(Comment)) + tags = ListField(StringField()) + + BlogPost.drop_collection() + + post = BlogPost(content='Went for a walk today...') + post.tags = tags = ['fun', 'leisure'] + comments = [Comment(content='Good for you'), Comment(content='Yay.')] + post.comments = comments + post.save() + + collection = self.db[BlogPost._get_collection_name()] + post_obj = collection.find_one() + self.assertEqual(post_obj['tags'], tags) + for comment_obj, comment in zip(post_obj['comments'], comments): + self.assertEqual(comment_obj['content'], comment['content']) + + BlogPost.drop_collection() + + def test_list_search_by_embedded(self): + class User(Document): + username = StringField(required=True) + + meta = {'allow_inheritance': False} + + class Comment(EmbeddedDocument): + comment = StringField() + user = ReferenceField(User, + required=True) + + meta = {'allow_inheritance': False} + + class Page(Document): + comments = ListField(EmbeddedDocumentField(Comment)) + meta = {'allow_inheritance': False, + 'indexes': [ + {'fields': ['comments.user']} + ]} + + User.drop_collection() + Page.drop_collection() + + u1 = User(username="wilson") + u1.save() + + u2 = User(username="rozza") + u2.save() + + u3 = User(username="hmarr") + u3.save() + + p1 = Page(comments=[Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world"), + Comment(user=u3, comment="Ping Pong"), + Comment(user=u1, comment="I like a beer")]) + p1.save() + + p2 = Page(comments=[Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world")]) + p2.save() + + p3 = Page(comments=[Comment(user=u3, comment="Its very good")]) + p3.save() + + p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) + p4.save() + + self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) + self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) + self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) + + def test_save_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Ensure that the object is in the database + collection = self.db[self.Person._get_collection_name()] + employee_obj = collection.find_one({'name': 'Test Employee'}) + self.assertEqual(employee_obj['name'], 'Test Employee') + self.assertEqual(employee_obj['age'], 50) + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(employee_obj['details']['position'], 'Developer') + + def test_embedded_update_after_save(self): + """ + Test update of `EmbeddedDocumentField` attached to a newly saved + document. + """ + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + Site.drop_collection() + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_updating_an_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Test updating an embedded document + promoted_employee = Employee.objects.get(name='Test Employee') + promoted_employee.details.position = 'Senior Developer' + promoted_employee.save() + + promoted_employee.reload() + self.assertEqual(promoted_employee.name, 'Test Employee') + self.assertEqual(promoted_employee.age, 50) + + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(promoted_employee.details.position, 'Senior Developer') + + # Test removal + promoted_employee.details = None + promoted_employee.save() + + promoted_employee.reload() + self.assertEqual(promoted_employee.details, None) + + def test_object_mixins(self): + + class NameMixin(object): + name = StringField() + + class Foo(EmbeddedDocument, NameMixin): + quantity = IntField() + + self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys())) + + class Bar(Document, NameMixin): + widgets = StringField() + + self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) + + def test_mixin_inheritance(self): + class BaseMixIn(object): + count = IntField() + data = StringField() + + class DoubleMixIn(BaseMixIn): + comment = StringField() + + class TestDoc(Document, DoubleMixIn): + age = IntField() + + TestDoc.drop_collection() + t = TestDoc(count=12, data="test", + comment="great!", age=19) + + t.save() + + t = TestDoc.objects.first() + + self.assertEqual(t.age, 19) + self.assertEqual(t.comment, "great!") + self.assertEqual(t.data, "test") + self.assertEqual(t.count, 12) + + def test_save_reference(self): + """Ensure that a document reference field may be saved in the database. + """ + + class BlogPost(Document): + meta = {'collection': 'blogpost_1'} + content = StringField() + author = ReferenceField(self.Person) + + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content='Watched some TV today... how exciting.') + # Should only reference author when saving + post.author = author + post.save() + + post_obj = BlogPost.objects.first() + + # Test laziness + self.assertTrue(isinstance(post_obj._data['author'], + bson.DBRef)) + self.assertTrue(isinstance(post_obj.author, self.Person)) + self.assertEqual(post_obj.author.name, 'Test User') + + # Ensure that the dereferenced object may be changed and saved + post_obj.author.age = 25 + post_obj.author.save() + + author = list(self.Person.objects(name='Test User'))[-1] + self.assertEqual(author.age, 25) + + BlogPost.drop_collection() + + def test_duplicate_db_fields_raise_invalid_document_error(self): + """Ensure a InvalidDocumentError is thrown if duplicate fields + declare the same db_field""" + + def throw_invalid_document_error(): + class Foo(Document): + name = StringField() + name2 = StringField(db_field='name') + + self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + + def test_invalid_son(self): + """Raise an error if loading invalid data""" + class Occurrence(EmbeddedDocument): + number = IntField() + + class Word(Document): + stem = StringField() + count = IntField(default=1) + forms = ListField(StringField(), default=list) + occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) + + def raise_invalid_document(): + Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', + 'occurs': {"hello": None}}) + + self.assertRaises(InvalidDocumentError, raise_invalid_document) + + def test_reverse_delete_rule_cascade_and_nullify(self): + """Ensure that a referenced document is also deleted upon deletion. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + reviewer = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) + + self.Person.drop_collection() + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + reviewer = self.Person(name='Re Viewer') + reviewer.save() + + post = BlogPost(content='Watched some TV') + post.author = author + post.reviewer = reviewer + post.save() + + reviewer.delete() + self.assertEqual(BlogPost.objects.count(), 1) # No effect on the BlogPost + self.assertEqual(BlogPost.objects.get().reviewer, None) + + # Delete the Person, which should lead to deletion of the BlogPost, too + author.delete() + self.assertEqual(BlogPost.objects.count(), 0) + + def test_reverse_delete_rule_with_document_inheritance(self): + """Ensure that a referenced document is also deleted upon deletion + of a child document. + """ + + class Writer(self.Person): + pass + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + reviewer = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) + + self.Person.drop_collection() + BlogPost.drop_collection() + + author = Writer(name='Test User') + author.save() + + reviewer = Writer(name='Re Viewer') + reviewer.save() + + post = BlogPost(content='Watched some TV') + post.author = author + post.reviewer = reviewer + post.save() + + reviewer.delete() + self.assertEqual(BlogPost.objects.count(), 1) + self.assertEqual(BlogPost.objects.get().reviewer, None) + + # Delete the Writer should lead to deletion of the BlogPost + author.delete() + self.assertEqual(BlogPost.objects.count(), 0) + + def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): + """Ensure that a referenced document is also deleted upon deletion for + complex fields. + """ + + class BlogPost(Document): + content = StringField() + authors = ListField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) + reviewers = ListField(ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) + + self.Person.drop_collection() + + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + reviewer = self.Person(name='Re Viewer') + reviewer.save() + + post = BlogPost(content='Watched some TV') + post.authors = [author] + post.reviewers = [reviewer] + post.save() + + # Deleting the reviewer should have no effect on the BlogPost + reviewer.delete() + self.assertEqual(BlogPost.objects.count(), 1) + self.assertEqual(BlogPost.objects.get().reviewers, []) + + # Delete the Person, which should lead to deletion of the BlogPost, too + author.delete() + self.assertEqual(BlogPost.objects.count(), 0) + + def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): + ''' ensure the pre_delete signal is triggered upon a cascading deletion + setup a blog post with content, an author and editor + delete the author which triggers deletion of blogpost via cascade + blog post's pre_delete signal alters an editor attribute + ''' + class Editor(self.Person): + review_queue = IntField(default=0) + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + editor = ReferenceField(Editor) + + @classmethod + def pre_delete(cls, sender, document, **kwargs): + # decrement the docs-to-review count + document.editor.update(dec__review_queue=1) + + signals.pre_delete.connect(BlogPost.pre_delete, sender=BlogPost) + + self.Person.drop_collection() + BlogPost.drop_collection() + Editor.drop_collection() + + author = self.Person(name='Will S.').save() + editor = Editor(name='Max P.', review_queue=1).save() + BlogPost(content='wrote some books', author=author, + editor=editor).save() + + # delete the author, the post is also deleted due to the CASCADE rule + author.delete() + # the pre-delete signal should have decremented the editor's queue + editor = Editor.objects(name='Max P.').get() + self.assertEqual(editor.review_queue, 0) + + def test_two_way_reverse_delete_rule(self): + """Ensure that Bi-Directional relationships work with + reverse_delete_rule + """ + + class Bar(Document): + content = StringField() + foo = ReferenceField('Foo') + + class Foo(Document): + content = StringField() + bar = ReferenceField(Bar) + + Bar.register_delete_rule(Foo, 'bar', NULLIFY) + Foo.register_delete_rule(Bar, 'foo', NULLIFY) + + Bar.drop_collection() + Foo.drop_collection() + + b = Bar(content="Hello") + b.save() + + f = Foo(content="world", bar=b) + f.save() + + b.foo = f + b.save() + + f.delete() + + self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost + self.assertEqual(Bar.objects.get().foo, None) + + def test_invalid_reverse_delete_rules_raise_errors(self): + + def throw_invalid_document_error(): + class Blog(Document): + content = StringField() + authors = MapField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) + reviewers = DictField(field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) + + self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + + def throw_invalid_document_error_embedded(): + class Parents(EmbeddedDocument): + father = ReferenceField('Person', reverse_delete_rule=DENY) + mother = ReferenceField('Person', reverse_delete_rule=DENY) + + self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded) + + def test_reverse_delete_rule_cascade_recurs(self): + """Ensure that a chain of documents is also deleted upon cascaded + deletion. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + + class Comment(Document): + text = StringField() + post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE) + + self.Person.drop_collection() + BlogPost.drop_collection() + Comment.drop_collection() + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content = 'Watched some TV') + post.author = author + post.save() + + comment = Comment(text = 'Kudos.') + comment.post = post + comment.save() + + # Delete the Person, which should lead to deletion of the BlogPost, and, + # recursively to the Comment, too + author.delete() + self.assertEqual(Comment.objects.count(), 0) + + self.Person.drop_collection() + BlogPost.drop_collection() + Comment.drop_collection() + + def test_reverse_delete_rule_deny(self): + """Ensure that a document cannot be referenced if there are still + documents referring to it. + """ + + class BlogPost(Document): + content = StringField() + author = ReferenceField(self.Person, reverse_delete_rule=DENY) + + self.Person.drop_collection() + BlogPost.drop_collection() + + author = self.Person(name='Test User') + author.save() + + post = BlogPost(content = 'Watched some TV') + post.author = author + post.save() + + # Delete the Person should be denied + self.assertRaises(OperationError, author.delete) # Should raise denied error + self.assertEqual(BlogPost.objects.count(), 1) # No objects may have been deleted + self.assertEqual(self.Person.objects.count(), 1) + + # Other users, that don't have BlogPosts must be removable, like normal + author = self.Person(name='Another User') + author.save() + + self.assertEqual(self.Person.objects.count(), 2) + author.delete() + self.assertEqual(self.Person.objects.count(), 1) + + self.Person.drop_collection() + BlogPost.drop_collection() + + def subclasses_and_unique_keys_works(self): + + class A(Document): + pass + + class B(A): + foo = BooleanField(unique=True) + + A.drop_collection() + B.drop_collection() + + A().save() + A().save() + B(foo=True).save() + + self.assertEqual(A.objects.count(), 2) + self.assertEqual(B.objects.count(), 1) + A.drop_collection() + B.drop_collection() + + def test_document_hash(self): + """Test document in list, dict, set + """ + class User(Document): + pass + + class BlogPost(Document): + pass + + # Clear old datas + User.drop_collection() + BlogPost.drop_collection() + + u1 = User.objects.create() + u2 = User.objects.create() + u3 = User.objects.create() + u4 = User() # New object + + b1 = BlogPost.objects.create() + b2 = BlogPost.objects.create() + + # in List + all_user_list = list(User.objects.all()) + + self.assertTrue(u1 in all_user_list) + self.assertTrue(u2 in all_user_list) + self.assertTrue(u3 in all_user_list) + self.assertFalse(u4 in all_user_list) # New object + self.assertFalse(b1 in all_user_list) # Other object + self.assertFalse(b2 in all_user_list) # Other object + + # in Dict + all_user_dic = {} + for u in User.objects.all(): + all_user_dic[u] = "OK" + + self.assertEqual(all_user_dic.get(u1, False), "OK") + self.assertEqual(all_user_dic.get(u2, False), "OK") + self.assertEqual(all_user_dic.get(u3, False), "OK") + self.assertEqual(all_user_dic.get(u4, False), False) # New object + self.assertEqual(all_user_dic.get(b1, False), False) # Other object + self.assertEqual(all_user_dic.get(b2, False), False) # Other object + + # in Set + all_user_set = set(User.objects.all()) + + self.assertTrue(u1 in all_user_set) + + def test_picklable(self): + + pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) + pickle_doc.embedded = PickleEmbedded() + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEqual(resurrected, pickle_doc) + + # Test pickling changed data + pickle_doc.lists.append("3") + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEqual(resurrected, pickle_doc) + resurrected.string = "Two" + resurrected.save() + + pickle_doc = PickleTest.objects.first() + self.assertEqual(resurrected, pickle_doc) + self.assertEqual(pickle_doc.string, "Two") + self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) + + def test_throw_invalid_document_error(self): + + # test handles people trying to upsert + def throw_invalid_document_error(): + class Blog(Document): + validate = DictField() + + self.assertRaises(InvalidDocumentError, throw_invalid_document_error) + + def test_mutating_documents(self): + + class B(EmbeddedDocument): + field1 = StringField(default='field1') + + class A(Document): + b = EmbeddedDocumentField(B, default=lambda: B()) + + A.drop_collection() + a = A() + a.save() + a.reload() + self.assertEqual(a.b.field1, 'field1') + + class C(EmbeddedDocument): + c_field = StringField(default='cfield') + + class B(EmbeddedDocument): + field1 = StringField(default='field1') + field2 = EmbeddedDocumentField(C, default=lambda: C()) + + class A(Document): + b = EmbeddedDocumentField(B, default=lambda: B()) + + a = A.objects()[0] + a.b.field2.c_field = 'new value' + a.save() + + a.reload() + self.assertEqual(a.b.field2.c_field, 'new value') + + def test_can_save_false_values(self): + """Ensures you can save False values on save""" + class Doc(Document): + foo = StringField() + archived = BooleanField(default=False, required=True) + + Doc.drop_collection() + d = Doc() + d.save() + d.archived = False + d.save() + + self.assertEqual(Doc.objects(archived=False).count(), 1) + + def test_can_save_false_values_dynamic(self): + """Ensures you can save False values on dynamic docs""" + class Doc(DynamicDocument): + foo = StringField() + + Doc.drop_collection() + d = Doc() + d.save() + d.archived = False + d.save() + + self.assertEqual(Doc.objects(archived=False).count(), 1) + + def test_do_not_save_unchanged_references(self): + """Ensures cascading saves dont auto update""" + class Job(Document): + name = StringField() + + class Person(Document): + name = StringField() + age = IntField() + job = ReferenceField(Job) + + Job.drop_collection() + Person.drop_collection() + + job = Job(name="Job 1") + # job should not have any changed fields after the save + job.save() + + person = Person(name="name", age=10, job=job) + + from pymongo.collection import Collection + orig_update = Collection.update + try: + def fake_update(*args, **kwargs): + self.fail("Unexpected update for %s" % args[0].name) + return orig_update(*args, **kwargs) + + Collection.update = fake_update + person.save() + finally: + Collection.update = orig_update + + def test_db_alias_tests(self): + """ DB Alias tests """ + # mongoenginetest - Is default connection alias from setUp() + # Register Aliases + register_connection('testdb-1', 'mongoenginetest2') + register_connection('testdb-2', 'mongoenginetest3') + register_connection('testdb-3', 'mongoenginetest4') + + class User(Document): + name = StringField() + meta = {"db_alias": "testdb-1"} + + class Book(Document): + name = StringField() + meta = {"db_alias": "testdb-2"} + + # Drops + User.drop_collection() + Book.drop_collection() + + # Create + bob = User.objects.create(name="Bob") + hp = Book.objects.create(name="Harry Potter") + + # Selects + self.assertEqual(User.objects.first(), bob) + self.assertEqual(Book.objects.first(), hp) + + # DeReference + class AuthorBooks(Document): + author = ReferenceField(User) + book = ReferenceField(Book) + meta = {"db_alias": "testdb-3"} + + # Drops + AuthorBooks.drop_collection() + + ab = AuthorBooks.objects.create(author=bob, book=hp) + + # select + self.assertEqual(AuthorBooks.objects.first(), ab) + self.assertEqual(AuthorBooks.objects.first().book, hp) + self.assertEqual(AuthorBooks.objects.first().author, bob) + self.assertEqual(AuthorBooks.objects.filter(author=bob).first(), ab) + self.assertEqual(AuthorBooks.objects.filter(book=hp).first(), ab) + + # DB Alias + self.assertEqual(User._get_db(), get_db("testdb-1")) + self.assertEqual(Book._get_db(), get_db("testdb-2")) + self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) + + # Collections + self.assertEqual(User._get_collection(), get_db("testdb-1")[User._get_collection_name()]) + self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]) + self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()]) + + def test_db_alias_overrides(self): + """db_alias can be overriden + """ + # Register a connection with db_alias testdb-2 + register_connection('testdb-2', 'mongoenginetest2') + + class A(Document): + """Uses default db_alias + """ + name = StringField() + meta = {"allow_inheritance": True} + + class B(A): + """Uses testdb-2 db_alias + """ + meta = {"db_alias": "testdb-2"} + + A.objects.all() + + self.assertEquals('testdb-2', B._meta.get('db_alias')) + self.assertEquals('mongoenginetest', + A._get_collection().database.name) + self.assertEquals('mongoenginetest2', + B._get_collection().database.name) + + def test_db_alias_propagates(self): + """db_alias propagates? + """ + register_connection('testdb-1', 'mongoenginetest2') + + class A(Document): + name = StringField() + meta = {"db_alias": "testdb-1", "allow_inheritance": True} + + class B(A): + pass + + self.assertEqual('testdb-1', B._meta.get('db_alias')) + + def test_db_ref_usage(self): + """ DB Ref usage in dict_fields""" + + class User(Document): + name = StringField() + + class Book(Document): + name = StringField() + author = ReferenceField(User) + extra = DictField() + meta = { + 'ordering': ['+name'] + } + + def __unicode__(self): + return self.name + + def __str__(self): + return self.name + + # Drops + User.drop_collection() + Book.drop_collection() + + # Authors + bob = User.objects.create(name="Bob") + jon = User.objects.create(name="Jon") + + # Redactors + karl = User.objects.create(name="Karl") + susan = User.objects.create(name="Susan") + peter = User.objects.create(name="Peter") + + # Bob + Book.objects.create(name="1", author=bob, extra={ + "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) + Book.objects.create(name="2", author=bob, extra={ + "a": bob.to_dbref(), "b": karl.to_dbref()}) + Book.objects.create(name="3", author=bob, extra={ + "a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) + Book.objects.create(name="4", author=bob) + + # Jon + Book.objects.create(name="5", author=jon) + Book.objects.create(name="6", author=peter) + Book.objects.create(name="7", author=jon) + Book.objects.create(name="8", author=jon) + Book.objects.create(name="9", author=jon, + extra={"a": peter.to_dbref()}) + + # Checks + self.assertEqual(",".join([str(b) for b in Book.objects.all()]), + "1,2,3,4,5,6,7,8,9") + # bob related books + self.assertEqual(",".join([str(b) for b in Book.objects.filter( + Q(extra__a=bob) | + Q(author=bob) | + Q(extra__b=bob))]), + "1,2,3,4") + + # Susan & Karl related books + self.assertEqual(",".join([str(b) for b in Book.objects.filter( + Q(extra__a__all=[karl, susan]) | + Q(author__all=[karl, susan]) | + Q(extra__b__all=[ + karl.to_dbref(), susan.to_dbref()])) + ]), "1") + + # $Where + self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + __raw__={ + "$where": """ + function(){ + return this.name == '1' || + this.name == '2';}""" + })]), + "1,2") + + def test_switch_db_instance(self): + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + with switch_db(Group, 'testdb-1') as Group: + Group.drop_collection() + + Group(name="hello - default").save() + self.assertEqual(1, Group.objects.count()) + + group = Group.objects.first() + group.switch_db('testdb-1') + group.name = "hello - testdb!" + group.save() + + with switch_db(Group, 'testdb-1') as Group: + group = Group.objects.first() + self.assertEqual("hello - testdb!", group.name) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + + # Slightly contrived now - perform an update + # Only works as they have the same object_id + group.switch_db('testdb-1') + group.update(set__name="hello - update") + + with switch_db(Group, 'testdb-1') as Group: + group = Group.objects.first() + self.assertEqual("hello - update", group.name) + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + + # Totally contrived now - perform a delete + # Only works as they have the same object_id + group.switch_db('testdb-1') + group.delete() + + with switch_db(Group, 'testdb-1') as Group: + self.assertEqual(0, Group.objects.count()) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + + def test_no_overwritting_no_data_loss(self): + + class User(Document): + username = StringField(primary_key=True) + name = StringField() + + @property + def foo(self): + return True + + User.drop_collection() + + user = User(username="Ross", foo="bar") + self.assertTrue(user.foo) + + User._get_collection().save({"_id": "Ross", "foo": "Bar", + "data": [1, 2, 3]}) + + user = User.objects.first() + self.assertEqual("Ross", user.username) + self.assertEqual(True, user.foo) + self.assertEqual("Bar", user._data["foo"]) + self.assertEqual([1, 2, 3], user._data["data"]) + + def test_spaces_in_keys(self): + + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + setattr(doc, 'hello world', 1) + doc.save() + + one = Doc.objects.filter(**{'hello world': 1}).count() + self.assertEqual(1, one) + + def test_shard_key(self): + class LogEntry(Document): + machine = StringField() + log = StringField() + + meta = { + 'shard_key': ('machine',) + } + + LogEntry.drop_collection() + + log = LogEntry() + log.machine = "Localhost" + log.save() + + log.log = "Saving" + log.save() + + def change_shard_key(): + log.machine = "127.0.0.1" + + self.assertRaises(OperationError, change_shard_key) + + def test_shard_key_primary(self): + class LogEntry(Document): + machine = StringField(primary_key=True) + log = StringField() + + meta = { + 'shard_key': ('machine',) + } + + LogEntry.drop_collection() + + log = LogEntry() + log.machine = "Localhost" + log.save() + + log.log = "Saving" + log.save() + + def change_shard_key(): + log.machine = "127.0.0.1" + + self.assertRaises(OperationError, change_shard_key) + + def test_kwargs_simple(self): + + class Embedded(EmbeddedDocument): + name = StringField() + + class Doc(Document): + doc_name = StringField() + doc = EmbeddedDocumentField(Embedded) + + classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) + dict_doc = Doc(**{"doc_name": "my doc", + "doc": {"name": "embedded doc"}}) + + self.assertEqual(classic_doc, dict_doc) + self.assertEqual(classic_doc._data, dict_doc._data) + + def test_kwargs_complex(self): + + class Embedded(EmbeddedDocument): + name = StringField() + + class Doc(Document): + doc_name = StringField() + docs = ListField(EmbeddedDocumentField(Embedded)) + + classic_doc = Doc(doc_name="my doc", docs=[ + Embedded(name="embedded doc1"), + Embedded(name="embedded doc2")]) + dict_doc = Doc(**{"doc_name": "my doc", + "docs": [{"name": "embedded doc1"}, + {"name": "embedded doc2"}]}) + + self.assertEqual(classic_doc, dict_doc) + self.assertEqual(classic_doc._data, dict_doc._data) + + def test_positional_creation(self): + """Ensure that document may be created using positional arguments. + """ + person = self.Person("Test User", 42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + + def test_mixed_creation(self): + """Ensure that document may be created using mixed arguments. + """ + person = self.Person("Test User", age=42) + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 42) + + def test_bad_mixed_creation(self): + """Ensure that document gives correct error when duplicating arguments + """ + def construct_bad_instance(): + return self.Person("Test User", 42, name="Bad User") + + self.assertRaises(TypeError, construct_bad_instance) + + def test_data_contains_id_field(self): + """Ensure that asking for _data returns 'id' + """ + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name="Harry Potter").save() + + person = Person.objects.first() + self.assertTrue('id' in person._data.keys()) + self.assertEqual(person._data.get('id'), person.id) + + def test_complex_nesting_document_and_embedded_document(self): + + class Macro(EmbeddedDocument): + value = DynamicField(default="UNDEFINED") + + class Parameter(EmbeddedDocument): + macros = MapField(EmbeddedDocumentField(Macro)) + + def expand(self): + self.macros["test"] = Macro() + + class Node(Document): + parameters = MapField(EmbeddedDocumentField(Parameter)) + + def expand(self): + self.flattened_parameter = {} + for parameter_name, parameter in self.parameters.iteritems(): + parameter.expand() + + class System(Document): + name = StringField(required=True) + nodes = MapField(ReferenceField(Node, dbref=False)) + + def save(self, *args, **kwargs): + for node_name, node in self.nodes.iteritems(): + node.expand() + node.save(*args, **kwargs) + super(System, self).save(*args, **kwargs) + + System.drop_collection() + Node.drop_collection() + + system = System(name="system") + system.nodes["node"] = Node() + system.save() + system.nodes["node"].parameters["param"] = Parameter() + system.save() + + system = System.objects.first() + self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py new file mode 100644 index 00000000..dbc09d83 --- /dev/null +++ b/tests/document/json_serialisation.py @@ -0,0 +1,81 @@ +import sys +sys.path[0:0] = [""] + +import unittest +import uuid + +from nose.plugins.skip import SkipTest +from datetime import datetime +from bson import ObjectId + +import pymongo + +from mongoengine import * + +__all__ = ("TestJson",) + + +class TestJson(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + def test_json_simple(self): + + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + string = StringField() + embedded_field = EmbeddedDocumentField(Embedded) + + doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) + + self.assertEqual(doc, Doc.from_json(doc.to_json())) + + def test_json_complex(self): + + if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: + raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") + + class EmbeddedDoc(EmbeddedDocument): + pass + + class Simple(Document): + pass + + class Doc(Document): + string_field = StringField(default='1') + int_field = IntField(default=1) + float_field = FloatField(default=1.1) + boolean_field = BooleanField(default=True) + datetime_field = DateTimeField(default=datetime.now) + embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, + default=lambda: EmbeddedDoc()) + list_field = ListField(default=lambda: [1, 2, 3]) + dict_field = DictField(default=lambda: {"hello": "world"}) + objectid_field = ObjectIdField(default=ObjectId) + reference_field = ReferenceField(Simple, default=lambda: + Simple().save()) + map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) + complex_datetime_field = ComplexDateTimeField(default=datetime.now) + url_field = URLField(default="http://mongoengine.org") + dynamic_field = DynamicField(default=1) + generic_reference_field = GenericReferenceField( + default=lambda: Simple().save()) + sorted_list_field = SortedListField(IntField(), + default=lambda: [1, 2, 3]) + email_field = EmailField(default="ross@example.com") + geo_point_field = GeoPointField(default=lambda: [1, 2]) + sequence_field = SequenceField() + uuid_field = UUIDField(default=uuid.uuid4) + generic_embedded_document_field = GenericEmbeddedDocumentField( + default=lambda: EmbeddedDoc()) + + doc = Doc() + self.assertEqual(doc, Doc.from_json(doc.to_json())) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/validation.py b/tests/document/validation.py new file mode 100644 index 00000000..d3f3fd70 --- /dev/null +++ b/tests/document/validation.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] + +import unittest +from datetime import datetime + +from mongoengine import * + +__all__ = ("ValidatorErrorTest",) + + +class ValidatorErrorTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + def test_to_dict(self): + """Ensure a ValidationError handles error to_dict correctly. + """ + error = ValidationError('root') + self.assertEqual(error.to_dict(), {}) + + # 1st level error schema + error.errors = {'1st': ValidationError('bad 1st'), } + self.assertTrue('1st' in error.to_dict()) + self.assertEqual(error.to_dict()['1st'], 'bad 1st') + + # 2nd level error schema + error.errors = {'1st': ValidationError('bad 1st', errors={ + '2nd': ValidationError('bad 2nd'), + })} + self.assertTrue('1st' in error.to_dict()) + self.assertTrue(isinstance(error.to_dict()['1st'], dict)) + self.assertTrue('2nd' in error.to_dict()['1st']) + self.assertEqual(error.to_dict()['1st']['2nd'], 'bad 2nd') + + # moar levels + error.errors = {'1st': ValidationError('bad 1st', errors={ + '2nd': ValidationError('bad 2nd', errors={ + '3rd': ValidationError('bad 3rd', errors={ + '4th': ValidationError('Inception'), + }), + }), + })} + self.assertTrue('1st' in error.to_dict()) + self.assertTrue('2nd' in error.to_dict()['1st']) + self.assertTrue('3rd' in error.to_dict()['1st']['2nd']) + self.assertTrue('4th' in error.to_dict()['1st']['2nd']['3rd']) + self.assertEqual(error.to_dict()['1st']['2nd']['3rd']['4th'], + 'Inception') + + self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") + + def test_model_validation(self): + + class User(Document): + username = StringField(primary_key=True) + name = StringField(required=True) + + try: + User().validate() + except ValidationError, e: + self.assertTrue("User:None" in e.message) + self.assertEqual(e.to_dict(), { + 'username': 'Field is required', + 'name': 'Field is required'}) + + user = User(username="RossC0", name="Ross").save() + user.name = None + try: + user.save() + except ValidationError, e: + self.assertTrue("User:RossC0" in e.message) + self.assertEqual(e.to_dict(), { + 'name': 'Field is required'}) + + def test_fields_rewrite(self): + class BasePerson(Document): + name = StringField() + age = IntField() + meta = {'abstract': True} + + class Person(BasePerson): + name = StringField(required=True) + + p = Person(age=15) + self.assertRaises(ValidationError, p.validate) + + def test_embedded_document_validation(self): + """Ensure that embedded documents may be validated. + """ + class Comment(EmbeddedDocument): + date = DateTimeField() + content = StringField(required=True) + + comment = Comment() + self.assertRaises(ValidationError, comment.validate) + + comment.content = 'test' + comment.validate() + + comment.date = 4 + self.assertRaises(ValidationError, comment.validate) + + comment.date = datetime.now() + comment.validate() + self.assertEqual(comment._instance, None) + + def test_embedded_db_field_validate(self): + + class SubDoc(EmbeddedDocument): + val = IntField(required=True) + + class Doc(Document): + id = StringField(primary_key=True) + e = EmbeddedDocumentField(SubDoc, db_field='eb') + + try: + Doc(id="bad").validate() + except ValidationError, e: + self.assertTrue("SubDoc:None" in e.message) + self.assertEqual(e.to_dict(), { + "e": {'val': 'OK could not be converted to int'}}) + + Doc.drop_collection() + + Doc(id="test", e=SubDoc(val=15)).save() + + doc = Doc.objects.first() + keys = doc._data.keys() + self.assertEqual(2, len(keys)) + self.assertTrue('e' in keys) + self.assertTrue('id' in keys) + + doc.e.val = "OK" + try: + doc.save() + except ValidationError, e: + self.assertTrue("Doc:test" in e.message) + self.assertEqual(e.to_dict(), { + "e": {'val': 'OK could not be converted to int'}}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fields/__init__.py b/tests/fields/__init__.py new file mode 100644 index 00000000..0731838b --- /dev/null +++ b/tests/fields/__init__.py @@ -0,0 +1,2 @@ +from fields import * +from file_tests import * \ No newline at end of file diff --git a/tests/test_fields.py b/tests/fields/fields.py similarity index 85% rename from tests/test_fields.py rename to tests/fields/fields.py index 28af1b23..4fa6989c 100644 --- a/tests/test_fields.py +++ b/tests/fields/fields.py @@ -1,23 +1,24 @@ # -*- coding: utf-8 -*- from __future__ import with_statement +import sys +sys.path[0:0] = [""] + import datetime -import os import unittest import uuid -import tempfile from decimal import Decimal from bson import Binary, DBRef, ObjectId -import gridfs -from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.base import _document_registry, NotRegistered -from mongoengine.python_support import PY3, b, StringIO, bin_type +from mongoengine.base import _document_registry +from mongoengine.errors import NotRegistered +from mongoengine.python_support import PY3, b, bin_type + +__all__ = ("FieldTest", ) -TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') class FieldTest(unittest.TestCase): @@ -144,6 +145,17 @@ class FieldTest(unittest.TestCase): self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) + def test_long_ne_operator(self): + class TestDocument(Document): + long_fld = LongField() + + TestDocument.drop_collection() + + TestDocument(long_fld=None).save() + TestDocument(long_fld=1).save() + + self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count()) + def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. """ @@ -217,6 +229,23 @@ class FieldTest(unittest.TestCase): person.age = 'ten' self.assertRaises(ValidationError, person.validate) + def test_long_validation(self): + """Ensure that invalid values cannot be assigned to long fields. + """ + class TestDocument(Document): + value = LongField(min_value=0, max_value=110) + + doc = TestDocument() + doc.value = 50 + doc.validate() + + doc.value = -1 + self.assertRaises(ValidationError, doc.validate) + doc.age = 120 + self.assertRaises(ValidationError, doc.validate) + doc.age = 'ten' + self.assertRaises(ValidationError, doc.validate) + def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. """ @@ -243,10 +272,8 @@ class FieldTest(unittest.TestCase): Person.drop_collection() - person = Person() - person.height = Decimal('1.89') - person.save() - person.reload() + Person(height=Decimal('1.89')).save() + person = Person.objects.first() self.assertEqual(person.height, Decimal('1.89')) person.height = '2.0' @@ -260,6 +287,45 @@ class FieldTest(unittest.TestCase): Person.drop_collection() + def test_decimal_comparison(self): + + class Person(Document): + money = DecimalField() + + Person.drop_collection() + + Person(money=6).save() + Person(money=8).save() + Person(money=10).save() + + self.assertEqual(2, Person.objects(money__gt=Decimal("7")).count()) + self.assertEqual(2, Person.objects(money__gt=7).count()) + self.assertEqual(2, Person.objects(money__gt="7").count()) + + def test_decimal_storage(self): + class Person(Document): + btc = DecimalField(precision=4) + + Person.drop_collection() + Person(btc=10).save() + Person(btc=10.1).save() + Person(btc=10.11).save() + Person(btc="10.111").save() + Person(btc=Decimal("10.1111")).save() + Person(btc=Decimal("10.11111")).save() + + # How its stored + expected = [{'btc': 10.0}, {'btc': 10.1}, {'btc': 10.11}, + {'btc': 10.111}, {'btc': 10.1111}, {'btc': 10.1111}] + actual = list(Person.objects.exclude('id').as_pymongo()) + self.assertEqual(expected, actual) + + # How it comes out locally + expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), + Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + actual = list(Person.objects().scalar('btc')) + self.assertEqual(expected, actual) + def test_boolean_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ @@ -325,7 +391,6 @@ class FieldTest(unittest.TestCase): person.api_key = api_key self.assertRaises(ValidationError, person.validate) - def test_datetime_validation(self): """Ensure that invalid values cannot be assigned to datetime fields. """ @@ -602,7 +667,8 @@ class FieldTest(unittest.TestCase): name = StringField() class CategoryList(Document): - categories = SortedListField(EmbeddedDocumentField(Category), ordering='count', reverse=True) + categories = SortedListField(EmbeddedDocumentField(Category), + ordering='count', reverse=True) name = StringField() catlist = CategoryList(name="Top categories") @@ -726,7 +792,7 @@ class FieldTest(unittest.TestCase): """Ensure that the list fields can handle the complex types.""" class SettingBase(EmbeddedDocument): - pass + meta = {'allow_inheritance': True} class StringSetting(SettingBase): value = StringField() @@ -742,8 +808,9 @@ class FieldTest(unittest.TestCase): e.mapping.append(StringSetting(value='foo')) e.mapping.append(IntegerSetting(value=42)) e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, - 'complex': IntegerSetting(value=42), 'list': - [IntegerSetting(value=42), StringSetting(value='foo')]}) + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]}) e.save() e2 = Simple.objects.get(id=e.id) @@ -843,7 +910,7 @@ class FieldTest(unittest.TestCase): """Ensure that the dict field can handle the complex types.""" class SettingBase(EmbeddedDocument): - pass + meta = {'allow_inheritance': True} class StringSetting(SettingBase): value = StringField() @@ -858,9 +925,11 @@ class FieldTest(unittest.TestCase): e = Simple() e.mapping['somestring'] = StringSetting(value='foo') e.mapping['someint'] = IntegerSetting(value=42) - e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, - 'complex': IntegerSetting(value=42), 'list': - [IntegerSetting(value=42), StringSetting(value='foo')]} + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', + 'float': 1.001, + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]} e.save() e2 = Simple.objects.get(id=e.id) @@ -914,7 +983,7 @@ class FieldTest(unittest.TestCase): """Ensure that the MapField can handle complex declared types.""" class SettingBase(EmbeddedDocument): - pass + meta = {"allow_inheritance": True} class StringSetting(SettingBase): value = StringField() @@ -950,7 +1019,8 @@ class FieldTest(unittest.TestCase): number = IntField(default=0, db_field='i') class Test(Document): - my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field='x') + my_map = MapField(field=EmbeddedDocumentField(Embedded), + db_field='x') Test.drop_collection() @@ -965,6 +1035,24 @@ class FieldTest(unittest.TestCase): doc = self.db.test.find_one() self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + def test_mapfield_numerical_index(self): + """Ensure that MapField accept numeric strings as indexes.""" + class Embedded(EmbeddedDocument): + name = StringField() + + class Test(Document): + my_map = MapField(EmbeddedDocumentField(Embedded)) + + Test.drop_collection() + + test = Test() + test.my_map['1'] = Embedded(name='test') + test.save() + test.my_map['1'].name = 'test updated' + test.save() + + Test.drop_collection() + def test_map_field_lookup(self): """Ensure MapField lookups succeed on Fields without a lookup method""" @@ -1037,6 +1125,8 @@ class FieldTest(unittest.TestCase): class User(EmbeddedDocument): name = StringField() + meta = {'allow_inheritance': True} + class PowerUser(User): power = IntField() @@ -1045,8 +1135,10 @@ class FieldTest(unittest.TestCase): author = EmbeddedDocumentField(User) post = BlogPost(content='What I did today...') - post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) + post.save() + + self.assertEqual(47, BlogPost.objects.first().author.power) def test_reference_validation(self): """Ensure that invalid docment objects cannot be assigned to reference @@ -1614,8 +1706,9 @@ class FieldTest(unittest.TestCase): """Ensure that value is in a container of allowed values. """ class Shirt(Document): - size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + size = StringField(max_length=3, choices=( + ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) Shirt.drop_collection() @@ -1631,12 +1724,15 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() def test_choices_get_field_display(self): - """Test dynamic helper for returning the display value of a choices field. + """Test dynamic helper for returning the display value of a choices + field. """ class Shirt(Document): - size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) - style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') + size = StringField(max_length=3, choices=( + ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + style = StringField(max_length=3, choices=( + ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') Shirt.drop_collection() @@ -1663,7 +1759,8 @@ class FieldTest(unittest.TestCase): """Ensure that value is in a container of allowed values. """ class Shirt(Document): - size = StringField(max_length=3, choices=('S', 'M', 'L', 'XL', 'XXL')) + size = StringField(max_length=3, + choices=('S', 'M', 'L', 'XL', 'XXL')) Shirt.drop_collection() @@ -1679,11 +1776,15 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() def test_simple_choices_get_field_display(self): - """Test dynamic helper for returning the display value of a choices field. + """Test dynamic helper for returning the display value of a choices + field. """ class Shirt(Document): - size = StringField(max_length=3, choices=('S', 'M', 'L', 'XL', 'XXL')) - style = StringField(max_length=3, choices=('Small', 'Baggy', 'wide'), default='Small') + size = StringField(max_length=3, + choices=('S', 'M', 'L', 'XL', 'XXL')) + style = StringField(max_length=3, + choices=('Small', 'Baggy', 'wide'), + default='Small') Shirt.drop_collection() @@ -1706,303 +1807,39 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() - def test_file_fields(self): - """Ensure that file fields can be written to and their data retrieved + def test_simple_choices_validation_invalid_value(self): + """Ensure that error messages are correct. """ - class PutFile(Document): - the_file = FileField() + SIZES = ('S', 'M', 'L', 'XL', 'XXL') + COLORS = (('R', 'Red'), ('B', 'Blue')) + SIZE_MESSAGE = u"Value must be one of ('S', 'M', 'L', 'XL', 'XXL')" + COLOR_MESSAGE = u"Value must be one of ['R', 'B']" - class StreamFile(Document): - the_file = FileField() + class Shirt(Document): + size = StringField(max_length=3, choices=SIZES) + color = StringField(max_length=1, choices=COLORS) - class SetFile(Document): - the_file = FileField() + Shirt.drop_collection() - text = b('Hello, World!') - more_text = b('Foo Bar') - content_type = 'text/plain' + shirt = Shirt() + shirt.validate() - PutFile.drop_collection() - StreamFile.drop_collection() - SetFile.drop_collection() + shirt.size = "S" + shirt.color = "R" + shirt.validate() - putfile = PutFile() - putfile.the_file.put(text, content_type=content_type) - putfile.save() - putfile.validate() - result = PutFile.objects.first() - self.assertTrue(putfile == result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.delete() # Remove file from GridFS - PutFile.objects.delete() + shirt.size = "XS" + shirt.color = "G" - # Ensure file-like objects are stored - putfile = PutFile() - putstring = StringIO() - putstring.write(text) - putstring.seek(0) - putfile.the_file.put(putstring, content_type=content_type) - putfile.save() - putfile.validate() - result = PutFile.objects.first() - self.assertTrue(putfile == result) - self.assertEqual(result.the_file.read(), text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.delete() + try: + shirt.validate() + except ValidationError, error: + # get the validation rules + error_dict = error.to_dict() + self.assertEqual(error_dict['size'], SIZE_MESSAGE) + self.assertEqual(error_dict['color'], COLOR_MESSAGE) - streamfile = StreamFile() - streamfile.the_file.new_file(content_type=content_type) - streamfile.the_file.write(text) - streamfile.the_file.write(more_text) - streamfile.the_file.close() - streamfile.save() - streamfile.validate() - result = StreamFile.objects.first() - self.assertTrue(streamfile == result) - self.assertEqual(result.the_file.read(), text + more_text) - self.assertEqual(result.the_file.content_type, content_type) - result.the_file.seek(0) - self.assertEqual(result.the_file.tell(), 0) - self.assertEqual(result.the_file.read(len(text)), text) - self.assertEqual(result.the_file.tell(), len(text)) - self.assertEqual(result.the_file.read(len(more_text)), more_text) - self.assertEqual(result.the_file.tell(), len(text + more_text)) - result.the_file.delete() - - # Ensure deleted file returns None - self.assertTrue(result.the_file.read() == None) - - setfile = SetFile() - setfile.the_file = text - setfile.save() - setfile.validate() - result = SetFile.objects.first() - self.assertTrue(setfile == result) - self.assertEqual(result.the_file.read(), text) - - # Try replacing file with new one - result.the_file.replace(more_text) - result.save() - result.validate() - result = SetFile.objects.first() - self.assertTrue(setfile == result) - self.assertEqual(result.the_file.read(), more_text) - result.the_file.delete() - - PutFile.drop_collection() - StreamFile.drop_collection() - SetFile.drop_collection() - - # Make sure FileField is optional and not required - class DemoFile(Document): - the_file = FileField() - DemoFile.objects.create() - - - def test_file_field_no_default(self): - - class GridDocument(Document): - the_file = FileField() - - GridDocument.drop_collection() - - with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) - f.flush() - - # Test without default - doc_a = GridDocument() - doc_a.save() - - - doc_b = GridDocument.objects.with_id(doc_a.id) - doc_b.the_file.replace(f, filename='doc_b') - doc_b.save() - self.assertNotEqual(doc_b.the_file.grid_id, None) - - # Test it matches - doc_c = GridDocument.objects.with_id(doc_b.id) - self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) - - # Test with default - doc_d = GridDocument(the_file=b('')) - doc_d.save() - - doc_e = GridDocument.objects.with_id(doc_d.id) - self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) - - doc_e.the_file.replace(f, filename='doc_e') - doc_e.save() - - doc_f = GridDocument.objects.with_id(doc_e.id) - self.assertEqual(doc_e.the_file.grid_id, doc_f.the_file.grid_id) - - db = GridDocument._get_db() - grid_fs = gridfs.GridFS(db) - self.assertEqual(['doc_b', 'doc_e'], grid_fs.list()) - - def test_file_uniqueness(self): - """Ensure that each instance of a FileField is unique - """ - class TestFile(Document): - name = StringField() - the_file = FileField() - - # First instance - test_file = TestFile() - test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!')) - test_file.save() - - # Second instance - test_file_dupe = TestFile() - data = test_file_dupe.the_file.read() # Should be None - - self.assertTrue(test_file.name != test_file_dupe.name) - self.assertTrue(test_file.the_file.read() != data) - - TestFile.drop_collection() - - def test_file_boolean(self): - """Ensure that a boolean test of a FileField indicates its presence - """ - class TestFile(Document): - the_file = FileField() - - test_file = TestFile() - self.assertFalse(bool(test_file.the_file)) - test_file.the_file = b('Hello, World!') - test_file.the_file.content_type = 'text/plain' - test_file.save() - self.assertTrue(bool(test_file.the_file)) - - TestFile.drop_collection() - - def test_file_cmp(self): - """Test comparing against other types""" - class TestFile(Document): - the_file = FileField() - - test_file = TestFile() - self.assertFalse(test_file.the_file in [{"test": 1}]) - - def test_image_field(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField() - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'r')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - - w, h = t.image.size - self.assertEqual(w, 371) - self.assertEqual(h, 76) - - t.image.delete() - - def test_image_field_resize(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(size=(185, 37)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'r')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - w, h = t.image.size - - self.assertEqual(w, 185) - self.assertEqual(h, 37) - - t.image.delete() - - def test_image_field_resize_force(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(size=(185, 37, True)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'r')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.format, 'PNG') - w, h = t.image.size - - self.assertEqual(w, 185) - self.assertEqual(h, 37) - - t.image.delete() - - def test_image_field_thumbnail(self): - if PY3: - raise SkipTest('PIL does not have Python 3 support') - - class TestImage(Document): - image = ImageField(thumbnail_size=(92, 18)) - - TestImage.drop_collection() - - t = TestImage() - t.image.put(open(TEST_IMAGE_PATH, 'r')) - t.save() - - t = TestImage.objects.first() - - self.assertEqual(t.image.thumbnail.format, 'PNG') - self.assertEqual(t.image.thumbnail.width, 92) - self.assertEqual(t.image.thumbnail.height, 18) - - t.image.delete() - - def test_file_multidb(self): - register_connection('test_files', 'test_files') - class TestFile(Document): - name = StringField() - the_file = FileField(db_alias="test_files", - collection_name="macumba") - - TestFile.drop_collection() - - # delete old filesystem - get_db("test_files").macumba.files.drop() - get_db("test_files").macumba.chunks.drop() - - # First instance - test_file = TestFile() - test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!'), - name="hello.txt") - test_file.save() - - data = get_db("test_files").macumba.files.find_one() - self.assertEqual(data.get('name'), 'hello.txt') - - test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), - b('Hello, World!')) + Shirt.drop_collection() def test_geo_indexes(self): """Ensure that indexes are created automatically for GeoPointFields. @@ -2065,8 +1902,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2077,6 +1913,10 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + def test_sequence_field_sequence_name(self): class Person(Document): id = SequenceField(primary_key=True, sequence_name='jelly') @@ -2086,8 +1926,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 10) @@ -2098,6 +1937,10 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) + self.assertEqual(c['next'], 1000) + def test_multiple_sequence_fields(self): class Person(Document): id = SequenceField(primary_key=True) @@ -2108,8 +1951,7 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - p = Person(name="Person %s" % x) - p.save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2123,16 +1965,23 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) + Person.id.set_next_value(1000) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 1000) + + Person.counter.set_next_value(999) + c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) + self.assertEqual(c['next'], 999) + def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() - type = StringField() + name = StringField() self.db['mongoengine.counters'].drop() Animal.drop_collection() - a = Animal(type="Boi") - a.save() + a = Animal(name="Boi").save() self.assertEqual(a.counter, 1) a.reload() @@ -2162,10 +2011,8 @@ class FieldTest(unittest.TestCase): Person.drop_collection() for x in xrange(10): - a = Animal(name="Animal %s" % x) - a.save() - p = Person(name="Person %s" % x) - p.save() + Animal(name="Animal %s" % x).save() + Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) self.assertEqual(c['next'], 10) @@ -2185,6 +2032,27 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) + def test_sequence_field_value_decorator(self): + class Person(Document): + id = SequenceField(primary_key=True, value_decorator=str) + name = StringField() + + self.db['mongoengine.counters'].drop() + Person.drop_collection() + + for x in xrange(10): + p = Person(name="Person %s" % x) + p.save() + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + + ids = [i.id for i in Person.objects] + self.assertEqual(ids, map(str, range(1, 11))) + + c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) + self.assertEqual(c['next'], 10) + def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): id = SequenceField() @@ -2200,13 +2068,13 @@ class FieldTest(unittest.TestCase): Post(title="MongoEngine", comments=[Comment(content="NoSQL Rocks"), Comment(content="MongoEngine Rocks")]).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) self.assertEqual(c['next'], 2) post = Post.objects.first() self.assertEqual(1, post.comments[0].id) self.assertEqual(2, post.comments[1].id) + def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -2325,12 +2193,26 @@ class FieldTest(unittest.TestCase): self.assertTrue(1 in error_dict['comments']) self.assertTrue('content' in error_dict['comments'][1]) self.assertEqual(error_dict['comments'][1]['content'], - u'Field is required') - + u'Field is required') post.comments[1].content = 'here we go' post.validate() + def test_email_field(self): + class User(Document): + email = EmailField() + + user = User(email="ross@example.com") + self.assertTrue(user.validate() is None) + + user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S" + "ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3" + "aJIazqqWkm7.net")) + self.assertTrue(user.validate() is None) + + user = User(email='me@localhost') + self.assertRaises(ValidationError, user.validate) + def test_email_field_honors_regex(self): class User(Document): email = EmailField(regex=r'\w+@example.com') diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py new file mode 100644 index 00000000..c5842d81 --- /dev/null +++ b/tests/fields/file_tests.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import sys +sys.path[0:0] = [""] + +import copy +import os +import unittest +import tempfile + +import gridfs + +from nose.plugins.skip import SkipTest +from mongoengine import * +from mongoengine.connection import get_db +from mongoengine.python_support import PY3, b, StringIO + +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') +TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') + + +class FileTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + self.db.drop_collection('fs.files') + self.db.drop_collection('fs.chunks') + + def test_file_field_optional(self): + # Make sure FileField is optional and not required + class DemoFile(Document): + the_file = FileField() + DemoFile.objects.create() + + def test_file_fields(self): + """Ensure that file fields can be written to and their data retrieved + """ + + class PutFile(Document): + the_file = FileField() + + PutFile.drop_collection() + + text = b('Hello, World!') + content_type = 'text/plain' + + putfile = PutFile() + putfile.the_file.put(text, content_type=content_type) + putfile.save() + + result = PutFile.objects.first() + self.assertTrue(putfile == result) + self.assertEqual(result.the_file.read(), text) + self.assertEqual(result.the_file.content_type, content_type) + result.the_file.delete() # Remove file from GridFS + PutFile.objects.delete() + + # Ensure file-like objects are stored + PutFile.drop_collection() + + putfile = PutFile() + putstring = StringIO() + putstring.write(text) + putstring.seek(0) + putfile.the_file.put(putstring, content_type=content_type) + putfile.save() + + result = PutFile.objects.first() + self.assertTrue(putfile == result) + self.assertEqual(result.the_file.read(), text) + self.assertEqual(result.the_file.content_type, content_type) + result.the_file.delete() + + def test_file_fields_stream(self): + """Ensure that file fields can be written to and their data retrieved + """ + class StreamFile(Document): + the_file = FileField() + + StreamFile.drop_collection() + + text = b('Hello, World!') + more_text = b('Foo Bar') + content_type = 'text/plain' + + streamfile = StreamFile() + streamfile.the_file.new_file(content_type=content_type) + streamfile.the_file.write(text) + streamfile.the_file.write(more_text) + streamfile.the_file.close() + streamfile.save() + + result = StreamFile.objects.first() + self.assertTrue(streamfile == result) + self.assertEqual(result.the_file.read(), text + more_text) + self.assertEqual(result.the_file.content_type, content_type) + result.the_file.seek(0) + self.assertEqual(result.the_file.tell(), 0) + self.assertEqual(result.the_file.read(len(text)), text) + self.assertEqual(result.the_file.tell(), len(text)) + self.assertEqual(result.the_file.read(len(more_text)), more_text) + self.assertEqual(result.the_file.tell(), len(text + more_text)) + result.the_file.delete() + + # Ensure deleted file returns None + self.assertTrue(result.the_file.read() == None) + + def test_file_fields_set(self): + + class SetFile(Document): + the_file = FileField() + + text = b('Hello, World!') + more_text = b('Foo Bar') + + SetFile.drop_collection() + + setfile = SetFile() + setfile.the_file = text + setfile.save() + + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEqual(result.the_file.read(), text) + + # Try replacing file with new one + result.the_file.replace(more_text) + result.save() + + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEqual(result.the_file.read(), more_text) + result.the_file.delete() + + def test_file_field_no_default(self): + + class GridDocument(Document): + the_file = FileField() + + GridDocument.drop_collection() + + with tempfile.TemporaryFile() as f: + f.write(b("Hello World!")) + f.flush() + + # Test without default + doc_a = GridDocument() + doc_a.save() + + doc_b = GridDocument.objects.with_id(doc_a.id) + doc_b.the_file.replace(f, filename='doc_b') + doc_b.save() + self.assertNotEqual(doc_b.the_file.grid_id, None) + + # Test it matches + doc_c = GridDocument.objects.with_id(doc_b.id) + self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) + + # Test with default + doc_d = GridDocument(the_file=b('')) + doc_d.save() + + doc_e = GridDocument.objects.with_id(doc_d.id) + self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) + + doc_e.the_file.replace(f, filename='doc_e') + doc_e.save() + + doc_f = GridDocument.objects.with_id(doc_e.id) + self.assertEqual(doc_e.the_file.grid_id, doc_f.the_file.grid_id) + + db = GridDocument._get_db() + grid_fs = gridfs.GridFS(db) + self.assertEqual(['doc_b', 'doc_e'], grid_fs.list()) + + def test_file_uniqueness(self): + """Ensure that each instance of a FileField is unique + """ + class TestFile(Document): + name = StringField() + the_file = FileField() + + # First instance + test_file = TestFile() + test_file.name = "Hello, World!" + test_file.the_file.put(b('Hello, World!')) + test_file.save() + + # Second instance + test_file_dupe = TestFile() + data = test_file_dupe.the_file.read() # Should be None + + self.assertTrue(test_file.name != test_file_dupe.name) + self.assertTrue(test_file.the_file.read() != data) + + TestFile.drop_collection() + + def test_file_saving(self): + """Ensure you can add meta data to file""" + + class Animal(Document): + genus = StringField() + family = StringField() + photo = FileField() + + Animal.drop_collection() + marmot = Animal(genus='Marmota', family='Sciuridae') + + marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk + marmot.photo.put(marmot_photo, content_type='image/jpeg', foo='bar') + marmot.photo.close() + marmot.save() + + marmot = Animal.objects.get() + self.assertEqual(marmot.photo.content_type, 'image/jpeg') + self.assertEqual(marmot.photo.foo, 'bar') + + def test_file_reassigning(self): + class TestFile(Document): + the_file = FileField() + TestFile.drop_collection() + + test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + self.assertEqual(test_file.the_file.get().length, 8313) + + test_file = TestFile.objects.first() + test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.save() + self.assertEqual(test_file.the_file.get().length, 4971) + + def test_file_boolean(self): + """Ensure that a boolean test of a FileField indicates its presence + """ + class TestFile(Document): + the_file = FileField() + TestFile.drop_collection() + + test_file = TestFile() + self.assertFalse(bool(test_file.the_file)) + test_file.the_file.put(b('Hello, World!'), content_type='text/plain') + test_file.save() + self.assertTrue(bool(test_file.the_file)) + + test_file = TestFile.objects.first() + self.assertEqual(test_file.the_file.content_type, "text/plain") + + def test_file_cmp(self): + """Test comparing against other types""" + class TestFile(Document): + the_file = FileField() + + test_file = TestFile() + self.assertFalse(test_file.the_file in [{"test": 1}]) + + def test_image_field(self): + if PY3: + raise SkipTest('PIL does not have Python 3 support') + + class TestImage(Document): + image = ImageField() + + TestImage.drop_collection() + + t = TestImage() + t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.save() + + t = TestImage.objects.first() + + self.assertEqual(t.image.format, 'PNG') + + w, h = t.image.size + self.assertEqual(w, 371) + self.assertEqual(h, 76) + + t.image.delete() + + def test_image_field_reassigning(self): + if PY3: + raise SkipTest('PIL does not have Python 3 support') + + class TestFile(Document): + the_file = ImageField() + TestFile.drop_collection() + + test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save() + self.assertEqual(test_file.the_file.size, (371, 76)) + + test_file = TestFile.objects.first() + test_file.the_file = open(TEST_IMAGE2_PATH, 'rb') + test_file.save() + self.assertEqual(test_file.the_file.size, (45, 101)) + + def test_image_field_resize(self): + if PY3: + raise SkipTest('PIL does not have Python 3 support') + + class TestImage(Document): + image = ImageField(size=(185, 37)) + + TestImage.drop_collection() + + t = TestImage() + t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.save() + + t = TestImage.objects.first() + + self.assertEqual(t.image.format, 'PNG') + w, h = t.image.size + + self.assertEqual(w, 185) + self.assertEqual(h, 37) + + t.image.delete() + + def test_image_field_resize_force(self): + if PY3: + raise SkipTest('PIL does not have Python 3 support') + + class TestImage(Document): + image = ImageField(size=(185, 37, True)) + + TestImage.drop_collection() + + t = TestImage() + t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.save() + + t = TestImage.objects.first() + + self.assertEqual(t.image.format, 'PNG') + w, h = t.image.size + + self.assertEqual(w, 185) + self.assertEqual(h, 37) + + t.image.delete() + + def test_image_field_thumbnail(self): + if PY3: + raise SkipTest('PIL does not have Python 3 support') + + class TestImage(Document): + image = ImageField(thumbnail_size=(92, 18)) + + TestImage.drop_collection() + + t = TestImage() + t.image.put(open(TEST_IMAGE_PATH, 'rb')) + t.save() + + t = TestImage.objects.first() + + self.assertEqual(t.image.thumbnail.format, 'PNG') + self.assertEqual(t.image.thumbnail.width, 92) + self.assertEqual(t.image.thumbnail.height, 18) + + t.image.delete() + + def test_file_multidb(self): + register_connection('test_files', 'test_files') + + class TestFile(Document): + name = StringField() + the_file = FileField(db_alias="test_files", + collection_name="macumba") + + TestFile.drop_collection() + + # delete old filesystem + get_db("test_files").macumba.files.drop() + get_db("test_files").macumba.chunks.drop() + + # First instance + test_file = TestFile() + test_file.name = "Hello, World!" + test_file.the_file.put(b('Hello, World!'), + name="hello.txt") + test_file.save() + + data = get_db("test_files").macumba.files.find_one() + self.assertEqual(data.get('name'), 'hello.txt') + + test_file = TestFile.objects.first() + self.assertEqual(test_file.the_file.read(), + b('Hello, World!')) + + def test_copyable(self): + class PutFile(Document): + the_file = FileField() + + PutFile.drop_collection() + + text = b('Hello, World!') + content_type = 'text/plain' + + putfile = PutFile() + putfile.the_file.put(text, content_type=content_type) + putfile.save() + + class TestFile(Document): + name = StringField() + + self.assertEqual(putfile, copy.copy(putfile)) + self.assertEqual(putfile, copy.deepcopy(putfile)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fields/mongodb_leaf.png b/tests/fields/mongodb_leaf.png new file mode 100644 index 00000000..36661cef Binary files /dev/null and b/tests/fields/mongodb_leaf.png differ diff --git a/tests/mongoengine.png b/tests/fields/mongoengine.png similarity index 100% rename from tests/mongoengine.png rename to tests/fields/mongoengine.png diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py new file mode 100644 index 00000000..6fc83e02 --- /dev/null +++ b/tests/migration/__init__.py @@ -0,0 +1,8 @@ +from convert_to_new_inheritance_model import * +from decimalfield_as_float import * +from refrencefield_dbref_to_object_id import * +from turn_off_inheritance import * +from uuidfield_to_binary import * + +if __name__ == '__main__': + unittest.main() diff --git a/tests/migration/convert_to_new_inheritance_model.py b/tests/migration/convert_to_new_inheritance_model.py new file mode 100644 index 00000000..89ee9e9d --- /dev/null +++ b/tests/migration/convert_to_new_inheritance_model.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField + +__all__ = ('ConvertToNewInheritanceModel', ) + + +class ConvertToNewInheritanceModel(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_how_to_convert_to_the_new_inheritance_model(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Declaration of the class + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Remove _types + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({'_types': {"$exists": True}}).count() + self.assertEqual(0, count) + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.ensure_indexes() diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py new file mode 100644 index 00000000..3903c913 --- /dev/null +++ b/tests/migration/decimalfield_as_float.py @@ -0,0 +1,50 @@ + # -*- coding: utf-8 -*- +import unittest +import decimal +from decimal import Decimal + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, DecimalField, ListField + +__all__ = ('ConvertDecimalField', ) + + +class ConvertDecimalField(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_decimal_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + money = DecimalField(force_string=True) + monies = ListField(DecimalField(force_string=True)) + + Person.drop_collection() + Person(name="Wilson Jr", money=Decimal("2.50"), + monies=[Decimal("2.10"), Decimal("5.00")]).save() + + # 2. Start the migration by changing the schema + # Change DecimalField - add precision and rounding settings + class Person(Document): + name = StringField() + money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) + monies = ListField(DecimalField(precision=2, + rounding=decimal.ROUND_HALF_UP)) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('money') + p._mark_as_changed('monies') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertTrue(isinstance(wilson['money'], float)) + self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/migration/refrencefield_dbref_to_object_id.py b/tests/migration/refrencefield_dbref_to_object_id.py new file mode 100644 index 00000000..d3acbe92 --- /dev/null +++ b/tests/migration/refrencefield_dbref_to_object_id.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, ReferenceField, ListField + +__all__ = ('ConvertToObjectIdsModel', ) + + +class ConvertToObjectIdsModel(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_to_object_id_reference_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + parent = ReferenceField('self', dbref=True) + friends = ListField(ReferenceField('self', dbref=True)) + + Person.drop_collection() + + p1 = Person(name="Wilson", parent=None).save() + f1 = Person(name="John", parent=None).save() + f2 = Person(name="Paul", parent=None).save() + f3 = Person(name="George", parent=None).save() + f4 = Person(name="Ringo", parent=None).save() + Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save() + + # 2. Start the migration by changing the schema + # Change ReferenceField as now dbref defaults to False + class Person(Document): + name = StringField() + parent = ReferenceField('self') + friends = ListField(ReferenceField('self')) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('parent') + p._mark_as_changed('friends') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertEqual(p1.id, wilson['parent']) + self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends']) diff --git a/tests/migration/turn_off_inheritance.py b/tests/migration/turn_off_inheritance.py new file mode 100644 index 00000000..ee461a84 --- /dev/null +++ b/tests/migration/turn_off_inheritance.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField + +__all__ = ('TurnOffInheritanceTest', ) + + +class TurnOffInheritanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_how_to_turn_off_inheritance(self): + """Demonstrates migrating from allow_inheritance = True to False. + """ + + # 1. Old declaration of the class + + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Turn off inheritance + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': False, + 'indexes': ['name'] + } + + # 3. Remove _types and _cls + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({"$or": [{'_types': {"$exists": True}}, + {'_cls': {"$exists": True}}]}).count() + assert count == 0 + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key']) + or '_cls' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.ensure_indexes() diff --git a/tests/migration/uuidfield_to_binary.py b/tests/migration/uuidfield_to_binary.py new file mode 100644 index 00000000..a535e91f --- /dev/null +++ b/tests/migration/uuidfield_to_binary.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +import unittest +import uuid + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField, UUIDField, ListField + +__all__ = ('ConvertToBinaryUUID', ) + + +class ConvertToBinaryUUID(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def test_how_to_convert_to_binary_uuid_fields(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Old definition - using dbrefs + class Person(Document): + name = StringField() + uuid = UUIDField(binary=False) + uuids = ListField(UUIDField(binary=False)) + + Person.drop_collection() + Person(name="Wilson Jr", uuid=uuid.uuid4(), + uuids=[uuid.uuid4(), uuid.uuid4()]).save() + + # 2. Start the migration by changing the schema + # Change UUIDFIeld as now binary defaults to True + class Person(Document): + name = StringField() + uuid = UUIDField() + uuids = ListField(UUIDField()) + + # 3. Loop all the objects and mark parent as changed + for p in Person.objects: + p._mark_as_changed('uuid') + p._mark_as_changed('uuids') + p.save() + + # 4. Confirmation of the fix! + wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] + self.assertTrue(isinstance(wilson['uuid'], uuid.UUID)) + self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']])) diff --git a/tests/queryset/__init__.py b/tests/queryset/__init__.py new file mode 100644 index 00000000..93cb8c23 --- /dev/null +++ b/tests/queryset/__init__.py @@ -0,0 +1,5 @@ + +from transform import * +from field_list import * +from queryset import * +from visitor import * \ No newline at end of file diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py new file mode 100644 index 00000000..2bdfce1f --- /dev/null +++ b/tests/queryset/field_list.py @@ -0,0 +1,399 @@ +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import * +from mongoengine.queryset import QueryFieldList + +__all__ = ("QueryFieldListTest", "OnlyExcludeAllTest") + + +class QueryFieldListTest(unittest.TestCase): + + def test_empty(self): + q = QueryFieldList() + self.assertFalse(q) + + q = QueryFieldList(always_include=['_cls']) + self.assertFalse(q) + + def test_include_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY, _only_called=True) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + + def test_include_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': 1}) + + def test_exclude_exclude(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0, 'c': 0}) + + def test_exclude_include(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'c': 1}) + + def test_always_include(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) + + def test_reset(self): + q = QueryFieldList(always_include=['x', 'y']) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) + q.reset() + self.assertFalse(q) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'b': 1, 'c': 1}) + + def test_using_a_slice(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a'], value={"$slice": 5}) + self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) + + +class OnlyExcludeAllTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + class Person(Document): + name = StringField() + age = IntField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + self.Person = Person + + def test_mixing_only_exclude(self): + + class MyDoc(Document): + a = StringField() + b = StringField() + c = StringField() + d = StringField() + e = StringField() + f = StringField() + + include = ['a', 'b', 'c', 'd', 'e'] + exclude = ['d', 'e'] + only = ['b', 'c'] + + qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + self.assertEqual(qs._loaded_fields.as_dict(), + {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) + qs = qs.only(*only) + self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + qs = qs.exclude(*exclude) + self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + + qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + qs = qs.exclude(*exclude) + self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + qs = qs.only(*only) + self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + + qs = MyDoc.objects.exclude(*exclude) + qs = qs.fields(**dict(((i, 1) for i in include))) + self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + qs = qs.only(*only) + self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + + def test_slicing(self): + + class MyDoc(Document): + a = ListField() + b = ListField() + c = ListField() + d = ListField() + e = ListField() + f = ListField() + + include = ['a', 'b', 'c', 'd', 'e'] + exclude = ['d', 'e'] + only = ['b', 'c'] + + qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + qs = qs.exclude(*exclude) + qs = qs.only(*only) + qs = qs.fields(slice__b=5) + self.assertEqual(qs._loaded_fields.as_dict(), + {'b': {'$slice': 5}, 'c': 1}) + + qs = qs.fields(slice__c=[5, 1]) + self.assertEqual(qs._loaded_fields.as_dict(), + {'b': {'$slice': 5}, 'c': {'$slice': [5, 1]}}) + + qs = qs.exclude('c') + self.assertEqual(qs._loaded_fields.as_dict(), + {'b': {'$slice': 5}}) + + def test_only(self): + """Ensure that QuerySet.only only returns the requested fields. + """ + person = self.Person(name='test', age=25) + person.save() + + obj = self.Person.objects.only('name').get() + self.assertEqual(obj.name, person.name) + self.assertEqual(obj.age, None) + + obj = self.Person.objects.only('age').get() + self.assertEqual(obj.name, None) + self.assertEqual(obj.age, person.age) + + obj = self.Person.objects.only('name', 'age').get() + self.assertEqual(obj.name, person.name) + self.assertEqual(obj.age, person.age) + + # Check polymorphism still works + class Employee(self.Person): + salary = IntField(db_field='wage') + + employee = Employee(name='test employee', age=40, salary=30000) + employee.save() + + obj = self.Person.objects(id=employee.id).only('age').get() + self.assertTrue(isinstance(obj, Employee)) + + # Check field names are looked up properly + obj = Employee.objects(id=employee.id).only('salary').get() + self.assertEqual(obj.salary, employee.salary) + self.assertEqual(obj.name, None) + + def test_only_with_subfields(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.only('author.name',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author.email, None) + self.assertEqual(obj.author.name, 'Test User') + self.assertEqual(obj.comments, []) + + obj = BlogPost.objects.only('content', 'comments.title',).get() + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, None) + self.assertEqual(obj.comments[1].text, None) + + obj = BlogPost.objects.only('comments',).get() + self.assertEqual(obj.content, None) + self.assertEqual(obj.author, None) + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].text, 'Great post!') + self.assertEqual(obj.comments[1].text, 'I hate coffee') + + BlogPost.drop_collection() + + def test_exclude(self): + class User(EmbeddedDocument): + name = StringField() + email = StringField() + + class Comment(EmbeddedDocument): + title = StringField() + text = StringField() + + class BlogPost(Document): + content = StringField() + author = EmbeddedDocumentField(User) + comments = ListField(EmbeddedDocumentField(Comment)) + + BlogPost.drop_collection() + + post = BlogPost(content='Had a good coffee today...') + post.author = User(name='Test User') + post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post.save() + + obj = BlogPost.objects.exclude('author', 'comments.text').get() + self.assertEqual(obj.author, None) + self.assertEqual(obj.content, 'Had a good coffee today...') + self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.comments[0].text, None) + + BlogPost.drop_collection() + + def test_exclude_only_combining(self): + class Attachment(EmbeddedDocument): + name = StringField() + content = StringField() + + class Email(Document): + sender = StringField() + to = StringField() + subject = StringField() + body = StringField() + content_type = StringField() + attachments = ListField(EmbeddedDocumentField(Attachment)) + + Email.drop_collection() + email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email.attachments = [ + Attachment(name='file1.doc', content='ABC'), + Attachment(name='file2.doc', content='XYZ'), + ] + email.save() + + obj = Email.objects.exclude('content_type').exclude('body').get() + self.assertEqual(obj.sender, 'me') + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, 'From Russia with Love') + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() + self.assertEqual(obj.attachments[0].name, 'file1.doc') + self.assertEqual(obj.attachments[0].content, None) + self.assertEqual(obj.sender, None) + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, None) + self.assertEqual(obj.body, None) + self.assertEqual(obj.content_type, None) + + Email.drop_collection() + + def test_all_fields(self): + + class Email(Document): + sender = StringField() + to = StringField() + subject = StringField() + body = StringField() + content_type = StringField() + + Email.drop_collection() + + email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email.save() + + obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get() + self.assertEqual(obj.sender, 'me') + self.assertEqual(obj.to, 'you') + self.assertEqual(obj.subject, 'From Russia with Love') + self.assertEqual(obj.body, 'Hello!') + self.assertEqual(obj.content_type, 'text/plain') + + Email.drop_collection() + + def test_slicing_fields(self): + """Ensure that query slicing an array works. + """ + class Numbers(Document): + n = ListField(IntField()) + + Numbers.drop_collection() + + numbers = Numbers(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__n=3).get() + self.assertEqual(numbers.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__n=-3).get() + self.assertEqual(numbers.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__n=[2, 3]).get() + self.assertEqual(numbers.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() + self.assertEqual(numbers.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() + self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() + self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) + + def test_slicing_nested_fields(self): + """Ensure that query slicing an embedded array works. + """ + + class EmbeddedNumber(EmbeddedDocument): + n = ListField(IntField()) + + class Numbers(Document): + embedded = EmbeddedDocumentField(EmbeddedNumber) + + Numbers.drop_collection() + + numbers = Numbers() + numbers.embedded = EmbeddedNumber(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__embedded__n=3).get() + self.assertEqual(numbers.embedded.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__embedded__n=-3).get() + self.assertEqual(numbers.embedded.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() + self.assertEqual(numbers.embedded.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() + self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() + self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() + self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_queryset.py b/tests/queryset/queryset.py similarity index 75% rename from tests/test_queryset.py rename to tests/queryset/queryset.py index 6b569261..5e403c4e 100644 --- a/tests/test_queryset.py +++ b/tests/queryset/queryset.py @@ -1,19 +1,30 @@ from __future__ import with_statement +import sys +sys.path[0:0] = [""] + import unittest +import uuid +from nose.plugins.skip import SkipTest from datetime import datetime, timedelta import pymongo +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference from bson import ObjectId from mongoengine import * from mongoengine.connection import get_connection from mongoengine.python_support import PY3 -from mongoengine.tests import query_counter +from mongoengine.context_managers import query_counter from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, - QueryFieldList) + queryset_manager) +from mongoengine.errors import InvalidQueryError + +__all__ = ("QuerySetTest",) + class QuerySetTest(unittest.TestCase): @@ -37,33 +48,29 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(isinstance(self.Person.objects._collection, pymongo.collection.Collection)) - def test_transform_query(self): - """Ensure that the _transform_query function operates correctly. - """ - self.assertEqual(QuerySet._transform_query(name='test', age=30), - {'name': 'test', 'age': 30}) - self.assertEqual(QuerySet._transform_query(age__lt=30), - {'age': {'$lt': 30}}) - self.assertEqual(QuerySet._transform_query(age__gt=20, age__lt=50), - {'age': {'$gt': 20, '$lt': 50}}) - self.assertEqual(QuerySet._transform_query(age=20, age__gt=50), - {'$and': [{'age': {'$gt': 50}}, {'age': 20}]}) - self.assertEqual(QuerySet._transform_query(friend__age__gte=30), - {'friend.age': {'$gte': 30}}) - self.assertEqual(QuerySet._transform_query(name__exists=True), - {'name': {'$exists': True}}) + def test_cannot_perform_joins_references(self): + + class BlogPost(Document): + author = ReferenceField(self.Person) + author2 = GenericReferenceField() + + def test_reference(): + list(BlogPost.objects(author__name="test")) + + self.assertRaises(InvalidQueryError, test_reference) + + def test_generic_reference(): + list(BlogPost.objects(author2__name="test")) def test_find(self): """Ensure that a query returns a valid set of results. """ - person1 = self.Person(name="User A", age=20) - person1.save() - person2 = self.Person(name="User B", age=30) - person2.save() + self.Person(name="User A", age=20).save() + self.Person(name="User B", age=30).save() # Find all people in the collection people = self.Person.objects - self.assertEqual(len(people), 2) + self.assertEqual(people.count(), 2) results = list(people) self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) @@ -74,7 +81,7 @@ class QuerySetTest(unittest.TestCase): # Use a query to filter the people found to just person1 people = self.Person.objects(age=20) - self.assertEqual(len(people), 1) + self.assertEqual(people.count(), 1) person = people.next() self.assertEqual(person.name, "User A") self.assertEqual(person.age, 20) @@ -121,7 +128,7 @@ class QuerySetTest(unittest.TestCase): for i in xrange(55): self.Person(name='A%s' % i, age=i).save() - self.assertEqual(len(self.Person.objects), 55) + self.assertEqual(self.Person.objects.count(), 55) self.assertEqual("Person object", "%s" % self.Person.objects[0]) self.assertEqual("[, ]", "%s" % self.Person.objects[1:3]) self.assertEqual("[, ]", "%s" % self.Person.objects[51:53]) @@ -202,10 +209,10 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() Blog.objects.create(tags=['a', 'b']) - self.assertEqual(len(Blog.objects(tags__0='a')), 1) - self.assertEqual(len(Blog.objects(tags__0='b')), 0) - self.assertEqual(len(Blog.objects(tags__1='a')), 0) - self.assertEqual(len(Blog.objects(tags__1='b')), 1) + self.assertEqual(Blog.objects(tags__0='a').count(), 1) + self.assertEqual(Blog.objects(tags__0='b').count(), 0) + self.assertEqual(Blog.objects(tags__1='a').count(), 0) + self.assertEqual(Blog.objects(tags__1='b').count(), 1) Blog.drop_collection() @@ -220,16 +227,26 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(blog, blog1) query = Blog.objects(posts__1__comments__1__name='testb') - self.assertEqual(len(query), 2) + self.assertEqual(query.count(), 2) query = Blog.objects(posts__1__comments__1__name='testa') - self.assertEqual(len(query), 0) + self.assertEqual(query.count(), 0) query = Blog.objects(posts__0__comments__1__name='testa') - self.assertEqual(len(query), 0) + self.assertEqual(query.count(), 0) Blog.drop_collection() + def test_none(self): + class A(Document): + s = StringField() + + A.drop_collection() + A().save() + + self.assertEqual(list(A.objects.none()), []) + self.assertEqual(list(A.objects.none().all()), []) + def test_chaining(self): class A(Document): s = StringField() @@ -259,23 +276,24 @@ class QuerySetTest(unittest.TestCase): query = query.filter(boolfield=True) self.assertEquals(query.count(), 1) - def test_update_write_options(self): - """Test that passing write_options works""" + def test_update_write_concern(self): + """Test that passing write_concern works""" self.Person.drop_collection() - write_options = {"fsync": True} + write_concern = {"fsync": True} author, created = self.Person.objects.get_or_create( - name='Test User', write_options=write_options) - author.save(write_options=write_options) + name='Test User', write_concern=write_concern) + author.save(write_concern=write_concern) - self.Person.objects.update(set__name='Ross', write_options=write_options) + self.Person.objects.update(set__name='Ross', + write_concern=write_concern) author = self.Person.objects.first() self.assertEqual(author.name, 'Ross') - self.Person.objects.update_one(set__name='Test User', write_options=write_options) + self.Person.objects.update_one(set__name='Test User', write_concern=write_concern) author = self.Person.objects.first() self.assertEqual(author.name, 'Test User') @@ -318,24 +336,23 @@ class QuerySetTest(unittest.TestCase): comment2 = Comment(name='testb') post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) - blog1 = Blog.objects.create(posts=[post1, post2]) - blog2 = Blog.objects.create(posts=[post2, post1]) + Blog.objects.create(posts=[post1, post2]) + Blog.objects.create(posts=[post2, post1]) # Update all of the first comments of second posts of all blogs - blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + Blog.objects().update(set__posts__1__comments__0__name="testc") testc_blogs = Blog.objects(posts__1__comments__0__name="testc") - self.assertEqual(len(testc_blogs), 2) + self.assertEqual(testc_blogs.count(), 2) Blog.drop_collection() - - blog1 = Blog.objects.create(posts=[post1, post2]) - blog2 = Blog.objects.create(posts=[post2, post1]) + Blog.objects.create(posts=[post1, post2]) + Blog.objects.create(posts=[post2, post1]) # Update only the first blog returned by the query - blog = Blog.objects().update_one( + Blog.objects().update_one( set__posts__1__comments__1__name="testc") testc_blogs = Blog.objects(posts__1__comments__1__name="testc") - self.assertEqual(len(testc_blogs), 1) + self.assertEqual(testc_blogs.count(), 1) # Check that using this indexing syntax on a non-list fails def non_list_indexing(): @@ -369,6 +386,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(post.comments[1].by, 'jane') self.assertEqual(post.comments[1].votes, 8) + def test_update_using_positional_operator_matches_first(self): + # Currently the $ operator only applies to the first matched item in # the query @@ -570,10 +589,17 @@ class QuerySetTest(unittest.TestCase): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 for the insert + self.assertEqual(q, 1) # 1 for the insert + + Blog.drop_collection() + with query_counter() as q: + self.assertEqual(q, 0) + + Blog.ensure_indexes() + self.assertEqual(q, 1) Blog.objects.insert(blogs) - self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total) + self.assertEqual(q, 3) # 1 for insert, and 1 for in bulk fetch (3 in total) Blog.drop_collection() @@ -597,7 +623,7 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(OperationError, throw_operation_error) # Test can insert new doc - new_post = Blog(title="code", id=ObjectId()) + new_post = Blog(title="code123", id=ObjectId()) Blog.objects.insert(new_post) # test handles other classes being inserted @@ -633,12 +659,13 @@ class QuerySetTest(unittest.TestCase): Blog.objects.insert([blog1, blog2]) def throw_operation_error_not_unique(): - Blog.objects.insert([blog2, blog3], safe=True) + Blog.objects.insert([blog2, blog3]) self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], write_options={'continue_on_error': True}) + Blog.objects.insert([blog2, blog3], write_concern={"w": 0, + 'continue_on_error': True}) self.assertEqual(Blog.objects.count(), 3) def test_get_changed_fields_query_count(self): @@ -664,7 +691,7 @@ class QuerySetTest(unittest.TestCase): r2 = Project(name="r2").save() r3 = Project(name="r3").save() p1 = Person(name="p1", projects=[r1, r2]).save() - p2 = Person(name="p2", projects=[r2]).save() + p2 = Person(name="p2", projects=[r2, r3]).save() o1 = Organization(name="o1", employees=[p1]).save() with query_counter() as q: @@ -679,24 +706,24 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.save() + fresh_o1.save() # No changes, does nothing - self.assertEqual(q, 2) + self.assertEqual(q, 1) with query_counter() as q: self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.save(cascade=False) + fresh_o1.save(cascade=False) # No changes, does nothing - self.assertEqual(q, 2) + self.assertEqual(q, 1) with query_counter() as q: self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.employees.append(p2) - fresh_o1.save(cascade=False) + fresh_o1.employees.append(p2) # Dereferences + fresh_o1.save(cascade=False) # Saves self.assertEqual(q, 3) @@ -722,19 +749,19 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(p._cursor_args, {'snapshot': False, 'slave_okay': False, 'timeout': True}) - p.snapshot(False).slave_okay(False).timeout(False) + p = p.snapshot(False).slave_okay(False).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': False, 'slave_okay': False, 'timeout': False}) - p.snapshot(True).slave_okay(False).timeout(False) + p = p.snapshot(True).slave_okay(False).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': False, 'timeout': False}) - p.snapshot(True).slave_okay(True).timeout(False) + p = p.snapshot(True).slave_okay(True).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': True, 'timeout': False}) - p.snapshot(True).slave_okay(True).timeout(True) + p = p.snapshot(True).slave_okay(True).timeout(True) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': True, 'timeout': True}) @@ -763,7 +790,7 @@ class QuerySetTest(unittest.TestCase): number = IntField() def __repr__(self): - return "" % self.number + return "" % self.number Doc.drop_collection() @@ -773,19 +800,17 @@ class QuerySetTest(unittest.TestCase): docs = Doc.objects.order_by('number') self.assertEqual(docs.count(), 1000) - self.assertEqual(len(docs), 1000) docs_string = "%s" % docs self.assertTrue("Doc: 0" in docs_string) self.assertEqual(docs.count(), 1000) - self.assertEqual(len(docs), 1000) # Limit and skip - self.assertEqual('[, , ]', "%s" % docs[1:4]) + docs = docs[1:4] + self.assertEqual('[, , ]', "%s" % docs) self.assertEqual(docs.count(), 3) - self.assertEqual(len(docs), 3) for doc in docs: self.assertEqual('.. queryset mid-iteration ..', repr(docs)) @@ -800,48 +825,30 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(name__contains='Van').first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__contains='van')).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__contains='Van')).first() - self.assertEqual(obj, None) # Test icontains obj = self.Person.objects(name__icontains='Van').first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__icontains='Van')).first() - self.assertEqual(obj, person) # Test startswith obj = self.Person.objects(name__startswith='Guido').first() self.assertEqual(obj, person) obj = self.Person.objects(name__startswith='guido').first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__startswith='Guido')).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__startswith='guido')).first() - self.assertEqual(obj, None) # Test istartswith obj = self.Person.objects(name__istartswith='guido').first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__istartswith='guido')).first() - self.assertEqual(obj, person) # Test endswith obj = self.Person.objects(name__endswith='Rossum').first() self.assertEqual(obj, person) obj = self.Person.objects(name__endswith='rossuM').first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__endswith='Rossum')).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__endswith='rossuM')).first() - self.assertEqual(obj, None) # Test iendswith obj = self.Person.objects(name__iendswith='rossuM').first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__iendswith='rossuM')).first() - self.assertEqual(obj, person) # Test exact obj = self.Person.objects(name__exact='Guido van Rossum').first() @@ -850,28 +857,18 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, None) obj = self.Person.objects(name__exact='Guido van Rossu').first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__exact='Guido van Rossum')).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__exact='Guido van rossum')).first() - self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__exact='Guido van Rossu')).first() - self.assertEqual(obj, None) # Test iexact obj = self.Person.objects(name__iexact='gUIDO VAN rOSSUM').first() self.assertEqual(obj, person) obj = self.Person.objects(name__iexact='gUIDO VAN rOSSU').first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSUM')).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() - self.assertEqual(obj, None) # Test unsafe expressions person = self.Person(name='Guido van Rossum [.\'Geek\']') person.save() - obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() + obj = self.Person.objects(name__icontains='[.\'Geek').first() self.assertEqual(obj, person) def test_not(self): @@ -914,14 +911,14 @@ class QuerySetTest(unittest.TestCase): blog_3.save() blog_post_1 = BlogPost(blog=blog_1, title="Blog Post #1", - is_published = True, - published_date=datetime(2010, 1, 5, 0, 0 ,0)) + is_published=True, + published_date=datetime(2010, 1, 5, 0, 0, 0)) blog_post_2 = BlogPost(blog=blog_2, title="Blog Post #2", - is_published = True, - published_date=datetime(2010, 1, 6, 0, 0 ,0)) + is_published=True, + published_date=datetime(2010, 1, 6, 0, 0, 0)) blog_post_3 = BlogPost(blog=blog_3, title="Blog Post #3", - is_published = True, - published_date=datetime(2010, 1, 7, 0, 0 ,0)) + is_published=True, + published_date=datetime(2010, 1, 7, 0, 0, 0)) blog_post_1.save() blog_post_2.save() @@ -930,10 +927,9 @@ class QuerySetTest(unittest.TestCase): # find all published blog posts before 2010-01-07 published_posts = BlogPost.published() published_posts = published_posts.filter( - published_date__lt=datetime(2010, 1, 7, 0, 0 ,0)) + published_date__lt=datetime(2010, 1, 7, 0, 0, 0)) self.assertEqual(published_posts.count(), 2) - blog_posts = BlogPost.objects blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2]) blog_posts = blog_posts.filter(blog=blog_3) @@ -942,24 +938,11 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() Blog.drop_collection() - def test_raw_and_merging(self): - class Doc(Document): - pass - - raw_query = Doc.objects(__raw__={'deleted': False, - 'scraped': 'yes', - '$nor': [{'views.extracted': 'no'}, - {'attachments.views.extracted':'no'}] - })._query - - expected = {'deleted': False, '_types': 'Doc', 'scraped': 'yes', - '$nor': [{'views.extracted': 'no'}, - {'attachments.views.extracted': 'no'}]} - self.assertEqual(expected, raw_query) - def assertSequence(self, qs, expected): + qs = list(qs) + expected = list(expected) self.assertEqual(len(qs), len(expected)) - for i in range(len(qs)): + for i in xrange(len(qs)): self.assertEqual(qs[i], expected[i]) def test_ordering(self): @@ -975,12 +958,12 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - blog_post_2 = BlogPost(title="Blog Post #2", - published_date=datetime(2010, 1, 6, 0, 0 ,0)) blog_post_1 = BlogPost(title="Blog Post #1", - published_date=datetime(2010, 1, 5, 0, 0 ,0)) + published_date=datetime(2010, 1, 5, 0, 0, 0)) + blog_post_2 = BlogPost(title="Blog Post #2", + published_date=datetime(2010, 1, 6, 0, 0, 0)) blog_post_3 = BlogPost(title="Blog Post #3", - published_date=datetime(2010, 1, 7, 0, 0 ,0)) + published_date=datetime(2010, 1, 7, 0, 0, 0)) blog_post_1.save() blog_post_2.save() @@ -996,257 +979,6 @@ class QuerySetTest(unittest.TestCase): expected = [blog_post_1, blog_post_2, blog_post_3] self.assertSequence(qs, expected) - def test_only(self): - """Ensure that QuerySet.only only returns the requested fields. - """ - person = self.Person(name='test', age=25) - person.save() - - obj = self.Person.objects.only('name').get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, None) - - obj = self.Person.objects.only('age').get() - self.assertEqual(obj.name, None) - self.assertEqual(obj.age, person.age) - - obj = self.Person.objects.only('name', 'age').get() - self.assertEqual(obj.name, person.name) - self.assertEqual(obj.age, person.age) - - # Check polymorphism still works - class Employee(self.Person): - salary = IntField(db_field='wage') - - employee = Employee(name='test employee', age=40, salary=30000) - employee.save() - - obj = self.Person.objects(id=employee.id).only('age').get() - self.assertTrue(isinstance(obj, Employee)) - - # Check field names are looked up properly - obj = Employee.objects(id=employee.id).only('salary').get() - self.assertEqual(obj.salary, employee.salary) - self.assertEqual(obj.name, None) - - def test_only_with_subfields(self): - class User(EmbeddedDocument): - name = StringField() - email = StringField() - - class Comment(EmbeddedDocument): - title = StringField() - text = StringField() - - class BlogPost(Document): - content = StringField() - author = EmbeddedDocumentField(User) - comments = ListField(EmbeddedDocumentField(Comment)) - - BlogPost.drop_collection() - - post = BlogPost(content='Had a good coffee today...') - post.author = User(name='Test User') - post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] - post.save() - - obj = BlogPost.objects.only('author.name',).get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author.email, None) - self.assertEqual(obj.author.name, 'Test User') - self.assertEqual(obj.comments, []) - - obj = BlogPost.objects.only('content', 'comments.title',).get() - self.assertEqual(obj.content, 'Had a good coffee today...') - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, 'I aggree') - self.assertEqual(obj.comments[1].title, 'Coffee') - self.assertEqual(obj.comments[0].text, None) - self.assertEqual(obj.comments[1].text, None) - - obj = BlogPost.objects.only('comments',).get() - self.assertEqual(obj.content, None) - self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, 'I aggree') - self.assertEqual(obj.comments[1].title, 'Coffee') - self.assertEqual(obj.comments[0].text, 'Great post!') - self.assertEqual(obj.comments[1].text, 'I hate coffee') - - BlogPost.drop_collection() - - def test_exclude(self): - class User(EmbeddedDocument): - name = StringField() - email = StringField() - - class Comment(EmbeddedDocument): - title = StringField() - text = StringField() - - class BlogPost(Document): - content = StringField() - author = EmbeddedDocumentField(User) - comments = ListField(EmbeddedDocumentField(Comment)) - - BlogPost.drop_collection() - - post = BlogPost(content='Had a good coffee today...') - post.author = User(name='Test User') - post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] - post.save() - - obj = BlogPost.objects.exclude('author', 'comments.text').get() - self.assertEqual(obj.author, None) - self.assertEqual(obj.content, 'Had a good coffee today...') - self.assertEqual(obj.comments[0].title, 'I aggree') - self.assertEqual(obj.comments[0].text, None) - - BlogPost.drop_collection() - - def test_exclude_only_combining(self): - class Attachment(EmbeddedDocument): - name = StringField() - content = StringField() - - class Email(Document): - sender = StringField() - to = StringField() - subject = StringField() - body = StringField() - content_type = StringField() - attachments = ListField(EmbeddedDocumentField(Attachment)) - - Email.drop_collection() - email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') - email.attachments = [ - Attachment(name='file1.doc', content='ABC'), - Attachment(name='file2.doc', content='XYZ'), - ] - email.save() - - obj = Email.objects.exclude('content_type').exclude('body').get() - self.assertEqual(obj.sender, 'me') - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, 'From Russia with Love') - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) - - obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) - - obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() - self.assertEqual(obj.attachments[0].name, 'file1.doc') - self.assertEqual(obj.attachments[0].content, None) - self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, None) - self.assertEqual(obj.body, None) - self.assertEqual(obj.content_type, None) - - Email.drop_collection() - - def test_all_fields(self): - - class Email(Document): - sender = StringField() - to = StringField() - subject = StringField() - body = StringField() - content_type = StringField() - - Email.drop_collection() - - email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') - email.save() - - obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get() - self.assertEqual(obj.sender, 'me') - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, 'From Russia with Love') - self.assertEqual(obj.body, 'Hello!') - self.assertEqual(obj.content_type, 'text/plain') - - Email.drop_collection() - - def test_slicing_fields(self): - """Ensure that query slicing an array works. - """ - class Numbers(Document): - n = ListField(IntField()) - - Numbers.drop_collection() - - numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) - numbers.save() - - # first three - numbers = Numbers.objects.fields(slice__n=3).get() - self.assertEqual(numbers.n, [0, 1, 2]) - - # last three - numbers = Numbers.objects.fields(slice__n=-3).get() - self.assertEqual(numbers.n, [-3, -2, -1]) - - # skip 2, limit 3 - numbers = Numbers.objects.fields(slice__n=[2, 3]).get() - self.assertEqual(numbers.n, [2, 3, 4]) - - # skip to fifth from last, limit 4 - numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2]) - - # skip to fifth from last, limit 10 - numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) - - # skip to fifth from last, limit 10 dict method - numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.n, [-5, -4, -3, -2, -1]) - - def test_slicing_nested_fields(self): - """Ensure that query slicing an embedded array works. - """ - - class EmbeddedNumber(EmbeddedDocument): - n = ListField(IntField()) - - class Numbers(Document): - embedded = EmbeddedDocumentField(EmbeddedNumber) - - Numbers.drop_collection() - - numbers = Numbers() - numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) - numbers.save() - - # first three - numbers = Numbers.objects.fields(slice__embedded__n=3).get() - self.assertEqual(numbers.embedded.n, [0, 1, 2]) - - # last three - numbers = Numbers.objects.fields(slice__embedded__n=-3).get() - self.assertEqual(numbers.embedded.n, [-3, -2, -1]) - - # skip 2, limit 3 - numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() - self.assertEqual(numbers.embedded.n, [2, 3, 4]) - - # skip to fifth from last, limit 4 - numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2]) - - # skip to fifth from last, limit 10 - numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) - - # skip to fifth from last, limit 10 dict method - numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() - self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) - def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -1285,143 +1017,6 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - def test_q(self): - """Ensure that Q objects may be used to query for documents. - """ - class BlogPost(Document): - title = StringField() - publish_date = DateTimeField() - published = BooleanField() - - BlogPost.drop_collection() - - post1 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 8), published=False) - post1.save() - - post2 = BlogPost(title='Test 2', publish_date=datetime(2010, 1, 15), published=True) - post2.save() - - post3 = BlogPost(title='Test 3', published=True) - post3.save() - - post4 = BlogPost(title='Test 4', publish_date=datetime(2010, 1, 8)) - post4.save() - - post5 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 15)) - post5.save() - - post6 = BlogPost(title='Test 1', published=False) - post6.save() - - # Check ObjectId lookup works - obj = BlogPost.objects(id=post1.id).first() - self.assertEqual(obj, post1) - - # Check Q object combination with one does not exist - q = BlogPost.objects(Q(title='Test 5') | Q(published=True)) - posts = [post.id for post in q] - - published_posts = (post2, post3) - self.assertTrue(all(obj.id in posts for obj in published_posts)) - - q = BlogPost.objects(Q(title='Test 1') | Q(published=True)) - posts = [post.id for post in q] - published_posts = (post1, post2, post3, post5, post6) - self.assertTrue(all(obj.id in posts for obj in published_posts)) - - # Check Q object combination - date = datetime(2010, 1, 10) - q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) - posts = [post.id for post in q] - - published_posts = (post1, post2, post3, post4) - self.assertTrue(all(obj.id in posts for obj in published_posts)) - - self.assertFalse(any(obj.id in posts for obj in [post5, post6])) - - BlogPost.drop_collection() - - # Check the 'in' operator - self.Person(name='user1', age=20).save() - self.Person(name='user2', age=20).save() - self.Person(name='user3', age=30).save() - self.Person(name='user4', age=40).save() - - self.assertEqual(len(self.Person.objects(Q(age__in=[20]))), 2) - self.assertEqual(len(self.Person.objects(Q(age__in=[20, 30]))), 3) - - def test_q_regex(self): - """Ensure that Q objects can be queried using regexes. - """ - person = self.Person(name='Guido van Rossum') - person.save() - - import re - obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() - self.assertEqual(obj, person) - obj = self.Person.objects(Q(name=re.compile('^gui'))).first() - self.assertEqual(obj, None) - - obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() - self.assertEqual(obj, person) - - obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() - self.assertEqual(obj, person) - - obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() - self.assertEqual(obj, None) - - def test_q_lists(self): - """Ensure that Q objects query ListFields correctly. - """ - class BlogPost(Document): - tags = ListField(StringField()) - - BlogPost.drop_collection() - - BlogPost(tags=['python', 'mongo']).save() - BlogPost(tags=['python']).save() - - self.assertEqual(len(BlogPost.objects(Q(tags='mongo'))), 1) - self.assertEqual(len(BlogPost.objects(Q(tags='python'))), 2) - - BlogPost.drop_collection() - - def test_raw_query_and_Q_objects(self): - """ - Test raw plays nicely - """ - class Foo(Document): - name = StringField() - a = StringField() - b = StringField() - c = StringField() - - meta = { - 'allow_inheritance': False - } - - query = Foo.objects(__raw__={'$nor': [{'name': 'bar'}]})._query - self.assertEqual(query, {'$nor': [{'name': 'bar'}]}) - - q1 = {'$or': [{'a': 1}, {'b': 1}]} - query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query - self.assertEqual(query, {'$or': [{'a': 1}, {'b': 1}], 'c': 1}) - - def test_q_merge_queries_edge_case(self): - - class User(Document): - email = EmailField(required=False) - name = StringField() - - User.drop_collection() - pk = ObjectId() - User(email='example@example.com', pk=pk).save() - - self.assertEqual(1, User.objects.filter( - Q(email='example@example.com') | - Q(name='John Doe') - ).limit(2).filter(pk=pk).count()) def test_exec_js_query(self): """Ensure that queries are properly formed for use in exec_js. @@ -1458,13 +1053,6 @@ class QuerySetTest(unittest.TestCase): c = BlogPost.objects(published=False).exec_js(js_func, 'hits') self.assertEqual(c, 1) - # Ensure that Q object queries work - c = BlogPost.objects(Q(published=True)).exec_js(js_func, 'hits') - self.assertEqual(c, 2) - - c = BlogPost.objects(Q(published=False)).exec_js(js_func, 'hits') - self.assertEqual(c, 1) - BlogPost.drop_collection() def test_exec_js_field_sub(self): @@ -1532,13 +1120,13 @@ class QuerySetTest(unittest.TestCase): self.Person(name="User B", age=30).save() self.Person(name="User C", age=40).save() - self.assertEqual(len(self.Person.objects), 3) + self.assertEqual(self.Person.objects.count(), 3) self.Person.objects(age__lt=30).delete() - self.assertEqual(len(self.Person.objects), 2) + self.assertEqual(self.Person.objects.count(), 2) self.Person.objects.delete() - self.assertEqual(len(self.Person.objects), 0) + self.assertEqual(self.Person.objects.count(), 0) def test_reverse_delete_rule_cascade(self): """Ensure cascading deletion of referring documents from the database. @@ -1896,7 +1484,6 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - def test_set_list_embedded_documents(self): class Author(EmbeddedDocument): @@ -1954,11 +1541,11 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() blog_post_3 = BlogPost(title="Blog Post #3", - published_date=datetime(2010, 1, 6, 0, 0 ,0)) + published_date=datetime(2010, 1, 6, 0, 0, 0)) blog_post_2 = BlogPost(title="Blog Post #2", - published_date=datetime(2010, 1, 5, 0, 0 ,0)) + published_date=datetime(2010, 1, 5, 0, 0, 0)) blog_post_4 = BlogPost(title="Blog Post #4", - published_date=datetime(2010, 1, 7, 0, 0 ,0)) + published_date=datetime(2010, 1, 7, 0, 0, 0)) blog_post_1 = BlogPost(title="Blog Post #1", published_date=None) blog_post_3.save() @@ -1984,11 +1571,11 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() blog_post_1 = BlogPost(title="A", - published_date=datetime(2010, 1, 6, 0, 0 ,0)) + published_date=datetime(2010, 1, 6, 0, 0, 0)) blog_post_2 = BlogPost(title="B", - published_date=datetime(2010, 1, 6, 0, 0 ,0)) + published_date=datetime(2010, 1, 6, 0, 0, 0)) blog_post_3 = BlogPost(title="C", - published_date=datetime(2010, 1, 7, 0, 0 ,0)) + published_date=datetime(2010, 1, 7, 0, 0, 0)) blog_post_2.save() blog_post_3.save() @@ -2025,6 +1612,7 @@ class QuerySetTest(unittest.TestCase): qs = self.Person.objects.all().limit(10) qs = qs.order_by('-age') + ages = [p.age for p in qs] self.assertEqual(ages, [40, 30, 20]) @@ -2615,56 +2203,68 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - def test_query_field_name(self): - """Ensure that the correct field name is used when querying. - """ - class Comment(EmbeddedDocument): - content = StringField(db_field='commentContent') + def test_custom_manager_overriding_objects_works(self): - class BlogPost(Document): - title = StringField(db_field='postTitle') - comments = ListField(EmbeddedDocumentField(Comment), - db_field='postComments') + class Foo(Document): + bar = StringField(default='bar') + active = BooleanField(default=False) + @queryset_manager + def objects(doc_cls, queryset): + return queryset(active=True) - BlogPost.drop_collection() + @queryset_manager + def with_inactive(doc_cls, queryset): + return queryset(active=False) - data = {'title': 'Post 1', 'comments': [Comment(content='test')]} - post = BlogPost(**data) - post.save() + Foo.drop_collection() - self.assertTrue('postTitle' in - BlogPost.objects(title=data['title'])._query) - self.assertFalse('title' in - BlogPost.objects(title=data['title'])._query) - self.assertEqual(len(BlogPost.objects(title=data['title'])), 1) + Foo(active=True).save() + Foo(active=False).save() - self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query) - self.assertEqual(len(BlogPost.objects(pk=post.id)), 1) + self.assertEqual(1, Foo.objects.count()) + self.assertEqual(1, Foo.with_inactive.count()) - self.assertTrue('postComments.commentContent' in - BlogPost.objects(comments__content='test')._query) - self.assertEqual(len(BlogPost.objects(comments__content='test')), 1) + Foo.with_inactive.first().delete() + self.assertEqual(0, Foo.with_inactive.count()) + self.assertEqual(1, Foo.objects.count()) - BlogPost.drop_collection() + def test_inherit_objects(self): - def test_query_pk_field_name(self): - """Ensure that the correct "primary key" field name is used when querying - """ - class BlogPost(Document): - title = StringField(primary_key=True, db_field='postTitle') + class Foo(Document): + meta = {'allow_inheritance': True} + active = BooleanField(default=True) - BlogPost.drop_collection() + @queryset_manager + def objects(klass, queryset): + return queryset(active=True) - data = { 'title':'Post 1' } - post = BlogPost(**data) - post.save() + class Bar(Foo): + pass - self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query) - self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query) - self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1) + Bar.drop_collection() + Bar.objects.create(active=False) + self.assertEqual(0, Bar.objects.count()) - BlogPost.drop_collection() + def test_inherit_objects_override(self): + + class Foo(Document): + meta = {'allow_inheritance': True} + active = BooleanField(default=True) + + @queryset_manager + def objects(klass, queryset): + return queryset(active=True) + + class Bar(Foo): + @queryset_manager + def objects(klass, queryset): + return queryset(active=False) + + Bar.drop_collection() + Bar.objects.create(active=False) + self.assertEqual(0, Foo.objects.count()) + self.assertEqual(1, Bar.objects.count()) def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. @@ -2717,68 +2317,6 @@ class QuerySetTest(unittest.TestCase): Group.drop_collection() - def test_types_index(self): - """Ensure that and index is used when '_types' is being used in a - query. - """ - class BlogPost(Document): - date = DateTimeField() - meta = {'indexes': ['-date']} - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1)] in info) - self.assertTrue([('_types', 1), ('date', -1)] in info) - - def test_dont_index_types(self): - """Ensure that index_types will, when disabled, prevent _types - being added to all indices. - """ - class BloggPost(Document): - date = DateTimeField() - meta = {'index_types': False, - 'indexes': ['-date']} - - # Indexes are lazy so use list() to perform query - list(BloggPost.objects) - info = BloggPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1)] not in info) - self.assertTrue([('date', -1)] in info) - - BloggPost.drop_collection() - - class BloggPost(Document): - title = StringField() - meta = {'allow_inheritance': False} - - # _types is not used on objects where allow_inheritance is False - list(BloggPost.objects) - info = BloggPost.objects._collection.index_information() - self.assertFalse([('_types', 1)] in info.values()) - - BloggPost.drop_collection() - - def test_types_index_with_pk(self): - - class Comment(EmbeddedDocument): - comment_id = IntField(required=True) - - try: - class BlogPost(Document): - comments = EmbeddedDocumentField(Comment) - meta = {'indexes': [{'fields': ['pk', 'comments.comment_id'], - 'unique': True}]} - except UnboundLocalError: - self.fail('Unbound local error at types index + pk definition') - - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - index_item = [(u'_types', 1), (u'_id', 1), (u'comments.comment_id', 1)] - self.assertTrue(index_item in info) - def test_dict_with_custom_baseclass(self): """Ensure DictField working with custom base clases. """ @@ -2790,8 +2328,8 @@ class QuerySetTest(unittest.TestCase): t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) - self.assertEqual(len(Test.objects(testdict__f='Value')), 1) + self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) + self.assertEqual(Test.objects(testdict__f='Value').count(), 1) Test.drop_collection() class Test(Document): @@ -2800,8 +2338,8 @@ class QuerySetTest(unittest.TestCase): t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f='Value')), 1) - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) + self.assertEqual(Test.objects(testdict__f='Value').count(), 1) + self.assertEqual(Test.objects(testdict__f__startswith='Val').count(), 1) Test.drop_collection() def test_bulk(self): @@ -2891,6 +2429,12 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event3, event1, event2]) + # find events within 10 degrees of san francisco + point = [37.7566023, -122.415579] + events = Event.objects(location__near=point, location__max_distance=10) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + # find events within 10 degrees of san francisco point_and_distance = [[37.7566023, -122.415579], 10] events = Event.objects(location__within_distance=point_and_distance) @@ -2970,6 +2514,10 @@ class QuerySetTest(unittest.TestCase): ); self.assertEqual(points.count(), 2) + points = Point.objects(location__near_sphere=[-122, 37.5], + location__max_distance=60 / earth_radius); + self.assertEqual(points.count(), 2) + # Finds both points, but orders the north point first because it's # closer to the reference point to the north. points = Point.objects(location__near_sphere=[-122, 38.5]) @@ -2987,8 +2535,7 @@ class QuerySetTest(unittest.TestCase): # Finds only one point because only the first point is within 60km of # the reference point to the south. points = Point.objects( - location__within_spherical_distance=[[-122, 36.5], 60/earth_radius] - ); + location__within_spherical_distance=[[-122, 36.5], 60/earth_radius]) self.assertEqual(points.count(), 1) self.assertEqual(points[0].id, south_point.id) @@ -2999,7 +2546,7 @@ class QuerySetTest(unittest.TestCase): """ class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class Post(Document): meta = {'queryset_class': CustomQuerySet} @@ -3020,7 +2567,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class CustomQuerySetManager(QuerySetManager): queryset_class = CustomQuerySet @@ -3067,7 +2614,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class Base(Document): meta = {'abstract': True, 'queryset_class': CustomQuerySet} @@ -3090,7 +2637,7 @@ class QuerySetTest(unittest.TestCase): class CustomQuerySet(QuerySet): def not_empty(self): - return len(self) > 0 + return self.count() > 0 class CustomQuerySetManager(QuerySetManager): queryset_class = CustomQuerySet @@ -3111,6 +2658,19 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_count_limit_and_skip(self): + class Post(Document): + title = StringField() + + Post.drop_collection() + + for i in xrange(10): + Post(title="Post %s" % i).save() + + self.assertEqual(5, Post.objects.limit(5).skip(5).count()) + + self.assertEqual(10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False)) + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ @@ -3119,10 +2679,8 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() - post1 = Post(title="Post 1") - post1.save() - post2 = Post(title="Post 2") - post2.save() + Post(title="Post 1").save() + Post(title="Post 2").save() posts = Post.objects.all()[0:1] self.assertEqual(len(list(posts())), 1) @@ -3229,21 +2787,21 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([1, 2, 3], numbers) Number.drop_collection() - def test_ensure_index(self): """Ensure that manual creation of indexes works. """ class Comment(Document): message = StringField() + meta = {'allow_inheritance': True} - Comment.objects.ensure_index('message') + Comment.ensure_index('message') info = Comment.objects._collection.index_information() info = [(value['key'], value.get('unique', False), value.get('sparse', False)) for key, value in info.iteritems()] - self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) + self.assertTrue(([('_cls', 1), ('message', 1)], False, False) in info) def test_where(self): """Ensure that where clauses work. @@ -3492,14 +3050,14 @@ class QuerySetTest(unittest.TestCase): # Find all people in the collection people = self.Person.objects.scalar('name') - self.assertEqual(len(people), 2) + self.assertEqual(people.count(), 2) results = list(people) self.assertEqual(results[0], "User A") self.assertEqual(results[1], "User B") # Use a query to filter the people found to just person1 people = self.Person.objects(age=20).scalar('name') - self.assertEqual(len(people), 1) + self.assertEqual(people.count(), 1) person = people.next() self.assertEqual(person, "User A") @@ -3545,7 +3103,7 @@ class QuerySetTest(unittest.TestCase): for i in xrange(55): self.Person(name='A%s' % i, age=i).save() - self.assertEqual(len(self.Person.objects.scalar('name')), 55) + self.assertEqual(self.Person.objects.scalar('name').count(), 55) self.assertEqual("A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) self.assertEqual("A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) if PY3: @@ -3565,227 +3123,6 @@ class QuerySetTest(unittest.TestCase): else: self.assertEqual("[u'A1', u'A2']", "%s" % sorted(self.Person.objects.scalar('name').in_bulk(list(pks)).values())) - -class QTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - - def test_empty_q(self): - """Ensure that empty Q objects won't hurt. - """ - q1 = Q() - q2 = Q(age__gte=18) - q3 = Q() - q4 = Q(name='test') - q5 = Q() - - class Person(Document): - name = StringField() - age = IntField() - - query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]} - self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) - - query = {'age': {'$gte': 18}, 'name': 'test'} - self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) - - def test_q_with_dbref(self): - """Ensure Q objects handle DBRefs correctly""" - connect(db='mongoenginetest') - - class User(Document): - pass - - class Post(Document): - created_user = ReferenceField(User) - - user = User.objects.create() - Post.objects.create(created_user=user) - - self.assertEqual(Post.objects.filter(created_user=user).count(), 1) - self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) - - def test_and_combination(self): - """Ensure that Q-objects correctly AND together. - """ - class TestDoc(Document): - x = IntField() - y = StringField() - - # Check than an error is raised when conflicting queries are anded - def invalid_combination(): - query = Q(x__lt=7) & Q(x__lt=3) - query.to_query(TestDoc) - self.assertRaises(InvalidQueryError, invalid_combination) - - # Check normal cases work without an error - query = Q(x__lt=7) & Q(x__gt=3) - - q1 = Q(x__lt=7) - q2 = Q(x__gt=3) - query = (q1 & q2).to_query(TestDoc) - self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) - - # More complex nested example - query = Q(x__lt=100) & Q(y__ne='NotMyString') - query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) - mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, - 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, - } - self.assertEqual(query.to_query(TestDoc), mongo_query) - - def test_or_combination(self): - """Ensure that Q-objects correctly OR together. - """ - class TestDoc(Document): - x = IntField() - - q1 = Q(x__lt=3) - q2 = Q(x__gt=7) - query = (q1 | q2).to_query(TestDoc) - self.assertEqual(query, { - '$or': [ - {'x': {'$lt': 3}}, - {'x': {'$gt': 7}}, - ] - }) - - def test_and_or_combination(self): - """Ensure that Q-objects handle ANDing ORed components. - """ - class TestDoc(Document): - x = IntField() - y = BooleanField() - - query = (Q(x__gt=0) | Q(x__exists=False)) - query &= Q(x__lt=100) - self.assertEqual(query.to_query(TestDoc), { - '$or': [ - {'x': {'$lt': 100, '$gt': 0}}, - {'x': {'$lt': 100, '$exists': False}}, - ] - }) - - q1 = (Q(x__gt=0) | Q(x__exists=False)) - q2 = (Q(x__lt=100) | Q(y=True)) - query = (q1 & q2).to_query(TestDoc) - - self.assertEqual(['$or'], query.keys()) - conditions = [ - {'x': {'$lt': 100, '$gt': 0}}, - {'x': {'$lt': 100, '$exists': False}}, - {'x': {'$gt': 0}, 'y': True}, - {'x': {'$exists': False}, 'y': True}, - ] - self.assertEqual(len(conditions), len(query['$or'])) - for condition in conditions: - self.assertTrue(condition in query['$or']) - - def test_or_and_or_combination(self): - """Ensure that Q-objects handle ORing ANDed ORed components. :) - """ - class TestDoc(Document): - x = IntField() - y = BooleanField() - - q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))) - q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) - query = (q1 | q2).to_query(TestDoc) - - self.assertEqual(['$or'], query.keys()) - conditions = [ - {'x': {'$gt': 0}, 'y': True}, - {'x': {'$gt': 0}, 'y': {'$exists': False}}, - {'x': {'$lt': 100}, 'y':False}, - {'x': {'$lt': 100}, 'y': {'$exists': False}}, - ] - self.assertEqual(len(conditions), len(query['$or'])) - for condition in conditions: - self.assertTrue(condition in query['$or']) - - - def test_q_clone(self): - - class TestDoc(Document): - x = IntField() - - TestDoc.drop_collection() - for i in xrange(1, 101): - t = TestDoc(x=i) - t.save() - - # Check normal cases work without an error - test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) - - self.assertEqual(test.count(), 3) - - test2 = test.clone() - self.assertEqual(test2.count(), 3) - self.assertFalse(test2 == test) - - test2.filter(x=6) - self.assertEqual(test2.count(), 1) - self.assertEqual(test.count(), 3) - -class QueryFieldListTest(unittest.TestCase): - def test_empty(self): - q = QueryFieldList() - self.assertFalse(q) - - q = QueryFieldList(always_include=['_cls']) - self.assertFalse(q) - - def test_include_include(self): - q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'b': True}) - - def test_include_exclude(self): - q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': True}) - - def test_exclude_exclude(self): - q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) - - def test_exclude_include(self): - q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'c': True}) - - def test_always_include(self): - q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) - - def test_reset(self): - q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) - q.reset() - self.assertFalse(q) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) - - def test_using_a_slice(self): - q = QueryFieldList() - q += QueryFieldList(fields=['a'], value={"$slice": 5}) - self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) - def test_elem_match(self): class Foo(EmbeddedDocument): shape = StringField() @@ -3810,6 +3147,102 @@ class QueryFieldListTest(unittest.TestCase): ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) + def test_upsert_includes_cls(self): + """Upserts should include _cls information for inheritable classes + """ + + class Test(Document): + test = StringField() + + Test.drop_collection() + Test.objects(test='foo').update_one(upsert=True, set__test='foo') + self.assertFalse('_cls' in Test._collection.find_one()) + + class Test(Document): + meta = {'allow_inheritance': True} + test = StringField() + + Test.drop_collection() + + Test.objects(test='foo').update_one(upsert=True, set__test='foo') + self.assertTrue('_cls' in Test._collection.find_one()) + + def test_read_preference(self): + class Bar(Document): + pass + + Bar.drop_collection() + bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) + self.assertEqual([], bars) + + self.assertRaises(ConfigurationError, Bar.objects, + read_preference='Primary') + + def test_json_simple(self): + + class Embedded(EmbeddedDocument): + string = StringField() + + class Doc(Document): + string = StringField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + Doc(string="Hi", embedded_field=Embedded(string="Hi")).save() + Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() + + Doc().save() + json_data = Doc.objects.to_json() + doc_objects = list(Doc.objects) + + self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + + def test_json_complex(self): + if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: + raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") + + class EmbeddedDoc(EmbeddedDocument): + pass + + class Simple(Document): + pass + + class Doc(Document): + string_field = StringField(default='1') + int_field = IntField(default=1) + float_field = FloatField(default=1.1) + boolean_field = BooleanField(default=True) + datetime_field = DateTimeField(default=datetime.now) + embedded_document_field = EmbeddedDocumentField( + EmbeddedDoc, default=lambda: EmbeddedDoc()) + list_field = ListField(default=lambda: [1, 2, 3]) + dict_field = DictField(default=lambda: {"hello": "world"}) + objectid_field = ObjectIdField(default=ObjectId) + reference_field = ReferenceField(Simple, default=lambda: Simple().save()) + map_field = MapField(IntField(), default=lambda: {"simple": 1}) + decimal_field = DecimalField(default=1.0) + complex_datetime_field = ComplexDateTimeField(default=datetime.now) + url_field = URLField(default="http://mongoengine.org") + dynamic_field = DynamicField(default=1) + generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) + sorted_list_field = SortedListField(IntField(), + default=lambda: [1, 2, 3]) + email_field = EmailField(default="ross@example.com") + geo_point_field = GeoPointField(default=lambda: [1, 2]) + sequence_field = SequenceField() + uuid_field = UUIDField(default=uuid.uuid4) + generic_embedded_document_field = GenericEmbeddedDocumentField( + default=lambda: EmbeddedDoc()) + + Simple.drop_collection() + Doc.drop_collection() + + Doc().save() + json_data = Doc.objects.to_json() + doc_objects = list(Doc.objects) + + self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) + def test_as_pymongo(self): from decimal import Decimal @@ -3829,9 +3262,9 @@ class QueryFieldListTest(unittest.TestCase): self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[1], dict)) self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], '1.11') + self.assertEqual(results[0]['price'], 1.11) self.assertEqual(results[1]['name'], 'Barack Obama') - self.assertEqual(results[1]['price'], '2.22') + self.assertEqual(results[1]['price'], 2.22) # Test coerce_types users = User.objects.only('name', 'price').as_pymongo(coerce_types=True) @@ -3843,5 +3276,67 @@ class QueryFieldListTest(unittest.TestCase): self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['price'], Decimal('2.22')) + def test_no_dereference(self): + + class Organization(Document): + name = StringField() + + class User(Document): + name = StringField() + organization = ReferenceField(Organization) + + User.drop_collection() + Organization.drop_collection() + + whitehouse = Organization(name="White House").save() + User(name="Bob Dole", organization=whitehouse).save() + + qs = User.objects() + self.assertTrue(isinstance(qs.first().organization, Organization)) + self.assertFalse(isinstance(qs.no_dereference().first().organization, + Organization)) + self.assertTrue(isinstance(qs.first().organization, Organization)) + + def test_nested_queryset_iterator(self): + # Try iterating the same queryset twice, nested. + names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] + + class User(Document): + name = StringField() + + def __unicode__(self): + return self.name + + User.drop_collection() + + for name in names: + User(name=name).save() + + users = User.objects.all().order_by('name') + + outer_count = 0 + inner_count = 0 + inner_total_count = 0 + + self.assertEqual(users.count(), 7) + + for i, outer_user in enumerate(users): + self.assertEqual(outer_user.name, names[i]) + outer_count += 1 + inner_count = 0 + + # Calling len might disrupt the inner loop if there are bugs + self.assertEqual(users.count(), 7) + + for j, inner_user in enumerate(users): + self.assertEqual(inner_user.name, names[j]) + inner_count += 1 + inner_total_count += 1 + + self.assertEqual(inner_count, 7) # inner loop should always be executed seven times + + self.assertEqual(outer_count, 7) # outer loop should be executed seven times total + self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py new file mode 100644 index 00000000..bde4b6f1 --- /dev/null +++ b/tests/queryset/transform.py @@ -0,0 +1,148 @@ +from __future__ import with_statement +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import * +from mongoengine.queryset import Q +from mongoengine.queryset import transform + +__all__ = ("TransformTest",) + + +class TransformTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + def test_transform_query(self): + """Ensure that the _transform_query function operates correctly. + """ + self.assertEqual(transform.query(name='test', age=30), + {'name': 'test', 'age': 30}) + self.assertEqual(transform.query(age__lt=30), + {'age': {'$lt': 30}}) + self.assertEqual(transform.query(age__gt=20, age__lt=50), + {'age': {'$gt': 20, '$lt': 50}}) + self.assertEqual(transform.query(age=20, age__gt=50), + {'$and': [{'age': {'$gt': 50}}, {'age': 20}]}) + self.assertEqual(transform.query(friend__age__gte=30), + {'friend.age': {'$gte': 30}}) + self.assertEqual(transform.query(name__exists=True), + {'name': {'$exists': True}}) + + def test_query_field_name(self): + """Ensure that the correct field name is used when querying. + """ + class Comment(EmbeddedDocument): + content = StringField(db_field='commentContent') + + class BlogPost(Document): + title = StringField(db_field='postTitle') + comments = ListField(EmbeddedDocumentField(Comment), + db_field='postComments') + + BlogPost.drop_collection() + + data = {'title': 'Post 1', 'comments': [Comment(content='test')]} + post = BlogPost(**data) + post.save() + + self.assertTrue('postTitle' in + BlogPost.objects(title=data['title'])._query) + self.assertFalse('title' in + BlogPost.objects(title=data['title'])._query) + self.assertEqual(BlogPost.objects(title=data['title']).count(), 1) + + self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query) + self.assertEqual(BlogPost.objects(pk=post.id).count(), 1) + + self.assertTrue('postComments.commentContent' in + BlogPost.objects(comments__content='test')._query) + self.assertEqual(BlogPost.objects(comments__content='test').count(), 1) + + BlogPost.drop_collection() + + def test_query_pk_field_name(self): + """Ensure that the correct "primary key" field name is used when + querying + """ + class BlogPost(Document): + title = StringField(primary_key=True, db_field='postTitle') + + BlogPost.drop_collection() + + data = {'title': 'Post 1'} + post = BlogPost(**data) + post.save() + + self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query) + self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query) + self.assertEqual(BlogPost.objects(pk=data['title']).count(), 1) + + BlogPost.drop_collection() + + def test_chaining(self): + class A(Document): + pass + + class B(Document): + a = ReferenceField(A) + + A.drop_collection() + B.drop_collection() + + a1 = A().save() + a2 = A().save() + + B(a=a1).save() + + # Works + q1 = B.objects.filter(a__in=[a1, a2], a=a1)._query + + # Doesn't work + q2 = B.objects.filter(a__in=[a1, a2]) + q2 = q2.filter(a=a1)._query + + self.assertEqual(q1, q2) + + def test_raw_query_and_Q_objects(self): + """ + Test raw plays nicely + """ + class Foo(Document): + name = StringField() + a = StringField() + b = StringField() + c = StringField() + + meta = { + 'allow_inheritance': False + } + + query = Foo.objects(__raw__={'$nor': [{'name': 'bar'}]})._query + self.assertEqual(query, {'$nor': [{'name': 'bar'}]}) + + q1 = {'$or': [{'a': 1}, {'b': 1}]} + query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query + self.assertEqual(query, {'$or': [{'a': 1}, {'b': 1}], 'c': 1}) + + def test_raw_and_merging(self): + class Doc(Document): + meta = {'allow_inheritance': False} + + raw_query = Doc.objects(__raw__={'deleted': False, + 'scraped': 'yes', + '$nor': [{'views.extracted': 'no'}, + {'attachments.views.extracted':'no'}] + })._query + + expected = {'deleted': False, 'scraped': 'yes', + '$nor': [{'views.extracted': 'no'}, + {'attachments.views.extracted': 'no'}]} + self.assertEqual(expected, raw_query) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py new file mode 100644 index 00000000..bd81a654 --- /dev/null +++ b/tests/queryset/visitor.py @@ -0,0 +1,335 @@ +from __future__ import with_statement +import sys +sys.path[0:0] = [""] + +import unittest + +from bson import ObjectId +from datetime import datetime + +from mongoengine import * +from mongoengine.queryset import Q +from mongoengine.errors import InvalidQueryError + +__all__ = ("QTest",) + + +class QTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + class Person(Document): + name = StringField() + age = IntField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + self.Person = Person + + def test_empty_q(self): + """Ensure that empty Q objects won't hurt. + """ + q1 = Q() + q2 = Q(age__gte=18) + q3 = Q() + q4 = Q(name='test') + q5 = Q() + + class Person(Document): + name = StringField() + age = IntField() + + query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]} + self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) + + query = {'age': {'$gte': 18}, 'name': 'test'} + self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) + + def test_q_with_dbref(self): + """Ensure Q objects handle DBRefs correctly""" + connect(db='mongoenginetest') + + class User(Document): + pass + + class Post(Document): + created_user = ReferenceField(User) + + user = User.objects.create() + Post.objects.create(created_user=user) + + self.assertEqual(Post.objects.filter(created_user=user).count(), 1) + self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + + def test_and_combination(self): + """Ensure that Q-objects correctly AND together. + """ + class TestDoc(Document): + x = IntField() + y = StringField() + + # Check than an error is raised when conflicting queries are anded + def invalid_combination(): + query = Q(x__lt=7) & Q(x__lt=3) + query.to_query(TestDoc) + self.assertRaises(InvalidQueryError, invalid_combination) + + # Check normal cases work without an error + query = Q(x__lt=7) & Q(x__gt=3) + + q1 = Q(x__lt=7) + q2 = Q(x__gt=3) + query = (q1 & q2).to_query(TestDoc) + self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) + + # More complex nested example + query = Q(x__lt=100) & Q(y__ne='NotMyString') + query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) + mongo_query = { + 'x': {'$lt': 100, '$gt': -100}, + 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, + } + self.assertEqual(query.to_query(TestDoc), mongo_query) + + def test_or_combination(self): + """Ensure that Q-objects correctly OR together. + """ + class TestDoc(Document): + x = IntField() + + q1 = Q(x__lt=3) + q2 = Q(x__gt=7) + query = (q1 | q2).to_query(TestDoc) + self.assertEqual(query, { + '$or': [ + {'x': {'$lt': 3}}, + {'x': {'$gt': 7}}, + ] + }) + + def test_and_or_combination(self): + """Ensure that Q-objects handle ANDing ORed components. + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + TestDoc.drop_collection() + + query = (Q(x__gt=0) | Q(x__exists=False)) + query &= Q(x__lt=100) + self.assertEqual(query.to_query(TestDoc), {'$and': [ + {'$or': [{'x': {'$gt': 0}}, + {'x': {'$exists': False}}]}, + {'x': {'$lt': 100}}] + }) + + q1 = (Q(x__gt=0) | Q(x__exists=False)) + q2 = (Q(x__lt=100) | Q(y=True)) + query = (q1 & q2).to_query(TestDoc) + + TestDoc(x=101).save() + TestDoc(x=10).save() + TestDoc(y=True).save() + + self.assertEqual(query, + {'$and': [ + {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, + {'$or': [{'x': {'$lt': 100}}, {'y': True}]} + ]}) + + self.assertEqual(2, TestDoc.objects(q1 & q2).count()) + + def test_or_and_or_combination(self): + """Ensure that Q-objects handle ORing ANDed ORed components. :) + """ + class TestDoc(Document): + x = IntField() + y = BooleanField() + + TestDoc.drop_collection() + TestDoc(x=-1, y=True).save() + TestDoc(x=101, y=True).save() + TestDoc(x=99, y=False).save() + TestDoc(x=101, y=False).save() + + q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))) + q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) + query = (q1 | q2).to_query(TestDoc) + + self.assertEqual(query, + {'$or': [ + {'$and': [{'x': {'$gt': 0}}, + {'$or': [{'y': True}, {'y': {'$exists': False}}]}]}, + {'$and': [{'x': {'$lt': 100}}, + {'$or': [{'y': False}, {'y': {'$exists': False}}]}]} + ]} + ) + + self.assertEqual(2, TestDoc.objects(q1 | q2).count()) + + def test_multiple_occurence_in_field(self): + class Test(Document): + name = StringField(max_length=40) + title = StringField(max_length=40) + + q1 = Q(name__contains='te') | Q(title__contains='te') + q2 = Q(name__contains='12') | Q(title__contains='12') + + q3 = q1 & q2 + + query = q3.to_query(Test) + self.assertEqual(query["$and"][0], q1.to_query(Test)) + self.assertEqual(query["$and"][1], q2.to_query(Test)) + + def test_q_clone(self): + + class TestDoc(Document): + x = IntField() + + TestDoc.drop_collection() + for i in xrange(1, 101): + t = TestDoc(x=i) + t.save() + + # Check normal cases work without an error + test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) + + self.assertEqual(test.count(), 3) + + test2 = test.clone() + self.assertEqual(test2.count(), 3) + self.assertFalse(test2 == test) + + test3 = test2.filter(x=6) + self.assertEqual(test3.count(), 1) + self.assertEqual(test.count(), 3) + + def test_q(self): + """Ensure that Q objects may be used to query for documents. + """ + class BlogPost(Document): + title = StringField() + publish_date = DateTimeField() + published = BooleanField() + + BlogPost.drop_collection() + + post1 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 8), published=False) + post1.save() + + post2 = BlogPost(title='Test 2', publish_date=datetime(2010, 1, 15), published=True) + post2.save() + + post3 = BlogPost(title='Test 3', published=True) + post3.save() + + post4 = BlogPost(title='Test 4', publish_date=datetime(2010, 1, 8)) + post4.save() + + post5 = BlogPost(title='Test 1', publish_date=datetime(2010, 1, 15)) + post5.save() + + post6 = BlogPost(title='Test 1', published=False) + post6.save() + + # Check ObjectId lookup works + obj = BlogPost.objects(id=post1.id).first() + self.assertEqual(obj, post1) + + # Check Q object combination with one does not exist + q = BlogPost.objects(Q(title='Test 5') | Q(published=True)) + posts = [post.id for post in q] + + published_posts = (post2, post3) + self.assertTrue(all(obj.id in posts for obj in published_posts)) + + q = BlogPost.objects(Q(title='Test 1') | Q(published=True)) + posts = [post.id for post in q] + published_posts = (post1, post2, post3, post5, post6) + self.assertTrue(all(obj.id in posts for obj in published_posts)) + + # Check Q object combination + date = datetime(2010, 1, 10) + q = BlogPost.objects(Q(publish_date__lte=date) | Q(published=True)) + posts = [post.id for post in q] + + published_posts = (post1, post2, post3, post4) + self.assertTrue(all(obj.id in posts for obj in published_posts)) + + self.assertFalse(any(obj.id in posts for obj in [post5, post6])) + + BlogPost.drop_collection() + + # Check the 'in' operator + self.Person(name='user1', age=20).save() + self.Person(name='user2', age=20).save() + self.Person(name='user3', age=30).save() + self.Person(name='user4', age=40).save() + + self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2) + self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) + + # Test invalid query objs + def wrong_query_objs(): + self.Person.objects('user1') + def wrong_query_objs_filter(): + self.Person.objects('user1') + self.assertRaises(InvalidQueryError, wrong_query_objs) + self.assertRaises(InvalidQueryError, wrong_query_objs_filter) + + def test_q_regex(self): + """Ensure that Q objects can be queried using regexes. + """ + person = self.Person(name='Guido van Rossum') + person.save() + + import re + obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() + self.assertEqual(obj, person) + obj = self.Person.objects(Q(name=re.compile('^gui'))).first() + self.assertEqual(obj, None) + + obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() + self.assertEqual(obj, person) + + obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() + self.assertEqual(obj, person) + + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() + self.assertEqual(obj, None) + + def test_q_lists(self): + """Ensure that Q objects query ListFields correctly. + """ + class BlogPost(Document): + tags = ListField(StringField()) + + BlogPost.drop_collection() + + BlogPost(tags=['python', 'mongo']).save() + BlogPost(tags=['python']).save() + + self.assertEqual(BlogPost.objects(Q(tags='mongo')).count(), 1) + self.assertEqual(BlogPost.objects(Q(tags='python')).count(), 2) + + BlogPost.drop_collection() + + def test_q_merge_queries_edge_case(self): + + class User(Document): + email = EmailField(required=False) + name = StringField() + + User.drop_collection() + pk = ObjectId() + User(email='example@example.com', pk=pk).save() + + self.assertEqual(1, User.objects.filter( + Q(email='example@example.com') | + Q(name='John Doe') + ).limit(2).filter(pk=pk).count()) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_all_warnings.py b/tests/test_all_warnings.py deleted file mode 100644 index 9b090349..00000000 --- a/tests/test_all_warnings.py +++ /dev/null @@ -1,98 +0,0 @@ -import unittest -import warnings - -from mongoengine import * -from mongoengine.tests import query_counter - - -class TestWarnings(unittest.TestCase): - - def setUp(self): - conn = connect(db='mongoenginetest') - self.warning_list = [] - self.showwarning_default = warnings.showwarning - warnings.showwarning = self.append_to_warning_list - - def append_to_warning_list(self, message, category, *args): - self.warning_list.append({"message": message, - "category": category}) - - def tearDown(self): - # restore default handling of warnings - warnings.showwarning = self.showwarning_default - - def test_allow_inheritance_future_warning(self): - """Add FutureWarning for future allow_inhertiance default change. - """ - - class SimpleBase(Document): - a = IntField() - - class InheritedClass(SimpleBase): - b = IntField() - - InheritedClass() - self.assertEqual(len(self.warning_list), 1) - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("InheritedClass" in str(warning["message"])) - - def test_dbref_reference_field_future_warning(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person() - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save(cascade=False) - - self.assertTrue(len(self.warning_list) > 0) - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("ReferenceFields will default to using ObjectId" - in str(warning["message"])) - - def test_document_save_cascade_future_warning(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.parent.name = "Poppa Wilson" - p2.save() - - self.assertTrue(len(self.warning_list) > 0) - if len(self.warning_list) > 1: - print self.warning_list - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("Cascading saves will default to off in 0.8" - in str(warning["message"])) - - def test_document_collection_syntax_warning(self): - - class NonAbstractBase(Document): - pass - - class InheritedDocumentFailTest(NonAbstractBase): - meta = {'collection': 'fail'} - - warning = self.warning_list[0] - self.assertEqual(SyntaxWarning, warning["category"]) - self.assertEqual('non_abstract_base', - InheritedDocumentFailTest._get_collection_name()) diff --git a/tests/test_connection.py b/tests/test_connection.py index cd03df0b..4b8a3d11 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,12 +1,14 @@ -import datetime -import pymongo +from __future__ import with_statement +import sys +sys.path[0:0] = [""] import unittest +import datetime -import mongoengine.connection - +import pymongo from bson.tz_util import utc from mongoengine import * +import mongoengine.connection from mongoengine.connection import get_db, get_connection, ConnectionError @@ -23,7 +25,7 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest') conn = get_connection() - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db() self.assertTrue(isinstance(db, pymongo.database.Database)) @@ -31,7 +33,7 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest2', alias='testdb') conn = get_connection('testdb') - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) def test_connect_uri(self): """Ensure that the connect() method works properly with uri's @@ -49,7 +51,7 @@ class ConnectionTest(unittest.TestCase): connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') conn = get_connection() - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db() self.assertTrue(isinstance(db, pymongo.database.Database)) @@ -62,7 +64,7 @@ class ConnectionTest(unittest.TestCase): self.assertRaises(ConnectionError, get_connection) conn = get_connection('testdb') - self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db('testdb') self.assertTrue(isinstance(db, pymongo.database.Database)) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py new file mode 100644 index 00000000..eef63bee --- /dev/null +++ b/tests/test_context_managers.py @@ -0,0 +1,156 @@ +from __future__ import with_statement +import sys +sys.path[0:0] = [""] +import unittest + +from mongoengine import * +from mongoengine.connection import get_db +from mongoengine.context_managers import (switch_db, switch_collection, + no_dereference, query_counter) + + +class ContextManagersTest(unittest.TestCase): + + def test_switch_db_context_manager(self): + connect('mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + + Group(name="hello - default").save() + self.assertEqual(1, Group.objects.count()) + + with switch_db(Group, 'testdb-1') as Group: + + self.assertEqual(0, Group.objects.count()) + + Group(name="hello").save() + + self.assertEqual(1, Group.objects.count()) + + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + self.assertEqual(1, Group.objects.count()) + + def test_switch_collection_context_manager(self): + connect('mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + with switch_collection(Group, 'group1') as Group: + Group.drop_collection() + + Group(name="hello - group").save() + self.assertEqual(1, Group.objects.count()) + + with switch_collection(Group, 'group1') as Group: + + self.assertEqual(0, Group.objects.count()) + + Group(name="hello - group1").save() + + self.assertEqual(1, Group.objects.count()) + + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + self.assertEqual(1, Group.objects.count()) + + def test_no_dereference_context_manager_object_id(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + connect('mongoenginetest') + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=False) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + + def test_no_dereference_context_manager_dbref(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + connect('mongoenginetest') + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=True) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=True)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + + def test_query_counter(self): + connect('mongoenginetest') + db = get_db() + db.test.find({}) + + with query_counter() as q: + self.assertEqual(0, q) + + for i in xrange(1, 51): + db.test.find({}).count() + + self.assertEqual(50, q) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index d7438d2f..ef5a10d9 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1,11 +1,14 @@ +# -*- coding: utf-8 -*- from __future__ import with_statement +import sys +sys.path[0:0] = [""] import unittest from bson import DBRef, ObjectId from mongoengine import * from mongoengine.connection import get_db -from mongoengine.tests import query_counter +from mongoengine.context_managers import query_counter class FieldTest(unittest.TestCase): @@ -123,6 +126,27 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() + def test_list_item_dereference_dref_false_stores_as_type(self): + """Ensure that DBRef items are stored as their type + """ + class User(Document): + my_id = IntField(primary_key=True) + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + user = User(my_id=1, name='user 1').save() + + Group(members=User.objects).save() + group = Group.objects.first() + + self.assertEqual(Group._get_collection().find_one()['members'], [1]) + self.assertEqual(group.members, [user]) + def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -177,6 +201,10 @@ class FieldTest(unittest.TestCase): raw_data = Group._get_collection().find_one() self.assertTrue(isinstance(raw_data['author'], DBRef)) self.assertTrue(isinstance(raw_data['members'][0], DBRef)) + group = Group.objects.first() + + self.assertEqual(group.author, user) + self.assertEqual(group.members, [user]) # Migrate the model definition class Group(Document): @@ -185,8 +213,9 @@ class FieldTest(unittest.TestCase): # Migrate the data for g in Group.objects(): - g.author = g.author - g.members = g.members + # Explicitly mark as changed so resets + g._mark_as_changed('author') + g._mark_as_changed('members') g.save() group = Group.objects.first() @@ -337,14 +366,10 @@ class FieldTest(unittest.TestCase): return "" % self.name Person.drop_collection() - paul = Person(name="Paul") - paul.save() - maria = Person(name="Maria") - maria.save() - julia = Person(name='Julia') - julia.save() - anna = Person(name='Anna') - anna.save() + paul = Person(name="Paul").save() + maria = Person(name="Maria").save() + julia = Person(name='Julia').save() + anna = Person(name='Anna').save() paul.other.friends = [maria, julia, anna] paul.other.name = "Paul's friends" @@ -997,28 +1022,129 @@ class FieldTest(unittest.TestCase): msg = Message.objects.get(id=1) self.assertEqual(0, msg.comments[0].id) self.assertEqual(1, msg.comments[1].id) - + + def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + Group(name="Test", members=User.objects).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + + def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(ReferenceField(User, dbref=True)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + Group(name="Test", members=User.objects).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + + def test_generic_reference_save_doesnt_cause_extra_queries(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i).save() + b = UserB(name='User B %s' % i).save() + c = UserC(name='User C %s' % i).save() + + members += [a, b, c] + + Group(name="test", members=members).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + def test_tuples_as_tuples(self): """ Ensure that tuples remain tuples when they are inside a ComplexBaseField """ from mongoengine.base import BaseField + class EnumField(BaseField): + def __init__(self, **kwargs): - super(EnumField,self).__init__(**kwargs) - + super(EnumField, self).__init__(**kwargs) + def to_mongo(self, value): return value - + def to_python(self, value): return tuple(value) - + class TestDoc(Document): items = ListField(EnumField()) - + TestDoc.drop_collection() - tuples = [(100,'Testing')] + tuples = [(100, 'Testing')] doc = TestDoc() doc.items = tuples doc.save() @@ -1028,3 +1154,29 @@ class FieldTest(unittest.TestCase): self.assertTrue(tuple(x.items[0]) in tuples) self.assertTrue(x.items[0] in tuples) + def test_non_ascii_pk(self): + """ + Ensure that dbref conversion to string does not fail when + non-ascii characters are used in primary key + """ + class Brand(Document): + title = StringField(max_length=255, primary_key=True) + + class BrandGroup(Document): + title = StringField(max_length=255, primary_key=True) + brands = ListField(ReferenceField("Brand", dbref=True)) + + Brand.drop_collection() + BrandGroup.drop_collection() + + brand1 = Brand(title="Moschino").save() + brand2 = Brand(title=u"Денис Симачёв").save() + + BrandGroup(title="top_brands", brands=[brand1, brand2]).save() + brand_groups = BrandGroup.objects().all() + + self.assertEqual(2, len([brand for bg in brand_groups for brand in bg.brands])) + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_django.py b/tests/test_django.py index 398fd3e0..573c0728 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -1,4 +1,6 @@ from __future__ import with_statement +import sys +sys.path[0:0] = [""] import unittest from nose.plugins.skip import SkipTest from mongoengine.python_support import PY3 @@ -12,8 +14,19 @@ try: from django.conf import settings from django.core.paginator import Paginator - settings.configure() + settings.configure( + USE_TZ=True, + INSTALLED_APPS=('django.contrib.auth', 'mongoengine.django.mongo_auth'), + AUTH_USER_MODEL=('mongo_auth.MongoUser'), + ) + try: + from django.contrib.auth import authenticate, get_user_model + from mongoengine.django.auth import User + from mongoengine.django.mongo_auth.models import MongoUser, MongoUserManager + DJ15 = True + except Exception: + DJ15 = False from django.contrib.sessions.tests import SessionTestsMixin from mongoengine.django.sessions import SessionStore, MongoSession except Exception, err: @@ -24,6 +37,37 @@ except Exception, err: raise err +from datetime import tzinfo, timedelta +ZERO = timedelta(0) + +class FixedOffset(tzinfo): + """Fixed offset in minutes east from UTC.""" + + def __init__(self, offset, name): + self.__offset = timedelta(minutes = offset) + self.__name = name + + def utcoffset(self, dt): + return self.__offset + + def tzname(self, dt): + return self.__name + + def dst(self, dt): + return ZERO + + +def activate_timezone(tz): + """Activate Django timezone support if it is available. + """ + try: + from django.utils import timezone + timezone.deactivate() + timezone.activate(tz) + except ImportError: + pass + + class QuerySetTest(unittest.TestCase): def setUp(self): @@ -103,6 +147,25 @@ class QuerySetTest(unittest.TestCase): start = end - 1 self.assertEqual(t.render(Context(d)), u'%d:%d:' % (start, end)) + def test_nested_queryset_template_iterator(self): + # Try iterating the same queryset twice, nested, in a Django template. + names = ['A', 'B', 'C', 'D'] + + class User(Document): + name = StringField() + + def __unicode__(self): + return self.name + + User.drop_collection() + + for name in names: + User(name=name).save() + + users = User.objects.all().order_by('name') + template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}") + rendered = template.render(Context({'users': users})) + self.assertEqual(rendered, 'AB ABCD CD') class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): @@ -115,8 +178,73 @@ class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase): MongoSession.drop_collection() super(MongoDBSessionTest, self).setUp() + def assertIn(self, first, second, msg=None): + self.assertTrue(first in second, msg) + + def assertNotIn(self, first, second, msg=None): + self.assertFalse(first in second, msg) + def test_first_save(self): session = SessionStore() session['test'] = True session.save() self.assertTrue('test' in session) + + def test_session_expiration_tz(self): + activate_timezone(FixedOffset(60, 'UTC+1')) + # create and save new session + session = SessionStore() + session.set_expiry(600) # expire in 600 seconds + session['test_expire'] = True + session.save() + # reload session with key + key = session.session_key + session = SessionStore(key) + self.assertTrue('test_expire' in session, 'Session has expired before it is expected') + + +class MongoAuthTest(unittest.TestCase): + user_data = { + 'username': 'user', + 'email': 'user@example.com', + 'password': 'test', + } + + def setUp(self): + if PY3: + raise SkipTest('django does not have Python 3 support') + if not DJ15: + raise SkipTest('mongo_auth requires Django 1.5') + connect(db='mongoenginetest') + User.drop_collection() + super(MongoAuthTest, self).setUp() + + def test_user_model(self): + self.assertEqual(get_user_model(), MongoUser) + + def test_user_manager(self): + manager = get_user_model()._default_manager + self.assertIsInstance(manager, MongoUserManager) + + def test_user_manager_exception(self): + manager = get_user_model()._default_manager + self.assertRaises(MongoUser.DoesNotExist, manager.get, + username='not found') + + def test_create_user(self): + manager = get_user_model()._default_manager + user = manager.create_user(**self.user_data) + self.assertIsInstance(user, User) + db_user = User.objects.get(username='user') + self.assertEqual(user.id, db_user.id) + + def test_authenticate(self): + get_user_model()._default_manager.create_user(**self.user_data) + user = authenticate(username='user', password='fail') + self.assertIsNone(user) + user = authenticate(username='user', password='test') + db_user = User.objects.get(username='user') + self.assertEqual(user.id, db_user.id) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_document.py b/tests/test_document.py deleted file mode 100644 index 3e8d8134..00000000 --- a/tests/test_document.py +++ /dev/null @@ -1,3526 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import with_statement -import bson -import os -import pickle -import pymongo -import sys -import unittest -import uuid -import warnings -import operator - -from nose.plugins.skip import SkipTest -from datetime import datetime - -from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest - -from mongoengine import * -from mongoengine.base import NotRegistered, InvalidDocumentError, get_document -from mongoengine.queryset import InvalidQueryError -from mongoengine.connection import get_db, get_connection - -TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') - - -class DocumentTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - class Person(Document): - name = StringField() - age = IntField() - - meta = {'allow_inheritance': True} - - self.Person = Person - - def tearDown(self): - self.Person.drop_collection() - - def test_drop_collection(self): - """Ensure that the collection may be dropped from the database. - """ - self.Person(name='Test').save() - - collection = self.Person._get_collection_name() - self.assertTrue(collection in self.db.collection_names()) - - self.Person.drop_collection() - self.assertFalse(collection in self.db.collection_names()) - - def test_queryset_resurrects_dropped_collection(self): - - self.Person.objects().item_frequencies('name') - self.Person.drop_collection() - - self.assertEqual({}, self.Person.objects().item_frequencies('name')) - - class Actor(self.Person): - pass - - # Ensure works correctly with inhertited classes - Actor.objects().item_frequencies('name') - self.Person.drop_collection() - self.assertEqual({}, Actor.objects().item_frequencies('name')) - - def test_definition(self): - """Ensure that document may be defined using fields. - """ - name_field = StringField() - age_field = IntField() - - class Person(Document): - name = name_field - age = age_field - non_field = True - - self.assertEqual(Person._fields['name'], name_field) - self.assertEqual(Person._fields['age'], age_field) - self.assertFalse('non_field' in Person._fields) - self.assertTrue('id' in Person._fields) - # Test iteration over fields - fields = list(Person()) - self.assertTrue('name' in fields and 'age' in fields) - # Ensure Document isn't treated like an actual document - self.assertFalse(hasattr(Document, '_fields')) - - def test_repr(self): - """Ensure that unicode representation works - """ - class Article(Document): - title = StringField() - - def __unicode__(self): - return self.title - - Article.drop_collection() - - Article(title=u'привет мир').save() - - self.assertEqual('', repr(Article.objects.first())) - self.assertEqual('[]', repr(Article.objects.all())) - - def test_collection_naming(self): - """Ensure that a collection with a specified name may be used. - """ - - class DefaultNamingTest(Document): - pass - self.assertEqual('default_naming_test', DefaultNamingTest._get_collection_name()) - - class CustomNamingTest(Document): - meta = {'collection': 'pimp_my_collection'} - - self.assertEqual('pimp_my_collection', CustomNamingTest._get_collection_name()) - - class DynamicNamingTest(Document): - meta = {'collection': lambda c: "DYNAMO"} - self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name()) - - # Use Abstract class to handle backwards compatibility - class BaseDocument(Document): - meta = { - 'abstract': True, - 'collection': lambda c: c.__name__.lower() - } - - class OldNamingConvention(BaseDocument): - pass - self.assertEqual('oldnamingconvention', OldNamingConvention._get_collection_name()) - - class InheritedAbstractNamingTest(BaseDocument): - meta = {'collection': 'wibble'} - self.assertEqual('wibble', InheritedAbstractNamingTest._get_collection_name()) - - - # Mixin tests - class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } - - class OldMixinNamingConvention(Document, BaseMixin): - pass - self.assertEqual('oldmixinnamingconvention', OldMixinNamingConvention._get_collection_name()) - - class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } - - class BaseDocument(Document, BaseMixin): - meta = {'allow_inheritance': True} - - class MyDocument(BaseDocument): - pass - - self.assertEqual('basedocument', MyDocument._get_collection_name()) - - def test_get_superclasses(self): - """Ensure that the correct list of superclasses is assembled. - """ - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - mammal_superclasses = {'Animal': Animal} - self.assertEqual(Mammal._superclasses, mammal_superclasses) - - dog_superclasses = { - 'Animal': Animal, - 'Animal.Mammal': Mammal, - } - self.assertEqual(Dog._superclasses, dog_superclasses) - - def test_external_superclasses(self): - """Ensure that the correct list of sub and super classes is assembled. - when importing part of the model - """ - class Animal(Base): pass - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - mammal_superclasses = {'Base': Base, 'Base.Animal': Animal} - self.assertEqual(Mammal._superclasses, mammal_superclasses) - - dog_superclasses = { - 'Base': Base, - 'Base.Animal': Animal, - 'Base.Animal.Mammal': Mammal, - } - self.assertEqual(Dog._superclasses, dog_superclasses) - - Base.drop_collection() - - h = Human() - h.save() - - self.assertEqual(Human.objects.count(), 1) - self.assertEqual(Mammal.objects.count(), 1) - self.assertEqual(Animal.objects.count(), 1) - self.assertEqual(Base.objects.count(), 1) - Base.drop_collection() - - def test_polymorphic_queries(self): - """Ensure that the correct subclasses are returned from a query""" - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - Animal.drop_collection() - - Animal().save() - Fish().save() - Mammal().save() - Human().save() - Dog().save() - - classes = [obj.__class__ for obj in Animal.objects] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) - - classes = [obj.__class__ for obj in Mammal.objects] - self.assertEqual(classes, [Mammal, Human, Dog]) - - classes = [obj.__class__ for obj in Human.objects] - self.assertEqual(classes, [Human]) - - Animal.drop_collection() - - def test_polymorphic_references(self): - """Ensure that the correct subclasses are returned from a query when - using references / generic references - """ - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - class Zoo(Document): - animals = ListField(ReferenceField(Animal)) - - Zoo.drop_collection() - Animal.drop_collection() - - Animal().save() - Fish().save() - Mammal().save() - Human().save() - Dog().save() - - # Save a reference to each animal - zoo = Zoo(animals=Animal.objects) - zoo.save() - zoo.reload() - - classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) - - Zoo.drop_collection() - - class Zoo(Document): - animals = ListField(GenericReferenceField(Animal)) - - # Save a reference to each animal - zoo = Zoo(animals=Animal.objects) - zoo.save() - zoo.reload() - - classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) - - Zoo.drop_collection() - Animal.drop_collection() - - def test_reference_inheritance(self): - class Stats(Document): - created = DateTimeField(default=datetime.now) - - meta = {'allow_inheritance': False} - - class CompareStats(Document): - generated = DateTimeField(default=datetime.now) - stats = ListField(ReferenceField(Stats)) - - Stats.drop_collection() - CompareStats.drop_collection() - - list_stats = [] - - for i in xrange(10): - s = Stats() - s.save() - list_stats.append(s) - - cmp_stats = CompareStats(stats=list_stats) - cmp_stats.save() - - self.assertEqual(list_stats, CompareStats.objects.first().stats) - - def test_inheritance(self): - """Ensure that document may inherit fields from a superclass document. - """ - class Employee(self.Person): - salary = IntField() - - self.assertTrue('name' in Employee._fields) - self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._get_collection_name(), - self.Person._get_collection_name()) - - # Ensure that MRO error is not raised - class A(Document): - meta = {'allow_inheritance': True} - class B(A): pass - class C(B): pass - - def test_allow_inheritance(self): - """Ensure that inheritance may be disabled on simple classes and that - _cls and _types will not be used. - """ - - class Animal(Document): - name = StringField() - meta = {'allow_inheritance': False} - - Animal.drop_collection() - def create_dog_class(): - class Dog(Animal): - pass - self.assertRaises(ValueError, create_dog_class) - - # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog') - dog.save() - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - Animal.drop_collection() - - def create_employee_class(): - class Employee(self.Person): - meta = {'allow_inheritance': False} - self.assertRaises(ValueError, create_employee_class) - - def test_allow_inheritance_abstract_document(self): - """Ensure that abstract documents can set inheritance rules and that - _cls and _types will not be used. - """ - class FinalDocument(Document): - meta = {'abstract': True, - 'allow_inheritance': False} - - class Animal(FinalDocument): - name = StringField() - - Animal.drop_collection() - def create_dog_class(): - class Dog(Animal): - pass - self.assertRaises(ValueError, create_dog_class) - - # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog') - dog.save() - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - Animal.drop_collection() - - def test_allow_inheritance_embedded_document(self): - - # Test the same for embedded documents - class Comment(EmbeddedDocument): - content = StringField() - meta = {'allow_inheritance': False} - - def create_special_comment(): - class SpecialComment(Comment): - pass - - self.assertRaises(ValueError, create_special_comment) - - comment = Comment(content='test') - self.assertFalse('_cls' in comment.to_mongo()) - self.assertFalse('_types' in comment.to_mongo()) - - class Comment(EmbeddedDocument): - content = StringField() - meta = {'allow_inheritance': True} - - comment = Comment(content='test') - self.assertTrue('_cls' in comment.to_mongo()) - self.assertTrue('_types' in comment.to_mongo()) - - def test_document_inheritance(self): - """Ensure mutliple inheritance of abstract docs works - """ - class DateCreatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } - - class DateUpdatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } - - try: - class MyDocument(DateCreatedDocument, DateUpdatedDocument): - pass - except: - self.assertTrue(False, "Couldn't create MyDocument class") - - def test_how_to_turn_off_inheritance(self): - """Demonstrates migrating from allow_inheritance = True to False. - """ - class Animal(Document): - name = StringField() - meta = { - 'indexes': ['name'] - } - - self.assertEqual(Animal._meta['index_specs'], - [{'fields': [('_types', 1), ('name', 1)]}]) - - Animal.drop_collection() - - dog = Animal(name='dog') - dog.save() - - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertTrue('_cls' in obj) - self.assertTrue('_types' in obj) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[('_id', 1)], [('_types', 1), ('name', 1)]], - sorted(info, key=operator.itemgetter(0))) - - # Turn off inheritance - class Animal(Document): - name = StringField() - meta = { - 'allow_inheritance': False, - 'indexes': ['name'] - } - - self.assertEqual(Animal._meta['index_specs'], - [{'fields': [('name', 1)]}]) - collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) - - # Confirm extra data is removed - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)], [(u'_types', 1), (u'name', 1)]], - sorted(info, key=operator.itemgetter(0))) - - info = collection.index_information() - indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])] - for index in indexes_to_drop: - collection.drop_index(index) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)]], - sorted(info, key=operator.itemgetter(0))) - - # Recreate indexes - dog = Animal.objects.first() - dog.save() - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)], [(u'name', 1),]], - sorted(info, key=operator.itemgetter(0))) - - Animal.drop_collection() - - def test_abstract_documents(self): - """Ensure that a document superclass can be marked as abstract - thereby not using it as the name for the collection.""" - - defaults = {'index_background': True, - 'index_drop_dups': True, - 'index_opts': {'hello': 'world'}, - 'allow_inheritance': True, - 'queryset_class': 'QuerySet', - 'db_alias': 'myDB', - 'shard_key': ('hello', 'world')} - - meta_settings = {'abstract': True} - meta_settings.update(defaults) - - class Animal(Document): - name = StringField() - meta = meta_settings - - class Fish(Animal): pass - class Guppy(Fish): pass - - class Mammal(Animal): - meta = {'abstract': True} - class Human(Mammal): pass - - for k, v in defaults.iteritems(): - for cls in [Animal, Fish, Guppy]: - self.assertEqual(cls._meta[k], v) - - self.assertFalse('collection' in Animal._meta) - self.assertFalse('collection' in Mammal._meta) - - self.assertEqual(Animal._get_collection_name(), None) - self.assertEqual(Mammal._get_collection_name(), None) - - self.assertEqual(Fish._get_collection_name(), 'fish') - self.assertEqual(Guppy._get_collection_name(), 'fish') - self.assertEqual(Human._get_collection_name(), 'human') - - def create_bad_abstract(): - class EvilHuman(Human): - evil = BooleanField(default=True) - meta = {'abstract': True} - self.assertRaises(ValueError, create_bad_abstract) - - def test_collection_name(self): - """Ensure that a collection with a specified name may be used. - """ - collection = 'personCollTest' - if collection in self.db.collection_names(): - self.db.drop_collection(collection) - - class Person(Document): - name = StringField() - meta = {'collection': collection} - - user = Person(name="Test User") - user.save() - self.assertTrue(collection in self.db.collection_names()) - - user_obj = self.db[collection].find_one() - self.assertEqual(user_obj['name'], "Test User") - - user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") - - Person.drop_collection() - self.assertFalse(collection in self.db.collection_names()) - - def test_collection_name_and_primary(self): - """Ensure that a collection with a specified name may be used. - """ - - class Person(Document): - name = StringField(primary_key=True) - meta = {'collection': 'app'} - - user = Person(name="Test User") - user.save() - - user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") - - Person.drop_collection() - - def test_inherited_collections(self): - """Ensure that subclassed documents don't override parents' collections. - """ - - class Drink(Document): - name = StringField() - meta = {'allow_inheritance': True} - - class Drinker(Document): - drink = GenericReferenceField() - - try: - warnings.simplefilter("error") - - class AcloholicDrink(Drink): - meta = {'collection': 'booze'} - - except SyntaxWarning, w: - warnings.simplefilter("ignore") - - class AlcoholicDrink(Drink): - meta = {'collection': 'booze'} - - else: - raise AssertionError("SyntaxWarning should be triggered") - - warnings.resetwarnings() - - Drink.drop_collection() - AlcoholicDrink.drop_collection() - Drinker.drop_collection() - - red_bull = Drink(name='Red Bull') - red_bull.save() - - programmer = Drinker(drink=red_bull) - programmer.save() - - beer = AlcoholicDrink(name='Beer') - beer.save() - real_person = Drinker(drink=beer) - real_person.save() - - self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) - self.assertEqual(Drinker.objects[1].drink.name, beer.name) - - def test_capped_collection(self): - """Ensure that capped collections work properly. - """ - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - 'max_size': 90000, - } - - Log.drop_collection() - - # Ensure that the collection handles up to its maximum - for i in range(10): - Log().save() - - self.assertEqual(len(Log.objects), 10) - - # Check that extra documents don't increase the size - Log().save() - self.assertEqual(len(Log.objects), 10) - - options = Log.objects._collection.options() - self.assertEqual(options['capped'], True) - self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 90000) - - # Check that the document cannot be redefined with different options - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 11, - } - # Create the collection by accessing Document.objects - Log.objects - self.assertRaises(InvalidCollectionError, recreate_log_document) - - Log.drop_collection() - - def test_indexes(self): - """Ensure that indexes are used when meta[indexes] is specified. - """ - class BlogPost(Document): - date = DateTimeField(db_field='addDate', default=datetime.now) - category = StringField() - tags = ListField(StringField()) - meta = { - 'indexes': [ - '-date', - 'tags', - ('category', '-date') - ], - 'allow_inheritance': True - } - - self.assertEqual(BlogPost._meta['index_specs'], - [{'fields': [('_types', 1), ('addDate', -1)]}, - {'fields': [('tags', 1)]}, - {'fields': [('_types', 1), ('category', 1), - ('addDate', -1)]}]) - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date', 'tags', ('cat', 'date') - # NB: there is no index on _types by itself, since - # the indices on -date and tags will both contain - # _types as first element in the key - self.assertEqual(len(info), 4) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info) - self.assertTrue([('_types', 1), ('addDate', -1)] in info) - # tags is a list field so it shouldn't have _types in the index - self.assertTrue([('tags', 1)] in info) - - class ExtendedBlogPost(BlogPost): - title = StringField() - meta = {'indexes': ['title']} - - self.assertEqual(ExtendedBlogPost._meta['index_specs'], - [{'fields': [('_types', 1), ('addDate', -1)]}, - {'fields': [('tags', 1)]}, - {'fields': [('_types', 1), ('category', 1), - ('addDate', -1)]}, - {'fields': [('_types', 1), ('title', 1)]}]) - - BlogPost.drop_collection() - - list(ExtendedBlogPost.objects) - info = ExtendedBlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info) - self.assertTrue([('_types', 1), ('addDate', -1)] in info) - self.assertTrue([('_types', 1), ('title', 1)] in info) - - BlogPost.drop_collection() - - def test_inherited_index(self): - """Ensure index specs are inhertited correctly""" - - class A(Document): - title = StringField() - meta = { - 'indexes': [ - { - 'fields': ('title',), - }, - ], - 'allow_inheritance': True, - } - - class B(A): - description = StringField() - - self.assertEqual(A._meta['index_specs'], B._meta['index_specs']) - self.assertEqual([{'fields': [('_types', 1), ('title', 1)]}], - A._meta['index_specs']) - - def test_build_index_spec_is_not_destructive(self): - - class MyDoc(Document): - keywords = StringField() - - meta = { - 'indexes': ['keywords'], - 'allow_inheritance': False - } - - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) - - # Force index creation - MyDoc.objects._ensure_indexes() - - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) - - def test_db_field_load(self): - """Ensure we load data correctly - """ - class Person(Document): - name = StringField(required=True) - _rank = StringField(required=False, db_field="rank") - - @property - def rank(self): - return self._rank or "Private" - - Person.drop_collection() - - Person(name="Jack", _rank="Corporal").save() - - Person(name="Fred").save() - - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") - - def test_db_embedded_doc_field_load(self): - """Ensure we load embedded document data correctly - """ - class Rank(EmbeddedDocument): - title = StringField(required=True) - - class Person(Document): - name = StringField(required=True) - rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank') - - @property - def rank(self): - return self.rank_.title if self.rank_ is not None else "Private" - - Person.drop_collection() - - Person(name="Jack", rank_=Rank(title="Corporal")).save() - - Person(name="Fred").save() - - self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") - self.assertEqual(Person.objects.get(name="Fred").rank, "Private") - - def test_embedded_document_index_meta(self): - """Ensure that embedded document indexes are created explicitly - """ - class Rank(EmbeddedDocument): - title = StringField(required=True) - - class Person(Document): - name = StringField(required=True) - rank = EmbeddedDocumentField(Rank, required=False) - - meta = { - 'indexes': [ - 'rank.title', - ], - 'allow_inheritance': False - } - - self.assertEqual([{'fields': [('rank.title', 1)]}], - Person._meta['index_specs']) - - Person.drop_collection() - - # Indexes are lazy so use list() to perform query - list(Person.objects) - info = Person.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('rank.title', 1)] in info) - - def test_explicit_geo2d_index(self): - """Ensure that geo2d indexes work when created via meta[indexes] - """ - class Place(Document): - location = DictField() - meta = { - 'indexes': [ - '*location.point', - ], - } - - self.assertEqual([{'fields': [('location.point', '2d')]}], - Place._meta['index_specs']) - - Place.drop_collection() - - info = Place.objects._collection.index_information() - # Indexes are lazy so use list() to perform query - list(Place.objects) - info = Place.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - - self.assertTrue([('location.point', '2d')] in info) - - def test_dictionary_indexes(self): - """Ensure that indexes are used when meta[indexes] contains dictionaries - instead of lists. - """ - class BlogPost(Document): - date = DateTimeField(db_field='addDate', default=datetime.now) - category = StringField() - tags = ListField(StringField()) - meta = { - 'indexes': [ - {'fields': ['-date'], 'unique': True, - 'sparse': True, 'types': False }, - ], - } - - self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, - 'sparse': True, 'types': False}], - BlogPost._meta['index_specs']) - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date' - self.assertEqual(len(info), 3) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [(value['key'], - value.get('unique', False), - value.get('sparse', False)) - for key, value in info.iteritems()] - self.assertTrue(([('addDate', -1)], True, True) in info) - - BlogPost.drop_collection() - - def test_abstract_index_inheritance(self): - - class UserBase(Document): - meta = { - 'abstract': True, - 'indexes': ['user_guid'] - } - - user_guid = StringField(required=True) - - class Person(UserBase): - meta = { - 'indexes': ['name'], - } - - name = StringField() - - Person.drop_collection() - - p = Person(name="test", user_guid='123') - p.save() - - self.assertEqual(1, Person.objects.count()) - info = Person.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ['_id_', '_types_1_name_1', '_types_1_user_guid_1']) - Person.drop_collection() - - def test_disable_index_creation(self): - """Tests setting auto_create_index to False on the connection will - disable any index generation. - """ - class User(Document): - meta = { - 'indexes': ['user_guid'], - 'auto_create_index': False - } - user_guid = StringField(required=True) - - - User.drop_collection() - - u = User(user_guid='123') - u.save() - - self.assertEqual(1, User.objects.count()) - info = User.objects._collection.index_information() - self.assertEqual(info.keys(), ['_id_']) - User.drop_collection() - - def test_embedded_document_index(self): - """Tests settings an index on an embedded document - """ - class Date(EmbeddedDocument): - year = IntField(db_field='yr') - - class BlogPost(Document): - title = StringField() - date = EmbeddedDocumentField(Date) - - meta = { - 'indexes': [ - '-date.year' - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), [ '_id_', '_types_1_date.yr_-1']) - BlogPost.drop_collection() - - def test_list_embedded_document_index(self): - """Ensure list embedded documents can be indexed - """ - class Tag(EmbeddedDocument): - name = StringField(db_field='tag') - - class BlogPost(Document): - title = StringField() - tags = ListField(EmbeddedDocumentField(Tag)) - - meta = { - 'indexes': [ - 'tags.name' - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # we don't use _types in with list fields by default - self.assertEqual(sorted(info.keys()), - ['_id_', '_types_1', 'tags.tag_1']) - - post1 = BlogPost(title="Embedded Indexes tests in place", - tags=[Tag(name="about"), Tag(name="time")] - ) - post1.save() - BlogPost.drop_collection() - - def test_recursive_embedded_objects_dont_break_indexes(self): - - class RecursiveObject(EmbeddedDocument): - obj = EmbeddedDocumentField('self') - - class RecursiveDocument(Document): - recursive_obj = EmbeddedDocumentField(RecursiveObject) - - info = RecursiveDocument.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ['_id_', '_types_1']) - - def test_geo_indexes_recursion(self): - - class Location(Document): - name = StringField() - location = GeoPointField() - - class Parent(Document): - name = StringField() - location = ReferenceField(Location) - - Location.drop_collection() - Parent.drop_collection() - - list(Parent.objects) - - collection = Parent._get_collection() - info = collection.index_information() - - self.assertFalse('location_2d' in info) - - self.assertEqual(len(Parent._geo_indices()), 0) - self.assertEqual(len(Location._geo_indices()), 1) - - def test_covered_index(self): - """Ensure that covered indexes can be used - """ - - class Test(Document): - a = IntField() - - meta = { - 'indexes': ['a'], - 'allow_inheritance': False - } - - Test.drop_collection() - - obj = Test(a=1) - obj.save() - - # Need to be explicit about covered indexes as mongoDB doesn't know if - # the documents returned might have more keys in that here. - query_plan = Test.objects(id=obj.id).exclude('a').explain() - self.assertFalse(query_plan['indexOnly']) - - query_plan = Test.objects(id=obj.id).only('id').explain() - self.assertTrue(query_plan['indexOnly']) - - query_plan = Test.objects(a=1).only('a').exclude('id').explain() - self.assertTrue(query_plan['indexOnly']) - - def test_index_on_id(self): - - class BlogPost(Document): - meta = { - 'indexes': [ - ['categories', 'id'] - ], - 'allow_inheritance': False - } - - title = StringField(required=True) - description = StringField(required=True) - categories = ListField() - - BlogPost.drop_collection() - - indexes = BlogPost.objects._collection.index_information() - self.assertEqual(indexes['categories_1__id_1']['key'], - [('categories', 1), ('_id', 1)]) - - def test_hint(self): - - class BlogPost(Document): - tags = ListField(StringField()) - meta = { - 'indexes': [ - 'tags', - ], - } - - BlogPost.drop_collection() - - for i in xrange(0, 10): - tags = [("tag %i" % n) for n in xrange(0, i % 2)] - BlogPost(tags=tags).save() - - self.assertEqual(BlogPost.objects.count(), 10) - self.assertEqual(BlogPost.objects.hint().count(), 10) - self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) - - self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) - - def invalid_index(): - BlogPost.objects.hint('tags') - self.assertRaises(TypeError, invalid_index) - - def invalid_index_2(): - return BlogPost.objects.hint(('tags', 1)) - self.assertRaises(TypeError, invalid_index_2) - - def test_unique(self): - """Ensure that uniqueness constraints are applied to fields. - """ - class BlogPost(Document): - title = StringField() - slug = StringField(unique=True) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', slug='test') - post1.save() - - # Two posts with the same slug is not allowed - post2 = BlogPost(title='test2', slug='test') - self.assertRaises(NotUniqueError, post2.save) - - # Ensure backwards compatibilty for errors - self.assertRaises(OperationError, post2.save) - - def test_unique_with(self): - """Ensure that unique_with constraints are applied to fields. - """ - class Date(EmbeddedDocument): - year = IntField(db_field='yr') - - class BlogPost(Document): - title = StringField() - date = EmbeddedDocumentField(Date) - slug = StringField(unique_with='date.year') - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', date=Date(year=2009), slug='test') - post1.save() - - # day is different so won't raise exception - post2 = BlogPost(title='test2', date=Date(year=2010), slug='test') - post2.save() - - # Now there will be two docs with the same slug and the same day: fail - post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') - self.assertRaises(OperationError, post3.save) - - BlogPost.drop_collection() - - def test_unique_embedded_document(self): - """Ensure that uniqueness constraints are applied to fields on embedded documents. - """ - class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') - slug = StringField(unique=True) - - class BlogPost(Document): - title = StringField() - sub = EmbeddedDocumentField(SubDocument) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) - post1.save() - - # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) - post2.save() - - # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) - self.assertRaises(NotUniqueError, post3.save) - - BlogPost.drop_collection() - - def test_unique_with_embedded_document_and_embedded_unique(self): - """Ensure that uniqueness constraints are applied to fields on - embedded documents. And work with unique_with as well. - """ - class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') - slug = StringField(unique=True) - - class BlogPost(Document): - title = StringField(unique_with='sub.year') - sub = EmbeddedDocumentField(SubDocument) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) - post1.save() - - # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) - post2.save() - - # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) - self.assertRaises(NotUniqueError, post3.save) - - # Now there will be two docs with the same title and year - post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) - self.assertRaises(NotUniqueError, post3.save) - - BlogPost.drop_collection() - - def test_ttl_indexes(self): - - class Log(Document): - created = DateTimeField(default=datetime.now) - meta = { - 'indexes': [ - {'fields': ['created'], 'expireAfterSeconds': 3600} - ] - } - - Log.drop_collection() - - if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: - raise SkipTest('pymongo needs to be 2.3 or higher for this test') - - connection = get_connection() - version_array = connection.server_info()['versionArray'] - if version_array[0] < 2 and version_array[1] < 2: - raise SkipTest('MongoDB needs to be 2.2 or higher for this test') - - # Indexes are lazy so use list() to perform query - list(Log.objects) - info = Log.objects._collection.index_information() - self.assertEqual(3600, - info['_types_1_created_1']['expireAfterSeconds']) - - def test_unique_and_indexes(self): - """Ensure that 'unique' constraints aren't overridden by - meta.indexes. - """ - class Customer(Document): - cust_id = IntField(unique=True, required=True) - meta = { - 'indexes': ['cust_id'], - 'allow_inheritance': False, - } - - Customer.drop_collection() - cust = Customer(cust_id=1) - cust.save() - - cust_dupe = Customer(cust_id=1) - try: - cust_dupe.save() - raise AssertionError, "We saved a dupe!" - except NotUniqueError: - pass - Customer.drop_collection() - - def test_unique_and_primary(self): - """If you set a field as primary, then unexpected behaviour can occur. - You won't create a duplicate but you will update an existing document. - """ - - class User(Document): - name = StringField(primary_key=True, unique=True) - password = StringField() - - User.drop_collection() - - user = User(name='huangz', password='secret') - user.save() - - user = User(name='huangz', password='secret2') - user.save() - - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, 'secret2') - - User.drop_collection() - - def test_custom_id_field(self): - """Ensure that documents may be created with custom primary keys. - """ - class User(Document): - username = StringField(primary_key=True) - name = StringField() - - meta = {'allow_inheritance': True} - - User.drop_collection() - - self.assertEqual(User._fields['username'].db_field, '_id') - self.assertEqual(User._meta['id_field'], 'username') - - def create_invalid_user(): - User(name='test').save() # no primary key field - self.assertRaises(ValidationError, create_invalid_user) - - def define_invalid_user(): - class EmailUser(User): - email = StringField(primary_key=True) - self.assertRaises(ValueError, define_invalid_user) - - class EmailUser(User): - email = StringField() - - user = User(username='test', name='test user') - user.save() - - user_obj = User.objects.first() - self.assertEqual(user_obj.id, 'test') - self.assertEqual(user_obj.pk, 'test') - - user_son = User.objects._collection.find_one() - self.assertEqual(user_son['_id'], 'test') - self.assertTrue('username' not in user_son['_id']) - - User.drop_collection() - - user = User(pk='mongo', name='mongo user') - user.save() - - user_obj = User.objects.first() - self.assertEqual(user_obj.id, 'mongo') - self.assertEqual(user_obj.pk, 'mongo') - - user_son = User.objects._collection.find_one() - self.assertEqual(user_son['_id'], 'mongo') - self.assertTrue('username' not in user_son['_id']) - - User.drop_collection() - - def test_document_not_registered(self): - - class Place(Document): - name = StringField() - - meta = {'allow_inheritance': True} - - class NicePlace(Place): - pass - - Place.drop_collection() - - Place(name="London").save() - NicePlace(name="Buckingham Palace").save() - - # Mimic Place and NicePlace definitions being in a different file - # and the NicePlace model not being imported in at query time. - from mongoengine.base import _document_registry - del(_document_registry['Place.NicePlace']) - - def query_without_importing_nice_place(): - print Place.objects.all() - self.assertRaises(NotRegistered, query_without_importing_nice_place) - - def test_document_registry_regressions(self): - - class Location(Document): - name = StringField() - meta = {'allow_inheritance': True} - - class Area(Location): - location = ReferenceField('Location', dbref=True) - - Location.drop_collection() - - self.assertEquals(Area, get_document("Area")) - self.assertEquals(Area, get_document("Location.Area")) - - def test_creation(self): - """Ensure that document may be created using keyword arguments. - """ - person = self.Person(name="Test User", age=30) - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 30) - - def test_to_dbref(self): - """Ensure that you can get a dbref of a document""" - person = self.Person(name="Test User", age=30) - self.assertRaises(OperationError, person.to_dbref) - person.save() - - person.to_dbref() - - def test_reload(self): - """Ensure that attributes may be reloaded. - """ - person = self.Person(name="Test User", age=20) - person.save() - - person_obj = self.Person.objects.first() - person_obj.name = "Mr Test User" - person_obj.age = 21 - person_obj.save() - - self.assertEqual(person.name, "Test User") - self.assertEqual(person.age, 20) - - person.reload() - self.assertEqual(person.name, "Mr Test User") - self.assertEqual(person.age, 21) - - def test_reload_sharded(self): - class Animal(Document): - superphylum = StringField() - meta = {'shard_key': ('superphylum',)} - - Animal.drop_collection() - doc = Animal(superphylum = 'Deuterostomia') - doc.save() - doc.reload() - Animal.drop_collection() - - def test_reload_referencing(self): - """Ensures reloading updates weakrefs correctly - """ - class Embedded(EmbeddedDocument): - dict_field = DictField() - list_field = ListField() - - class Doc(Document): - dict_field = DictField() - list_field = ListField() - embedded_field = EmbeddedDocumentField(Embedded) - - Doc.drop_collection() - doc = Doc() - doc.dict_field = {'hello': 'world'} - doc.list_field = ['1', 2, {'hello': 'world'}] - - embedded_1 = Embedded() - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - doc.save() - - doc = doc.reload(10) - doc.list_field.append(1) - doc.dict_field['woot'] = "woot" - doc.embedded_field.list_field.append(1) - doc.embedded_field.dict_field['woot'] = "woot" - - self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) - doc.save() - - doc = doc.reload(10) - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(len(doc.list_field), 4) - self.assertEqual(len(doc.dict_field), 2) - self.assertEqual(len(doc.embedded_field.list_field), 4) - self.assertEqual(len(doc.embedded_field.dict_field), 2) - - def test_dictionary_access(self): - """Ensure that dictionary-style field access works properly. - """ - person = self.Person(name='Test User', age=30) - self.assertEqual(person['name'], 'Test User') - - self.assertRaises(KeyError, person.__getitem__, 'salary') - self.assertRaises(KeyError, person.__setitem__, 'salary', 50) - - person['name'] = 'Another User' - self.assertEqual(person['name'], 'Another User') - - # Length = length(assigned fields + id) - self.assertEqual(len(person), 3) - - self.assertTrue('age' in person) - person.age = None - self.assertFalse('age' in person) - self.assertFalse('nationality' in person) - - def test_embedded_document(self): - """Ensure that embedded documents are set up correctly. - """ - class Comment(EmbeddedDocument): - content = StringField() - - self.assertTrue('content' in Comment._fields) - self.assertFalse('id' in Comment._fields) - - def test_embedded_document_validation(self): - """Ensure that embedded documents may be validated. - """ - class Comment(EmbeddedDocument): - date = DateTimeField() - content = StringField(required=True) - - comment = Comment() - self.assertRaises(ValidationError, comment.validate) - - comment.content = 'test' - comment.validate() - - comment.date = 4 - self.assertRaises(ValidationError, comment.validate) - - comment.date = datetime.now() - comment.validate() - - def test_embedded_db_field_validate(self): - - class SubDoc(EmbeddedDocument): - val = IntField() - - class Doc(Document): - e = EmbeddedDocumentField(SubDoc, db_field='eb') - - Doc.drop_collection() - - Doc(e=SubDoc(val=15)).save() - - doc = Doc.objects.first() - doc.validate() - keys = doc._data.keys() - self.assertEqual(2, len(keys)) - self.assertTrue(None in keys) - self.assertTrue('e' in keys) - - def test_save(self): - """Ensure that a document may be saved in the database. - """ - # Create person object and save it to the database - person = self.Person(name='Test User', age=30) - person.save() - # Ensure that the object is in the database - collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(person_obj['name'], 'Test User') - self.assertEqual(person_obj['age'], 30) - self.assertEqual(person_obj['_id'], person.id) - # Test skipping validation on save - class Recipient(Document): - email = EmailField(required=True) - - recipient = Recipient(email='root@localhost') - self.assertRaises(ValidationError, recipient.save) - try: - recipient.save(validate=False) - except ValidationError: - self.fail() - - def test_save_to_a_value_that_equates_to_false(self): - - class Thing(EmbeddedDocument): - count = IntField() - - class User(Document): - thing = EmbeddedDocumentField(Thing) - - User.drop_collection() - - user = User(thing=Thing(count=1)) - user.save() - user.reload() - - user.thing.count = 0 - user.save() - - user.reload() - self.assertEqual(user.thing.count, 0) - - def test_save_max_recursion_not_hit(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - friend = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save() - - p1.friend = p2 - p1.save() - - # Confirm can save and it resets the changed fields without hitting - # max recursion error - p0 = Person.objects.first() - p0.name = 'wpjunior' - p0.save() - - def test_save_max_recursion_not_hit_with_file_field(self): - - class Foo(Document): - name = StringField() - picture = FileField() - bar = ReferenceField('self') - - Foo.drop_collection() - - a = Foo(name='hello') - a.save() - - a.bar = a - with open(TEST_IMAGE_PATH, 'rb') as test_image: - a.picture = test_image - a.save() - - # Confirm can save and it resets the changed fields without hitting - # max recursion error - b = Foo.objects.with_id(a.id) - b.name='world' - b.save() - - self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) - - def test_save_cascades(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save() - - p = Person.objects(name="Wilson Jr").get() - p.parent.name = "Daddy Wilson" - p.save() - - p1.reload() - self.assertEqual(p1.name, p.parent.name) - - def test_save_cascade_kwargs(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save(force_insert=True, cascade_kwargs={"force_insert": False}) - - p = Person.objects(name="Wilson Jr").get() - p.parent.name = "Daddy Wilson" - p.save() - - p1.reload() - self.assertEqual(p1.name, p.parent.name) - - def test_save_cascade_meta(self): - - class Person(Document): - name = StringField() - parent = ReferenceField('self') - - meta = {'cascade': False} - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.parent = None - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save() - - p = Person.objects(name="Wilson Jr").get() - p.parent.name = "Daddy Wilson" - p.save() - - p1.reload() - self.assertNotEqual(p1.name, p.parent.name) - - p.save(cascade=True) - p1.reload() - self.assertEqual(p1.name, p.parent.name) - - def test_save_cascades_generically(self): - - class Person(Document): - name = StringField() - parent = GenericReferenceField() - - Person.drop_collection() - - p1 = Person(name="Wilson Snr") - p1.save() - - p2 = Person(name="Wilson Jr") - p2.parent = p1 - p2.save() - - p = Person.objects(name="Wilson Jr").get() - p.parent.name = "Daddy Wilson" - p.save() - - p1.reload() - self.assertEqual(p1.name, p.parent.name) - - def test_update(self): - """Ensure that an existing document is updated instead of be overwritten. - """ - # Create person object and save it to the database - person = self.Person(name='Test User', age=30) - person.save() - - # Create same person object, with same id, without age - same_person = self.Person(name='Test') - same_person.id = person.id - same_person.save() - - # Confirm only one object - self.assertEqual(self.Person.objects.count(), 1) - - # reload - person.reload() - same_person.reload() - - # Confirm the same - self.assertEqual(person, same_person) - self.assertEqual(person.name, same_person.name) - self.assertEqual(person.age, same_person.age) - - # Confirm the saved values - self.assertEqual(person.name, 'Test') - self.assertEqual(person.age, 30) - - # Test only / exclude only updates included fields - person = self.Person.objects.only('name').get() - person.name = 'User' - person.save() - - person.reload() - self.assertEqual(person.name, 'User') - self.assertEqual(person.age, 30) - - # test exclude only updates set fields - person = self.Person.objects.exclude('name').get() - person.age = 21 - person.save() - - person.reload() - self.assertEqual(person.name, 'User') - self.assertEqual(person.age, 21) - - # Test only / exclude can set non excluded / included fields - person = self.Person.objects.only('name').get() - person.name = 'Test' - person.age = 30 - person.save() - - person.reload() - self.assertEqual(person.name, 'Test') - self.assertEqual(person.age, 30) - - # test exclude only updates set fields - person = self.Person.objects.exclude('name').get() - person.name = 'User' - person.age = 21 - person.save() - - person.reload() - self.assertEqual(person.name, 'User') - self.assertEqual(person.age, 21) - - # Confirm does remove unrequired fields - person = self.Person.objects.exclude('name').get() - person.age = None - person.save() - - person.reload() - self.assertEqual(person.name, 'User') - self.assertEqual(person.age, None) - - person = self.Person.objects.get() - person.name = None - person.age = None - person.save() - - person.reload() - self.assertEqual(person.name, None) - self.assertEqual(person.age, None) - - def test_can_save_if_not_included(self): - - class EmbeddedDoc(EmbeddedDocument): - pass - - class Simple(Document): - pass - - class Doc(Document): - string_field = StringField(default='1') - int_field = IntField(default=1) - float_field = FloatField(default=1.1) - boolean_field = BooleanField(default=True) - datetime_field = DateTimeField(default=datetime.now) - embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc()) - list_field = ListField(default=lambda: [1, 2, 3]) - dict_field = DictField(default=lambda: {"hello": "world"}) - objectid_field = ObjectIdField(default=bson.ObjectId) - reference_field = ReferenceField(Simple, default=lambda: Simple().save()) - map_field = MapField(IntField(), default=lambda: {"simple": 1}) - decimal_field = DecimalField(default=1.0) - complex_datetime_field = ComplexDateTimeField(default=datetime.now) - url_field = URLField(default="http://mongoengine.org") - dynamic_field = DynamicField(default=1) - generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) - sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) - email_field = EmailField(default="ross@example.com") - geo_point_field = GeoPointField(default=lambda: [1, 2]) - sequence_field = SequenceField() - uuid_field = UUIDField(default=uuid.uuid4) - generic_embedded_document_field = GenericEmbeddedDocumentField(default=lambda: EmbeddedDoc()) - - - Simple.drop_collection() - Doc.drop_collection() - - Doc().save() - - my_doc = Doc.objects.only("string_field").first() - my_doc.string_field = "string" - my_doc.save() - - my_doc = Doc.objects.get(string_field="string") - self.assertEqual(my_doc.string_field, "string") - self.assertEqual(my_doc.int_field, 1) - - def test_document_update(self): - - def update_not_saved_raises(): - person = self.Person(name='dcrosta') - person.update(set__name='Dan Crosta') - - self.assertRaises(OperationError, update_not_saved_raises) - - author = self.Person(name='dcrosta') - author.save() - - author.update(set__name='Dan Crosta') - author.reload() - - p1 = self.Person.objects.first() - self.assertEqual(p1.name, author.name) - - def update_no_value_raises(): - person = self.Person.objects.first() - person.update() - - self.assertRaises(OperationError, update_no_value_raises) - - def update_no_op_raises(): - person = self.Person.objects.first() - person.update(name="Dan") - - self.assertRaises(InvalidQueryError, update_no_op_raises) - - def test_embedded_update(self): - """ - Test update on `EmbeddedDocumentField` fields - """ - - class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - required=True) - - class Site(Document): - page = EmbeddedDocumentField(Page) - - - Site.drop_collection() - site = Site(page=Page(log_message="Warning: Dummy message")) - site.save() - - # Update - site = Site.objects.first() - site.page.log_message = "Error: Dummy message" - site.save() - - site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") - - def test_embedded_update_db_field(self): - """ - Test update on `EmbeddedDocumentField` fields when db_field is other - than default. - """ - - class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - db_field="page_log_message", - required=True) - - class Site(Document): - page = EmbeddedDocumentField(Page) - - - Site.drop_collection() - - site = Site(page=Page(log_message="Warning: Dummy message")) - site.save() - - # Update - site = Site.objects.first() - site.page.log_message = "Error: Dummy message" - site.save() - - site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") - - def test_circular_reference_deltas(self): - - class Person(Document): - name = StringField() - owns = ListField(ReferenceField('Organization')) - - class Organization(Document): - name = StringField() - owner = ReferenceField('Person') - - Person.drop_collection() - Organization.drop_collection() - - person = Person(name="owner") - person.save() - organization = Organization(name="company") - organization.save() - - person.owns.append(organization) - organization.owner = person - - person.save() - organization.save() - - p = Person.objects[0].select_related() - o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - - def test_circular_reference_deltas_2(self): - - class Person(Document): - name = StringField() - owns = ListField( ReferenceField( 'Organization' ) ) - employer = ReferenceField( 'Organization' ) - - class Organization( Document ): - name = StringField() - owner = ReferenceField( 'Person' ) - employees = ListField( ReferenceField( 'Person' ) ) - - Person.drop_collection() - Organization.drop_collection() - - person = Person( name="owner" ) - person.save() - - employee = Person( name="employee" ) - employee.save() - - organization = Organization( name="company" ) - organization.save() - - person.owns.append( organization ) - organization.owner = person - - organization.employees.append( employee ) - employee.employer = organization - - person.save() - organization.save() - employee.save() - - p = Person.objects.get(name="owner") - e = Person.objects.get(name="employee") - o = Organization.objects.first() - - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - self.assertEqual(e.employer, o) - - def test_delta(self): - - class Doc(Document): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - - def test_delta_recursive(self): - - class Embedded(EmbeddedDocument): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - - class Doc(Document): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - embedded_field = EmbeddedDocumentField(Embedded) - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) - - embedded_delta = { - 'string_field': 'hello', - 'int_field': 1, - 'dict_field': {'hello': 'world'}, - 'list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'embedded_field': embedded_delta}, {})) - - doc.save() - doc = doc.reload(10) - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['embedded_field.dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field.2.string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc = doc.reload(10) - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) - doc.save() - doc = doc.reload(10) - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) - - doc.save() - doc = doc.reload(10) - - doc.dict_field['Embedded'] = embedded_1 - doc.save() - doc = doc.reload(10) - - doc.dict_field['Embedded'].string_field = 'Hello World' - self.assertEqual(doc._get_changed_fields(), ['dict_field.Embedded.string_field']) - self.assertEqual(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) - - - def test_delta_db_field(self): - - class Doc(Document): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['db_string_field']) - self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['db_int_field']) - self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) - - # Test it saves that data - doc = Doc() - doc.save() - - doc.string_field = 'hello' - doc.int_field = 1 - doc.dict_field = {'hello': 'world'} - doc.list_field = ['1', 2, {'hello': 'world'}] - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.string_field, 'hello') - self.assertEqual(doc.int_field, 1) - self.assertEqual(doc.dict_field, {'hello': 'world'}) - self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) - - def test_delta_recursive_db_field(self): - - class Embedded(EmbeddedDocument): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - - class Doc(Document): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - embedded_field = EmbeddedDocumentField(Embedded, db_field='db_embedded_field') - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) - - embedded_delta = { - 'db_string_field': 'hello', - 'db_int_field': 1, - 'db_dict_field': {'hello': 'world'}, - 'db_list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) - - doc.save() - doc = doc.reload(10) - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field.2.db_string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc = doc.reload(10) - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) - doc.save() - doc = doc.reload(10) - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) - - def test_save_only_changed_fields(self): - """Ensure save only sets / unsets changed fields - """ - - class User(self.Person): - active = BooleanField(default=True) - - - User.drop_collection() - - # Create person object and save it to the database - user = User(name='Test User', age=30, active=True) - user.save() - user.reload() - - # Simulated Race condition - same_person = self.Person.objects.get() - same_person.active = False - - user.age = 21 - user.save() - - same_person.name = 'User' - same_person.save() - - person = self.Person.objects.get() - self.assertEqual(person.name, 'User') - self.assertEqual(person.age, 21) - self.assertEqual(person.active, False) - - def test_save_only_changed_fields_recursive(self): - """Ensure save only sets / unsets changed fields - """ - - class Comment(EmbeddedDocument): - published = BooleanField(default=True) - - class User(self.Person): - comments_dict = DictField() - comments = ListField(EmbeddedDocumentField(Comment)) - active = BooleanField(default=True) - - User.drop_collection() - - # Create person object and save it to the database - person = User(name='Test User', age=30, active=True) - person.comments.append(Comment()) - person.save() - person.reload() - - person = self.Person.objects.get() - self.assertTrue(person.comments[0].published) - - person.comments[0].published = False - person.save() - - person = self.Person.objects.get() - self.assertFalse(person.comments[0].published) - - # Simple dict w - person.comments_dict['first_post'] = Comment() - person.save() - - person = self.Person.objects.get() - self.assertTrue(person.comments_dict['first_post'].published) - - person.comments_dict['first_post'].published = False - person.save() - - person = self.Person.objects.get() - self.assertFalse(person.comments_dict['first_post'].published) - - def test_delete(self): - """Ensure that document may be deleted using the delete method. - """ - person = self.Person(name="Test User", age=30) - person.save() - self.assertEqual(len(self.Person.objects), 1) - person.delete() - self.assertEqual(len(self.Person.objects), 0) - - def test_save_custom_id(self): - """Ensure that a document may be saved with a custom _id. - """ - # Create person object and save it to the database - person = self.Person(name='Test User', age=30, - id='497ce96f395f2f052a494fd4') - person.save() - # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') - - def test_save_custom_pk(self): - """Ensure that a document may be saved with a custom _id using pk alias. - """ - # Create person object and save it to the database - person = self.Person(name='Test User', age=30, - pk='497ce96f395f2f052a494fd4') - person.save() - # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') - - def test_save_list(self): - """Ensure that a list field may be properly saved. - """ - class Comment(EmbeddedDocument): - content = StringField() - - class BlogPost(Document): - content = StringField() - comments = ListField(EmbeddedDocumentField(Comment)) - tags = ListField(StringField()) - - BlogPost.drop_collection() - - post = BlogPost(content='Went for a walk today...') - post.tags = tags = ['fun', 'leisure'] - comments = [Comment(content='Good for you'), Comment(content='Yay.')] - post.comments = comments - post.save() - - collection = self.db[BlogPost._get_collection_name()] - post_obj = collection.find_one() - self.assertEqual(post_obj['tags'], tags) - for comment_obj, comment in zip(post_obj['comments'], comments): - self.assertEqual(comment_obj['content'], comment['content']) - - BlogPost.drop_collection() - - def test_list_search_by_embedded(self): - class User(Document): - username = StringField(required=True) - - meta = {'allow_inheritance': False} - - class Comment(EmbeddedDocument): - comment = StringField() - user = ReferenceField(User, - required=True) - - meta = {'allow_inheritance': False} - - class Page(Document): - comments = ListField(EmbeddedDocumentField(Comment)) - meta = {'allow_inheritance': False, - 'indexes': [ - {'fields': ['comments.user']} - ]} - - User.drop_collection() - Page.drop_collection() - - u1 = User(username="wilson") - u1.save() - - u2 = User(username="rozza") - u2.save() - - u3 = User(username="hmarr") - u3.save() - - p1 = Page(comments = [Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world"), - Comment(user=u3, comment="Ping Pong"), - Comment(user=u1, comment="I like a beer")]) - p1.save() - - p2 = Page(comments = [Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world")]) - p2.save() - - p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) - p3.save() - - p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) - p4.save() - - self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) - self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) - self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) - - def test_save_embedded_document(self): - """Ensure that a document with an embedded document field may be - saved in the database. - """ - class EmployeeDetails(EmbeddedDocument): - position = StringField() - - class Employee(self.Person): - salary = IntField() - details = EmbeddedDocumentField(EmployeeDetails) - - # Create employee object and save it to the database - employee = Employee(name='Test Employee', age=50, salary=20000) - employee.details = EmployeeDetails(position='Developer') - employee.save() - - # Ensure that the object is in the database - collection = self.db[self.Person._get_collection_name()] - employee_obj = collection.find_one({'name': 'Test Employee'}) - self.assertEqual(employee_obj['name'], 'Test Employee') - self.assertEqual(employee_obj['age'], 50) - # Ensure that the 'details' embedded object saved correctly - self.assertEqual(employee_obj['details']['position'], 'Developer') - - def test_embedded_update_after_save(self): - """ - Test update of `EmbeddedDocumentField` attached to a newly saved - document. - """ - class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - required=True) - - class Site(Document): - page = EmbeddedDocumentField(Page) - - - Site.drop_collection() - site = Site(page=Page(log_message="Warning: Dummy message")) - site.save() - - # Update - site.page.log_message = "Error: Dummy message" - site.save() - - site = Site.objects.first() - self.assertEqual(site.page.log_message, "Error: Dummy message") - - def test_updating_an_embedded_document(self): - """Ensure that a document with an embedded document field may be - saved in the database. - """ - class EmployeeDetails(EmbeddedDocument): - position = StringField() - - class Employee(self.Person): - salary = IntField() - details = EmbeddedDocumentField(EmployeeDetails) - - # Create employee object and save it to the database - employee = Employee(name='Test Employee', age=50, salary=20000) - employee.details = EmployeeDetails(position='Developer') - employee.save() - - # Test updating an embedded document - promoted_employee = Employee.objects.get(name='Test Employee') - promoted_employee.details.position = 'Senior Developer' - promoted_employee.save() - - promoted_employee.reload() - self.assertEqual(promoted_employee.name, 'Test Employee') - self.assertEqual(promoted_employee.age, 50) - - # Ensure that the 'details' embedded object saved correctly - self.assertEqual(promoted_employee.details.position, 'Senior Developer') - - # Test removal - promoted_employee.details = None - promoted_employee.save() - - promoted_employee.reload() - self.assertEqual(promoted_employee.details, None) - - def test_mixins_dont_add_to_types(self): - - class Mixin(object): - name = StringField() - - class Person(Document, Mixin): - pass - - Person.drop_collection() - - self.assertEqual(sorted(Person._fields.keys()), ['id', 'name']) - - Person(name="Rozza").save() - - collection = self.db[Person._get_collection_name()] - obj = collection.find_one() - self.assertEqual(obj['_cls'], 'Person') - self.assertEqual(obj['_types'], ['Person']) - - self.assertEqual(Person.objects.count(), 1) - - Person.drop_collection() - - def test_object_mixins(self): - - class NameMixin(object): - name = StringField() - - class Foo(EmbeddedDocument, NameMixin): - quantity = IntField() - - self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys())) - - class Bar(Document, NameMixin): - widgets = StringField() - - self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) - - def test_mixin_inheritance(self): - class BaseMixIn(object): - count = IntField() - data = StringField() - - class DoubleMixIn(BaseMixIn): - comment = StringField() - - class TestDoc(Document, DoubleMixIn): - age = IntField() - - TestDoc.drop_collection() - t = TestDoc(count=12, data="test", - comment="great!", age=19) - - t.save() - - t = TestDoc.objects.first() - - self.assertEqual(t.age, 19) - self.assertEqual(t.comment, "great!") - self.assertEqual(t.data, "test") - self.assertEqual(t.count, 12) - - def test_save_reference(self): - """Ensure that a document reference field may be saved in the database. - """ - - class BlogPost(Document): - meta = {'collection': 'blogpost_1'} - content = StringField() - author = ReferenceField(self.Person) - - BlogPost.drop_collection() - - author = self.Person(name='Test User') - author.save() - - post = BlogPost(content='Watched some TV today... how exciting.') - # Should only reference author when saving - post.author = author - post.save() - - post_obj = BlogPost.objects.first() - - # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], - bson.DBRef)) - self.assertTrue(isinstance(post_obj.author, self.Person)) - self.assertEqual(post_obj.author.name, 'Test User') - - # Ensure that the dereferenced object may be changed and saved - post_obj.author.age = 25 - post_obj.author.save() - - author = list(self.Person.objects(name='Test User'))[-1] - self.assertEqual(author.age, 25) - - BlogPost.drop_collection() - - def test_cannot_perform_joins_references(self): - - class BlogPost(Document): - author = ReferenceField(self.Person) - author2 = GenericReferenceField() - - def test_reference(): - list(BlogPost.objects(author__name="test")) - - self.assertRaises(InvalidQueryError, test_reference) - - def test_generic_reference(): - list(BlogPost.objects(author2__name="test")) - - self.assertRaises(InvalidQueryError, test_generic_reference) - - def test_duplicate_db_fields_raise_invalid_document_error(self): - """Ensure a InvalidDocumentError is thrown if duplicate fields - declare the same db_field""" - - def throw_invalid_document_error(): - class Foo(Document): - name = StringField() - name2 = StringField(db_field='name') - - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - - def test_invalid_son(self): - """Raise an error if loading invalid data""" - class Occurrence(EmbeddedDocument): - number = IntField() - - class Word(Document): - stem = StringField() - count = IntField(default=1) - forms = ListField(StringField(), default=list) - occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) - - def raise_invalid_document(): - Word._from_son({'stem': [1,2,3], 'forms': 1, 'count': 'one', 'occurs': {"hello": None}}) - - self.assertRaises(InvalidDocumentError, raise_invalid_document) - - def test_reverse_delete_rule_cascade_and_nullify(self): - """Ensure that a referenced document is also deleted upon deletion. - """ - - class BlogPost(Document): - content = StringField() - author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) - reviewer = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) - - self.Person.drop_collection() - BlogPost.drop_collection() - - author = self.Person(name='Test User') - author.save() - - reviewer = self.Person(name='Re Viewer') - reviewer.save() - - post = BlogPost(content = 'Watched some TV') - post.author = author - post.reviewer = reviewer - post.save() - - reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) # No effect on the BlogPost - self.assertEqual(BlogPost.objects.get().reviewer, None) - - # Delete the Person, which should lead to deletion of the BlogPost, too - author.delete() - self.assertEqual(len(BlogPost.objects), 0) - - def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): - """Ensure that a referenced document is also deleted upon deletion for - complex fields. - """ - - class BlogPost(Document): - content = StringField() - authors = ListField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) - reviewers = ListField(ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) - - self.Person.drop_collection() - - BlogPost.drop_collection() - - author = self.Person(name='Test User') - author.save() - - reviewer = self.Person(name='Re Viewer') - reviewer.save() - - post = BlogPost(content='Watched some TV') - post.authors = [author] - post.reviewers = [reviewer] - post.save() - - # Deleting the reviewer should have no effect on the BlogPost - reviewer.delete() - self.assertEqual(len(BlogPost.objects), 1) - self.assertEqual(BlogPost.objects.get().reviewers, []) - - # Delete the Person, which should lead to deletion of the BlogPost, too - author.delete() - self.assertEqual(len(BlogPost.objects), 0) - - def test_two_way_reverse_delete_rule(self): - """Ensure that Bi-Directional relationships work with - reverse_delete_rule - """ - - class Bar(Document): - content = StringField() - foo = ReferenceField('Foo') - - class Foo(Document): - content = StringField() - bar = ReferenceField(Bar) - - Bar.register_delete_rule(Foo, 'bar', NULLIFY) - Foo.register_delete_rule(Bar, 'foo', NULLIFY) - - - Bar.drop_collection() - Foo.drop_collection() - - b = Bar(content="Hello") - b.save() - - f = Foo(content="world", bar=b) - f.save() - - b.foo = f - b.save() - - f.delete() - - self.assertEqual(len(Bar.objects), 1) # No effect on the BlogPost - self.assertEqual(Bar.objects.get().foo, None) - - def test_invalid_reverse_delete_rules_raise_errors(self): - - def throw_invalid_document_error(): - class Blog(Document): - content = StringField() - authors = MapField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) - reviewers = DictField(field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) - - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - - def throw_invalid_document_error_embedded(): - class Parents(EmbeddedDocument): - father = ReferenceField('Person', reverse_delete_rule=DENY) - mother = ReferenceField('Person', reverse_delete_rule=DENY) - - self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded) - - def test_reverse_delete_rule_cascade_recurs(self): - """Ensure that a chain of documents is also deleted upon cascaded - deletion. - """ - - class BlogPost(Document): - content = StringField() - author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) - - class Comment(Document): - text = StringField() - post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE) - - self.Person.drop_collection() - BlogPost.drop_collection() - Comment.drop_collection() - - author = self.Person(name='Test User') - author.save() - - post = BlogPost(content = 'Watched some TV') - post.author = author - post.save() - - comment = Comment(text = 'Kudos.') - comment.post = post - comment.save() - - # Delete the Person, which should lead to deletion of the BlogPost, and, - # recursively to the Comment, too - author.delete() - self.assertEqual(len(Comment.objects), 0) - - self.Person.drop_collection() - BlogPost.drop_collection() - Comment.drop_collection() - - def test_reverse_delete_rule_deny(self): - """Ensure that a document cannot be referenced if there are still - documents referring to it. - """ - - class BlogPost(Document): - content = StringField() - author = ReferenceField(self.Person, reverse_delete_rule=DENY) - - self.Person.drop_collection() - BlogPost.drop_collection() - - author = self.Person(name='Test User') - author.save() - - post = BlogPost(content = 'Watched some TV') - post.author = author - post.save() - - # Delete the Person should be denied - self.assertRaises(OperationError, author.delete) # Should raise denied error - self.assertEqual(len(BlogPost.objects), 1) # No objects may have been deleted - self.assertEqual(len(self.Person.objects), 1) - - # Other users, that don't have BlogPosts must be removable, like normal - author = self.Person(name='Another User') - author.save() - - self.assertEqual(len(self.Person.objects), 2) - author.delete() - self.assertEqual(len(self.Person.objects), 1) - - self.Person.drop_collection() - BlogPost.drop_collection() - - def subclasses_and_unique_keys_works(self): - - class A(Document): - pass - - class B(A): - foo = BooleanField(unique=True) - - A.drop_collection() - B.drop_collection() - - A().save() - A().save() - B(foo=True).save() - - self.assertEqual(A.objects.count(), 2) - self.assertEqual(B.objects.count(), 1) - A.drop_collection() - B.drop_collection() - - def test_document_hash(self): - """Test document in list, dict, set - """ - class User(Document): - pass - - class BlogPost(Document): - pass - - # Clear old datas - User.drop_collection() - BlogPost.drop_collection() - - u1 = User.objects.create() - u2 = User.objects.create() - u3 = User.objects.create() - u4 = User() # New object - - b1 = BlogPost.objects.create() - b2 = BlogPost.objects.create() - - # in List - all_user_list = list(User.objects.all()) - - self.assertTrue(u1 in all_user_list) - self.assertTrue(u2 in all_user_list) - self.assertTrue(u3 in all_user_list) - self.assertFalse(u4 in all_user_list) # New object - self.assertFalse(b1 in all_user_list) # Other object - self.assertFalse(b2 in all_user_list) # Other object - - # in Dict - all_user_dic = {} - for u in User.objects.all(): - all_user_dic[u] = "OK" - - self.assertEqual(all_user_dic.get(u1, False), "OK" ) - self.assertEqual(all_user_dic.get(u2, False), "OK" ) - self.assertEqual(all_user_dic.get(u3, False), "OK" ) - self.assertEqual(all_user_dic.get(u4, False), False ) # New object - self.assertEqual(all_user_dic.get(b1, False), False ) # Other object - self.assertEqual(all_user_dic.get(b2, False), False ) # Other object - - # in Set - all_user_set = set(User.objects.all()) - - self.assertTrue(u1 in all_user_set ) - - def test_picklable(self): - - pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) - pickle_doc.embedded = PickleEmbedded() - pickle_doc.save() - - pickled_doc = pickle.dumps(pickle_doc) - resurrected = pickle.loads(pickled_doc) - - self.assertEqual(resurrected, pickle_doc) - - resurrected.string = "Two" - resurrected.save() - - pickle_doc = pickle_doc.reload() - self.assertEqual(resurrected, pickle_doc) - - def test_throw_invalid_document_error(self): - - # test handles people trying to upsert - def throw_invalid_document_error(): - class Blog(Document): - validate = DictField() - - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - - def test_mutating_documents(self): - - class B(EmbeddedDocument): - field1 = StringField(default='field1') - - class A(Document): - b = EmbeddedDocumentField(B, default=lambda: B()) - - A.drop_collection() - a = A() - a.save() - a.reload() - self.assertEqual(a.b.field1, 'field1') - - class C(EmbeddedDocument): - c_field = StringField(default='cfield') - - class B(EmbeddedDocument): - field1 = StringField(default='field1') - field2 = EmbeddedDocumentField(C, default=lambda: C()) - - class A(Document): - b = EmbeddedDocumentField(B, default=lambda: B()) - - a = A.objects()[0] - a.b.field2.c_field = 'new value' - a.save() - - a.reload() - self.assertEqual(a.b.field2.c_field, 'new value') - - def test_can_save_false_values(self): - """Ensures you can save False values on save""" - class Doc(Document): - foo = StringField() - archived = BooleanField(default=False, required=True) - - Doc.drop_collection() - d = Doc() - d.save() - d.archived = False - d.save() - - self.assertEqual(Doc.objects(archived=False).count(), 1) - - - def test_can_save_false_values_dynamic(self): - """Ensures you can save False values on dynamic docs""" - class Doc(DynamicDocument): - foo = StringField() - - Doc.drop_collection() - d = Doc() - d.save() - d.archived = False - d.save() - - self.assertEqual(Doc.objects(archived=False).count(), 1) - - def test_do_not_save_unchanged_references(self): - """Ensures cascading saves dont auto update""" - class Job(Document): - name = StringField() - - class Person(Document): - name = StringField() - age = IntField() - job = ReferenceField(Job) - - Job.drop_collection() - Person.drop_collection() - - job = Job(name="Job 1") - # job should not have any changed fields after the save - job.save() - - person = Person(name="name", age=10, job=job) - - from pymongo.collection import Collection - orig_update = Collection.update - try: - def fake_update(*args, **kwargs): - self.fail("Unexpected update for %s" % args[0].name) - return orig_update(*args, **kwargs) - - Collection.update = fake_update - person.save() - finally: - Collection.update = orig_update - - def test_db_alias_tests(self): - """ DB Alias tests """ - # mongoenginetest - Is default connection alias from setUp() - # Register Aliases - register_connection('testdb-1', 'mongoenginetest2') - register_connection('testdb-2', 'mongoenginetest3') - register_connection('testdb-3', 'mongoenginetest4') - - class User(Document): - name = StringField() - meta = {"db_alias": "testdb-1"} - - class Book(Document): - name = StringField() - meta = {"db_alias": "testdb-2"} - - # Drops - User.drop_collection() - Book.drop_collection() - - # Create - bob = User.objects.create(name="Bob") - hp = Book.objects.create(name="Harry Potter") - - # Selects - self.assertEqual(User.objects.first(), bob) - self.assertEqual(Book.objects.first(), hp) - - # DeReference - class AuthorBooks(Document): - author = ReferenceField(User) - book = ReferenceField(Book) - meta = {"db_alias": "testdb-3"} - - # Drops - AuthorBooks.drop_collection() - - ab = AuthorBooks.objects.create(author=bob, book=hp) - - # select - self.assertEqual(AuthorBooks.objects.first(), ab) - self.assertEqual(AuthorBooks.objects.first().book, hp) - self.assertEqual(AuthorBooks.objects.first().author, bob) - self.assertEqual(AuthorBooks.objects.filter(author=bob).first(), ab) - self.assertEqual(AuthorBooks.objects.filter(book=hp).first(), ab) - - # DB Alias - self.assertEqual(User._get_db(), get_db("testdb-1")) - self.assertEqual(Book._get_db(), get_db("testdb-2")) - self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) - - # Collections - self.assertEqual(User._get_collection(), get_db("testdb-1")[User._get_collection_name()]) - self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]) - self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()]) - - def test_db_alias_propagates(self): - """db_alias propagates? - """ - class A(Document): - name = StringField() - meta = {"db_alias": "testdb-1", "allow_inheritance": True} - - class B(A): - pass - - self.assertEqual('testdb-1', B._meta.get('db_alias')) - - def test_db_ref_usage(self): - """ DB Ref usage in dict_fields""" - - class User(Document): - name = StringField() - - class Book(Document): - name = StringField() - author = ReferenceField(User) - extra = DictField() - meta = { - 'ordering': ['+name'] - } - - def __unicode__(self): - return self.name - - def __str__(self): - return self.name - - # Drops - User.drop_collection() - Book.drop_collection() - - # Authors - bob = User.objects.create(name="Bob") - jon = User.objects.create(name="Jon") - - # Redactors - karl = User.objects.create(name="Karl") - susan = User.objects.create(name="Susan") - peter = User.objects.create(name="Peter") - - # Bob - Book.objects.create(name="1", author=bob, extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) - Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()} ) - Book.objects.create(name="3", author=bob, extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) - Book.objects.create(name="4", author=bob) - - # Jon - Book.objects.create(name="5", author=jon) - Book.objects.create(name="6", author=peter) - Book.objects.create(name="7", author=jon) - Book.objects.create(name="8", author=jon) - Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) - - # Checks - self.assertEqual(u",".join([str(b) for b in Book.objects.all()] ) , "1,2,3,4,5,6,7,8,9" ) - # bob related books - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - Q(extra__a=bob ) | - Q(author=bob) | - Q(extra__b=bob))]) , - "1,2,3,4") - - # Susan & Karl related books - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - Q(extra__a__all=[karl, susan] ) | - Q(author__all=[karl, susan ] ) | - Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()] ) - ) ] ) , "1" ) - - # $Where - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - __raw__={ - "$where": """ - function(){ - return this.name == '1' || - this.name == '2';}""" - } - ) ]), "1,2") - - -class ValidatorErrorTest(unittest.TestCase): - - def test_to_dict(self): - """Ensure a ValidationError handles error to_dict correctly. - """ - error = ValidationError('root') - self.assertEqual(error.to_dict(), {}) - - # 1st level error schema - error.errors = {'1st': ValidationError('bad 1st'), } - self.assertTrue('1st' in error.to_dict()) - self.assertEqual(error.to_dict()['1st'], 'bad 1st') - - # 2nd level error schema - error.errors = {'1st': ValidationError('bad 1st', errors={ - '2nd': ValidationError('bad 2nd'), - })} - self.assertTrue('1st' in error.to_dict()) - self.assertTrue(isinstance(error.to_dict()['1st'], dict)) - self.assertTrue('2nd' in error.to_dict()['1st']) - self.assertEqual(error.to_dict()['1st']['2nd'], 'bad 2nd') - - # moar levels - error.errors = {'1st': ValidationError('bad 1st', errors={ - '2nd': ValidationError('bad 2nd', errors={ - '3rd': ValidationError('bad 3rd', errors={ - '4th': ValidationError('Inception'), - }), - }), - })} - self.assertTrue('1st' in error.to_dict()) - self.assertTrue('2nd' in error.to_dict()['1st']) - self.assertTrue('3rd' in error.to_dict()['1st']['2nd']) - self.assertTrue('4th' in error.to_dict()['1st']['2nd']['3rd']) - self.assertEqual(error.to_dict()['1st']['2nd']['3rd']['4th'], - 'Inception') - - self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") - - def test_model_validation(self): - - class User(Document): - username = StringField(primary_key=True) - name = StringField(required=True) - - try: - User().validate() - except ValidationError, e: - expected_error_message = """ValidationError(Field is required""" - self.assertTrue(expected_error_message in e.message) - self.assertEqual(e.to_dict(), { - 'username': 'Field is required', - 'name': 'Field is required'}) - - def test_spaces_in_keys(self): - - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - setattr(doc, 'hello world', 1) - doc.save() - - one = Doc.objects.filter(**{'hello world': 1}).count() - self.assertEqual(1, one) - - def test_fields_rewrite(self): - class BasePerson(Document): - name = StringField() - age = IntField() - meta = {'abstract': True} - - class Person(BasePerson): - name = StringField(required=True) - - p = Person(age=15) - self.assertRaises(ValidationError, p.validate) - - def test_cascaded_save_wrong_reference(self): - - class ADocument(Document): - val = IntField() - - class BDocument(Document): - a = ReferenceField(ADocument) - - ADocument.drop_collection() - BDocument.drop_collection() - - a = ADocument() - a.val = 15 - a.save() - - b = BDocument() - b.a = a - b.save() - - a.delete() - - b = BDocument.objects.first() - b.save(cascade=True) - - def test_shard_key(self): - class LogEntry(Document): - machine = StringField() - log = StringField() - - meta = { - 'shard_key': ('machine',) - } - - LogEntry.drop_collection() - - log = LogEntry() - log.machine = "Localhost" - log.save() - - log.log = "Saving" - log.save() - - def change_shard_key(): - log.machine = "127.0.0.1" - - self.assertRaises(OperationError, change_shard_key) - - def test_shard_key_primary(self): - class LogEntry(Document): - machine = StringField(primary_key=True) - log = StringField() - - meta = { - 'shard_key': ('machine',) - } - - LogEntry.drop_collection() - - log = LogEntry() - log.machine = "Localhost" - log.save() - - log.log = "Saving" - log.save() - - def change_shard_key(): - log.machine = "127.0.0.1" - - self.assertRaises(OperationError, change_shard_key) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_dynamic_document.py b/tests/test_dynamic_document.py deleted file mode 100644 index 23762a34..00000000 --- a/tests/test_dynamic_document.py +++ /dev/null @@ -1,533 +0,0 @@ -import unittest - -from mongoengine import * -from mongoengine.connection import get_db - - -class DynamicDocTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - class Person(DynamicDocument): - name = StringField() - meta = {'allow_inheritance': True} - - Person.drop_collection() - - self.Person = Person - - def test_simple_dynamic_document(self): - """Ensures simple dynamic documents are saved correctly""" - - p = self.Person() - p.name = "James" - p.age = 34 - - self.assertEqual(p.to_mongo(), - {"_types": ["Person"], "_cls": "Person", - "name": "James", "age": 34} - ) - - p.save() - - self.assertEqual(self.Person.objects.first().age, 34) - - # Confirm no changes to self.Person - self.assertFalse(hasattr(self.Person, 'age')) - - def test_dynamic_document_delta(self): - """Ensures simple dynamic documents can delta correctly""" - p = self.Person(name="James", age=34) - self.assertEqual(p._delta(), ({'_types': ['Person'], 'age': 34, 'name': 'James', '_cls': 'Person'}, {})) - - p.doc = 123 - del(p.doc) - self.assertEqual(p._delta(), ({'_types': ['Person'], 'age': 34, 'name': 'James', '_cls': 'Person'}, {'doc': 1})) - - def test_change_scope_of_variable(self): - """Test changing the scope of a dynamic field has no adverse effects""" - p = self.Person() - p.name = "Dean" - p.misc = 22 - p.save() - - p = self.Person.objects.get() - p.misc = {'hello': 'world'} - p.save() - - p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) - - def test_delete_dynamic_field(self): - """Test deleting a dynamic field works""" - self.Person.drop_collection() - p = self.Person() - p.name = "Dean" - p.misc = 22 - p.save() - - p = self.Person.objects.get() - p.misc = {'hello': 'world'} - p.save() - - p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) - collection = self.db[self.Person._get_collection_name()] - obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', '_types', 'misc', 'name']) - - del(p.misc) - p.save() - - p = self.Person.objects.get() - self.assertFalse(hasattr(p, 'misc')) - - obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', '_types', 'name']) - - def test_dynamic_document_queries(self): - """Ensure we can query dynamic fields""" - p = self.Person() - p.name = "Dean" - p.age = 22 - p.save() - - self.assertEqual(1, self.Person.objects(age=22).count()) - p = self.Person.objects(age=22) - p = p.get() - self.assertEqual(22, p.age) - - def test_complex_dynamic_document_queries(self): - class Person(DynamicDocument): - name = StringField() - - Person.drop_collection() - - p = Person(name="test") - p.age = "ten" - p.save() - - p1 = Person(name="test1") - p1.age = "less then ten and a half" - p1.save() - - p2 = Person(name="test2") - p2.age = 10 - p2.save() - - self.assertEqual(Person.objects(age__icontains='ten').count(), 2) - self.assertEqual(Person.objects(age__gte=10).count(), 1) - - def test_complex_data_lookups(self): - """Ensure you can query dynamic document dynamic fields""" - p = self.Person() - p.misc = {'hello': 'world'} - p.save() - - self.assertEqual(1, self.Person.objects(misc__hello='world').count()) - - def test_inheritance(self): - """Ensure that dynamic document plays nice with inheritance""" - class Employee(self.Person): - salary = IntField() - - Employee.drop_collection() - - self.assertTrue('name' in Employee._fields) - self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._get_collection_name(), - self.Person._get_collection_name()) - - joe_bloggs = Employee() - joe_bloggs.name = "Joe Bloggs" - joe_bloggs.salary = 10 - joe_bloggs.age = 20 - joe_bloggs.save() - - self.assertEqual(1, self.Person.objects(age=20).count()) - self.assertEqual(1, Employee.objects(age=20).count()) - - joe_bloggs = self.Person.objects.first() - self.assertTrue(isinstance(joe_bloggs, Employee)) - - def test_embedded_dynamic_document(self): - """Test dynamic embedded documents""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc.to_mongo(), {"_types": ['Doc'], "_cls": "Doc", - "embedded_field": { - "_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}] - } - }) - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field, ['1', 2, {'hello': 'world'}]) - - def test_complex_embedded_documents(self): - """Test complex dynamic embedded documents setups""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - embedded_1.list_field = ['1', 2, embedded_2] - doc.embedded_field = embedded_1 - - self.assertEqual(doc.to_mongo(), {"_types": ['Doc'], "_cls": "Doc", - "embedded_field": { - "_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, - {"_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}]} - ] - } - }) - doc.save() - doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - - embedded_field = doc.embedded_field.list_field[2] - - self.assertEqual(embedded_field.__class__, Embedded) - self.assertEqual(embedded_field.string_field, "hello") - self.assertEqual(embedded_field.int_field, 1) - self.assertEqual(embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(embedded_field.list_field, ['1', 2, {'hello': 'world'}]) - - def test_delta_for_dynamic_documents(self): - p = self.Person() - p.name = "Dean" - p.age = 22 - p.save() - - p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) - - p = self.Person.objects(age=22).get() - p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) - - p.save() - self.assertEqual(1, self.Person.objects(age=24).count()) - - def test_delta(self): - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - - def test_delta_recursive(self): - """Testing deltaing works with dynamic documents""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) - - embedded_delta = { - 'string_field': 'hello', - 'int_field': 1, - 'dict_field': {'hello': 'world'}, - 'list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'embedded_field': embedded_delta}, {})) - - doc.save() - doc.reload() - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['embedded_field.dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) - - self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) - doc.save() - doc.reload() - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) - doc.save() - doc.reload() - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc.reload() - - self.assertEqual(doc.embedded_field.list_field[2]._changed_fields, []) - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field.2.string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc.reload() - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str)# use str as a key to allow comparing uncomperable types - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) - doc.save() - doc.reload() - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) - - doc.save() - doc.reload() - - doc.dict_field = {'embedded': embedded_1} - doc.save() - doc.reload() - - doc.dict_field['embedded'].string_field = 'Hello World' - self.assertEqual(doc._get_changed_fields(), ['dict_field.embedded.string_field']) - self.assertEqual(doc._delta(), ({'dict_field.embedded.string_field': 'Hello World'}, {})) - - def test_indexes(self): - """Ensure that indexes are used when meta[indexes] is specified. - """ - class BlogPost(DynamicDocument): - meta = { - 'indexes': [ - '-date', - ('category', '-date') - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date', ('cat', 'date') - # NB: there is no index on _types by itself, since - # the indices on -date and tags will both contain - # _types as first element in the key - self.assertEqual(len(info), 3) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('date', -1)] - in info) - self.assertTrue([('_types', 1), ('date', -1)] in info) - - def test_dynamic_and_embedded(self): - """Ensure embedded documents play nicely""" - - class Address(EmbeddedDocument): - city = StringField() - - class Person(DynamicDocument): - name = StringField() - meta = {'allow_inheritance': True} - - Person.drop_collection() - - Person(name="Ross", address=Address(city="London")).save() - - person = Person.objects.first() - person.address.city = "Lundenne" - person.save() - - self.assertEqual(Person.objects.first().address.city, "Lundenne") - - person = Person.objects.first() - person.address = Address(city="Londinium") - person.save() - - self.assertEqual(Person.objects.first().address.city, "Londinium") - - person = Person.objects.first() - person.age = 35 - person.save() - self.assertEqual(Person.objects.first().age, 35) diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 3118c5a4..d27960f7 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,3 +1,5 @@ +import sys +sys.path[0:0] = [""] import unittest import pymongo diff --git a/tests/test_signals.py b/tests/test_signals.py index d1199248..32517ddf 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] import unittest from mongoengine import * @@ -21,6 +23,7 @@ class SignalTests(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + class Author(Document): name = StringField() @@ -69,7 +72,7 @@ class SignalTests(unittest.TestCase): else: signal_output.append('Not loaded') self.Author = Author - + Author.drop_collection() class Another(Document): name = StringField() @@ -108,8 +111,24 @@ class SignalTests(unittest.TestCase): signal_output.append('post_delete Another signal, %s' % document) self.Another = Another - # Save up the number of connected signals so that we can check at the end - # that all the signals we register get properly unregistered + Another.drop_collection() + + class ExplicitId(Document): + id = IntField(primary_key=True) + + @classmethod + def post_save(cls, sender, document, **kwargs): + if 'created' in kwargs: + if kwargs['created']: + signal_output.append('Is created') + else: + signal_output.append('Is updated') + + self.ExplicitId = ExplicitId + ExplicitId.drop_collection() + + # Save up the number of connected signals so that we can check at the + # end that all the signals we register get properly unregistered self.pre_signals = ( len(signals.pre_init.receivers), len(signals.post_init.receivers), @@ -137,6 +156,8 @@ class SignalTests(unittest.TestCase): signals.pre_delete.connect(Another.pre_delete, sender=Another) signals.post_delete.connect(Another.post_delete, sender=Another) + signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) + def tearDown(self): signals.pre_init.disconnect(self.Author.pre_init) signals.post_init.disconnect(self.Author.post_init) @@ -154,6 +175,8 @@ class SignalTests(unittest.TestCase): signals.post_save.disconnect(self.Another.post_save) signals.pre_save.disconnect(self.Another.pre_save) + signals.post_save.disconnect(self.ExplicitId.post_save) + # Check that all our signals got disconnected properly. post_signals = ( len(signals.pre_init.receivers), @@ -166,13 +189,15 @@ class SignalTests(unittest.TestCase): len(signals.post_bulk_insert.receivers), ) + self.ExplicitId.objects.delete() + self.assertEqual(self.pre_signals, post_signals) def test_model_signals(self): """ Model saves should throw some signals. """ def create_author(): - a1 = self.Author(name='Bill Shakespeare') + self.Author(name='Bill Shakespeare') def bulk_create_author_with_load(): a1 = self.Author(name='Bill Shakespeare') @@ -196,7 +221,7 @@ class SignalTests(unittest.TestCase): ]) a1.reload() - a1.name='William Shakespeare' + a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", "post_save signal, William Shakespeare", @@ -228,3 +253,15 @@ class SignalTests(unittest.TestCase): ]) self.Author.objects.delete() + + def test_signals_with_explicit_doc_ids(self): + """ Model saves must have a created flag the first time.""" + ei = self.ExplicitId(id=123) + # post save must received the created flag, even if there's already + # an object id present + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + # second time, it must be an update + self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + +if __name__ == '__main__': + unittest.main()