Merge branch 'master' into pr/590
Conflicts: mongoengine/connection.py
This commit is contained in:
		
							
								
								
									
										28
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								.travis.yml
									
									
									
									
									
								
							| @@ -6,26 +6,28 @@ python: | |||||||
|     - "2.7" |     - "2.7" | ||||||
|     - "3.2" |     - "3.2" | ||||||
|     - "3.3" |     - "3.3" | ||||||
|  |     - "3.4" | ||||||
|  |     - "pypy" | ||||||
| env: | env: | ||||||
|   - PYMONGO=dev DJANGO=1.6 |   - PYMONGO=dev DJANGO=1.6.5 | ||||||
|   - PYMONGO=dev DJANGO=1.5.5 |   - PYMONGO=dev DJANGO=1.5.8 | ||||||
|   - PYMONGO=dev DJANGO=1.4.10 |   - PYMONGO=2.7.1 DJANGO=1.6.5 | ||||||
|   - PYMONGO=2.5 DJANGO=1.6 |   - PYMONGO=2.7.1 DJANGO=1.5.8 | ||||||
|   - PYMONGO=2.5 DJANGO=1.5.5 |  | ||||||
|   - PYMONGO=2.5 DJANGO=1.4.10 | matrix: | ||||||
|   - PYMONGO=3.2 DJANGO=1.6 |     fast_finish: true | ||||||
|   - PYMONGO=3.2 DJANGO=1.5.5 |  | ||||||
|   - PYMONGO=3.3 DJANGO=1.6 |  | ||||||
|   - PYMONGO=3.3 DJANGO=1.5.5 |  | ||||||
| install: | install: | ||||||
|     - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi |     - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev python-tk | ||||||
|     - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi |  | ||||||
|     - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi |     - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi | ||||||
|     - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi |     - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO; true; fi | ||||||
|  |     - pip install Django==$DJANGO | ||||||
|     - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b |     - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b | ||||||
|     - python setup.py install |     - python setup.py install | ||||||
| script: | script: | ||||||
|     - python setup.py test |     - python setup.py test | ||||||
|  |     - if [[ $TRAVIS_PYTHON_VERSION == '3.'* ]]; then 2to3 . -w; fi; | ||||||
|  |     - python benchmark.py | ||||||
| notifications: | notifications: | ||||||
|   irc: "irc.freenode.org#mongoengine" |   irc: "irc.freenode.org#mongoengine" | ||||||
| branches: | branches: | ||||||
|   | |||||||
							
								
								
									
										13
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -171,7 +171,7 @@ that much better: | |||||||
|  * Michael Bartnett (https://github.com/michaelbartnett) |  * Michael Bartnett (https://github.com/michaelbartnett) | ||||||
|  * Alon Horev (https://github.com/alonho) |  * Alon Horev (https://github.com/alonho) | ||||||
|  * Kelvin Hammond (https://github.com/kelvinhammond) |  * Kelvin Hammond (https://github.com/kelvinhammond) | ||||||
|  * Jatin- (https://github.com/jatin-) |  * Jatin Chopra (https://github.com/jatin) | ||||||
|  * Paul Uithol (https://github.com/PaulUithol) |  * Paul Uithol (https://github.com/PaulUithol) | ||||||
|  * Thom Knowles (https://github.com/fleat) |  * Thom Knowles (https://github.com/fleat) | ||||||
|  * Paul (https://github.com/squamous) |  * Paul (https://github.com/squamous) | ||||||
| @@ -189,3 +189,14 @@ that much better: | |||||||
|  * Tom (https://github.com/tomprimozic) |  * Tom (https://github.com/tomprimozic) | ||||||
|  * j0hnsmith (https://github.com/j0hnsmith) |  * j0hnsmith (https://github.com/j0hnsmith) | ||||||
|  * Damien Churchill (https://github.com/damoxc) |  * Damien Churchill (https://github.com/damoxc) | ||||||
|  |  * Jonathan Simon Prates (https://github.com/jonathansp) | ||||||
|  |  * Thiago Papageorgiou (https://github.com/tmpapageorgiou) | ||||||
|  |  * Omer Katz (https://github.com/thedrow) | ||||||
|  |  * Falcon Dai (https://github.com/falcondai) | ||||||
|  |  * Polyrabbit (https://github.com/polyrabbit) | ||||||
|  |  * Sagiv Malihi (https://github.com/sagivmalihi) | ||||||
|  |  * Dmitry Konishchev (https://github.com/KonishchevDmitry) | ||||||
|  |  * Martyn Smith (https://github.com/martynsmith) | ||||||
|  |  * Andrei Zbikowski (https://github.com/b1naryth1ef) | ||||||
|  |  * Ronald van Rij (https://github.com/ronaldvanrij) | ||||||
|  |  * François Schmidts (https://github.com/jaesivsm) | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								README.rst
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README.rst
									
									
									
									
									
								
							| @@ -29,9 +29,18 @@ setup.py install``. | |||||||
|  |  | ||||||
| Dependencies | Dependencies | ||||||
| ============ | ============ | ||||||
| - pymongo 2.5+ | - pymongo>=2.5 | ||||||
| - sphinx (optional - for documentation generation) | - sphinx (optional - for documentation generation) | ||||||
|  |  | ||||||
|  | Optional Dependencies | ||||||
|  | --------------------- | ||||||
|  | - **Django Integration:** Django>=1.4.0 for Python 2.x or PyPy and Django>=1.5.0 for Python 3.x | ||||||
|  | - **Image Fields**: Pillow>=2.0.0 or PIL (not recommended since MongoEngine is tested with Pillow) | ||||||
|  | - dateutil>=2.1.0 | ||||||
|  |  | ||||||
|  | .. note | ||||||
|  |    MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: Django 1.6.5 | ||||||
|  |  | ||||||
| Examples | Examples | ||||||
| ======== | ======== | ||||||
| Some simple examples of what MongoEngine code looks like:: | Some simple examples of what MongoEngine code looks like:: | ||||||
|   | |||||||
							
								
								
									
										67
									
								
								benchmark.py
									
									
									
									
									
								
							
							
						
						
									
										67
									
								
								benchmark.py
									
									
									
									
									
								
							| @@ -15,7 +15,7 @@ def cprofile_main(): | |||||||
|     class Noddy(Document): |     class Noddy(Document): | ||||||
|         fields = DictField() |         fields = DictField() | ||||||
|  |  | ||||||
|     for i in xrange(1): |     for i in range(1): | ||||||
|         noddy = Noddy() |         noddy = Noddy() | ||||||
|         for j in range(20): |         for j in range(20): | ||||||
|             noddy.fields["key" + str(j)] = "value " + str(j) |             noddy.fields["key" + str(j)] = "value " + str(j) | ||||||
| @@ -113,6 +113,7 @@ def main(): | |||||||
|     4.68946313858 |     4.68946313858 | ||||||
|     ---------------------------------------------------------------------------------------------------- |     ---------------------------------------------------------------------------------------------------- | ||||||
|     """ |     """ | ||||||
|  |     print("Benchmarking...") | ||||||
|  |  | ||||||
|     setup = """ |     setup = """ | ||||||
| from pymongo import MongoClient | from pymongo import MongoClient | ||||||
| @@ -127,7 +128,7 @@ connection = MongoClient() | |||||||
| db = connection.timeit_test | db = connection.timeit_test | ||||||
| noddy = db.noddy | noddy = db.noddy | ||||||
|  |  | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     example = {'fields': {}} |     example = {'fields': {}} | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         example['fields']["key"+str(j)] = "value "+str(j) |         example['fields']["key"+str(j)] = "value "+str(j) | ||||||
| @@ -138,10 +139,10 @@ myNoddys = noddy.find() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - Pymongo""" |     print("""Creating 10000 dictionaries - Pymongo""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| from pymongo import MongoClient | from pymongo import MongoClient | ||||||
| @@ -150,7 +151,7 @@ connection = MongoClient() | |||||||
| db = connection.timeit_test | db = connection.timeit_test | ||||||
| noddy = db.noddy | noddy = db.noddy | ||||||
|  |  | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     example = {'fields': {}} |     example = {'fields': {}} | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         example['fields']["key"+str(j)] = "value "+str(j) |         example['fields']["key"+str(j)] = "value "+str(j) | ||||||
| @@ -161,10 +162,10 @@ myNoddys = noddy.find() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""" |     print("""Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     setup = """ |     setup = """ | ||||||
| from pymongo import MongoClient | from pymongo import MongoClient | ||||||
| @@ -180,7 +181,7 @@ class Noddy(Document): | |||||||
| """ | """ | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         noddy.fields["key"+str(j)] = "value "+str(j) |         noddy.fields["key"+str(j)] = "value "+str(j) | ||||||
| @@ -190,13 +191,13 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - MongoEngine""" |     print("""Creating 10000 dictionaries - MongoEngine""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     fields = {} |     fields = {} | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
| @@ -208,13 +209,13 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries without continual assign - MongoEngine""" |     print("""Creating 10000 dictionaries without continual assign - MongoEngine""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         noddy.fields["key"+str(j)] = "value "+str(j) |         noddy.fields["key"+str(j)] = "value "+str(j) | ||||||
| @@ -224,13 +225,13 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""" |     print("""Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         noddy.fields["key"+str(j)] = "value "+str(j) |         noddy.fields["key"+str(j)] = "value "+str(j) | ||||||
| @@ -240,13 +241,13 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""" |     print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         noddy.fields["key"+str(j)] = "value "+str(j) |         noddy.fields["key"+str(j)] = "value "+str(j) | ||||||
| @@ -256,13 +257,13 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""" |     print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|     stmt = """ |     stmt = """ | ||||||
| for i in xrange(10000): | for i in range(10000): | ||||||
|     noddy = Noddy() |     noddy = Noddy() | ||||||
|     for j in range(20): |     for j in range(20): | ||||||
|         noddy.fields["key"+str(j)] = "value "+str(j) |         noddy.fields["key"+str(j)] = "value "+str(j) | ||||||
| @@ -272,10 +273,10 @@ myNoddys = Noddy.objects() | |||||||
| [n for n in myNoddys] # iterate | [n for n in myNoddys] # iterate | ||||||
| """ | """ | ||||||
|  |  | ||||||
|     print "-" * 100 |     print("-" * 100) | ||||||
|     print """Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""" |     print("""Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""") | ||||||
|     t = timeit.Timer(stmt=stmt, setup=setup) |     t = timeit.Timer(stmt=stmt, setup=setup) | ||||||
|     print t.timeit(1) |     print(t.timeit(1)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -2,6 +2,36 @@ | |||||||
| Changelog | Changelog | ||||||
| ========= | ========= | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Changes in 0.9.X - DEV | ||||||
|  | ====================== | ||||||
|  |  | ||||||
|  | - Implemented equality between Documents and DBRefs #597 | ||||||
|  | - Fixed ReferenceField inside nested ListFields dereferencing problem #368 | ||||||
|  | - Added the ability to reload specific document fields #100 | ||||||
|  | - Added db_alias support and fixes for custom map/reduce output #586 | ||||||
|  | - post_save signal now has access to delta information about field changes #594 #589 | ||||||
|  | - Don't query with $orderby for qs.get() #600 | ||||||
|  | - Fix id shard key save issue #636 | ||||||
|  | - Fixes issue with recursive embedded document errors #557 | ||||||
|  | - Fix clear_changed_fields() clearing unsaved documents bug #602 | ||||||
|  | - Removing support for Django 1.4.x, pymongo 2.5.x, pymongo 2.6.x. | ||||||
|  | - Removing support for Python < 2.6.6 | ||||||
|  | - Fixed $maxDistance location for geoJSON $near queries with MongoDB 2.6+ #664 | ||||||
|  | - QuerySet.modify() method to provide find_and_modify() like behaviour #677 | ||||||
|  | - Added support for the using() method on a queryset #676 | ||||||
|  | - PYPY support #673 | ||||||
|  | - Connection pooling #674 | ||||||
|  | - Avoid to open all documents from cursors in an if stmt #655 | ||||||
|  | - Ability to clear the ordering #657 | ||||||
|  | - Raise NotUniqueError in Document.update() on pymongo.errors.DuplicateKeyError #626 | ||||||
|  | - Slots - memory improvements #625 | ||||||
|  | - Fixed incorrectly split a query key when it ends with "_" #619 | ||||||
|  | - Geo docs updates #613 | ||||||
|  | - Workaround a dateutil bug #608 | ||||||
|  | - Conditional save for atomic-style operations #511 | ||||||
|  | - Allow dynamic dictionary-style field access #559 | ||||||
|  |  | ||||||
| Changes in 0.8.7 | Changes in 0.8.7 | ||||||
| ================ | ================ | ||||||
| - Calling reload on deleted / nonexistant documents raises DoesNotExist (#538) | - Calling reload on deleted / nonexistant documents raises DoesNotExist (#538) | ||||||
|   | |||||||
| @@ -531,6 +531,8 @@ field name to the index definition. | |||||||
| Sometimes its more efficient to index parts of Embedded / dictionary fields, | Sometimes its more efficient to index parts of Embedded / dictionary fields, | ||||||
| in this case use 'dot' notation to identify the value to index eg: `rank.title` | in this case use 'dot' notation to identify the value to index eg: `rank.title` | ||||||
|  |  | ||||||
|  | .. _geospatial-indexes: | ||||||
|  |  | ||||||
| Geospatial indexes | Geospatial indexes | ||||||
| ------------------ | ------------------ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -46,7 +46,7 @@ slightly different manner.  First, a new file must be created by calling the | |||||||
|     marmot.photo.write('some_more_image_data') |     marmot.photo.write('some_more_image_data') | ||||||
|     marmot.photo.close() |     marmot.photo.close() | ||||||
|  |  | ||||||
|     marmot.photo.save() |     marmot.save() | ||||||
|  |  | ||||||
| Deletion | Deletion | ||||||
| -------- | -------- | ||||||
|   | |||||||
| @@ -488,8 +488,9 @@ calling it with keyword arguments:: | |||||||
| Atomic updates | Atomic updates | ||||||
| ============== | ============== | ||||||
| Documents may be updated atomically by using the | Documents may be updated atomically by using the | ||||||
| :meth:`~mongoengine.queryset.QuerySet.update_one` and | :meth:`~mongoengine.queryset.QuerySet.update_one`, | ||||||
| :meth:`~mongoengine.queryset.QuerySet.update` methods on a | :meth:`~mongoengine.queryset.QuerySet.update` and | ||||||
|  | :meth:`~mongoengine.queryset.QuerySet.modify` methods on a | ||||||
| :meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" | :meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" | ||||||
| that you may use with these methods: | that you may use with these methods: | ||||||
|  |  | ||||||
| @@ -499,11 +500,13 @@ that you may use with these methods: | |||||||
| * ``dec`` -- decrement a value by a given amount | * ``dec`` -- decrement a value by a given amount | ||||||
| * ``push`` -- append a value to a list | * ``push`` -- append a value to a list | ||||||
| * ``push_all`` -- append several values to a list | * ``push_all`` -- append several values to a list | ||||||
| * ``pop`` -- remove the first or last element of a list | * ``pop`` -- remove the first or last element of a list `depending on the value`_ | ||||||
| * ``pull`` -- remove a value from a list | * ``pull`` -- remove a value from a list | ||||||
| * ``pull_all`` -- remove several values from a list | * ``pull_all`` -- remove several values from a list | ||||||
| * ``add_to_set`` -- add value to a list only if its not in the list already | * ``add_to_set`` -- add value to a list only if its not in the list already | ||||||
|  |  | ||||||
|  | .. _depending on the value: http://docs.mongodb.org/manual/reference/operator/update/pop/ | ||||||
|  |  | ||||||
| The syntax for atomic updates is similar to the querying syntax, but the | The syntax for atomic updates is similar to the querying syntax, but the | ||||||
| modifier comes before the field, not after it:: | modifier comes before the field, not after it:: | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,4 +1,6 @@ | |||||||
| import weakref | import weakref | ||||||
|  | import functools | ||||||
|  | import itertools | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
|  |  | ||||||
| __all__ = ("BaseDict", "BaseList") | __all__ = ("BaseDict", "BaseList") | ||||||
| @@ -156,3 +158,98 @@ class BaseList(list): | |||||||
|     def _mark_as_changed(self): |     def _mark_as_changed(self): | ||||||
|         if hasattr(self._instance, '_mark_as_changed'): |         if hasattr(self._instance, '_mark_as_changed'): | ||||||
|             self._instance._mark_as_changed(self._name) |             self._instance._mark_as_changed(self._name) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StrictDict(object): | ||||||
|  |     __slots__ = ()  | ||||||
|  |     _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) | ||||||
|  |     _classes = {} | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         for k,v in kwargs.iteritems(): | ||||||
|  |             setattr(self, k, v) | ||||||
|  |     def __getitem__(self, key): | ||||||
|  |         key = '_reserved_' + key if key in self._special_fields else key | ||||||
|  |         try: | ||||||
|  |             return getattr(self, key) | ||||||
|  |         except AttributeError: | ||||||
|  |             raise KeyError(key) | ||||||
|  |     def __setitem__(self, key, value): | ||||||
|  |         key = '_reserved_' + key if key in self._special_fields else key | ||||||
|  |         return setattr(self, key, value) | ||||||
|  |     def __contains__(self, key): | ||||||
|  |         return hasattr(self, key) | ||||||
|  |     def get(self, key, default=None): | ||||||
|  |         try: | ||||||
|  |             return self[key] | ||||||
|  |         except KeyError: | ||||||
|  |             return default | ||||||
|  |     def pop(self, key, default=None): | ||||||
|  |         v = self.get(key, default) | ||||||
|  |         try: | ||||||
|  |             delattr(self, key) | ||||||
|  |         except AttributeError: | ||||||
|  |             pass | ||||||
|  |         return v | ||||||
|  |     def iteritems(self): | ||||||
|  |         for key in self: | ||||||
|  |             yield key, self[key] | ||||||
|  |     def items(self): | ||||||
|  |         return [(k, self[k]) for k in iter(self)] | ||||||
|  |     def keys(self): | ||||||
|  |         return list(iter(self)) | ||||||
|  |     def __iter__(self): | ||||||
|  |         return (key for key in self.__slots__ if hasattr(self, key)) | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(list(self.iteritems())) | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return self.items() == other.items() | ||||||
|  |     def __neq__(self, other): | ||||||
|  |         return self.items() != other.items() | ||||||
|  |      | ||||||
|  |     @classmethod | ||||||
|  |     def create(cls, allowed_keys): | ||||||
|  |         allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) | ||||||
|  |         allowed_keys = frozenset(allowed_keys_tuple) | ||||||
|  |         if allowed_keys not in cls._classes: | ||||||
|  |             class SpecificStrictDict(cls): | ||||||
|  |                 __slots__ = allowed_keys_tuple | ||||||
|  |             cls._classes[allowed_keys] = SpecificStrictDict  | ||||||
|  |         return cls._classes[allowed_keys] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SemiStrictDict(StrictDict): | ||||||
|  |     __slots__ = ('_extras') | ||||||
|  |     _classes = {} | ||||||
|  |     def __getattr__(self, attr): | ||||||
|  |         try: | ||||||
|  |             super(SemiStrictDict, self).__getattr__(attr) | ||||||
|  |         except AttributeError: | ||||||
|  |             try: | ||||||
|  |                 return self.__getattribute__('_extras')[attr] | ||||||
|  |             except KeyError as e: | ||||||
|  |                 raise AttributeError(e) | ||||||
|  |     def __setattr__(self, attr, value): | ||||||
|  |         try: | ||||||
|  |             super(SemiStrictDict, self).__setattr__(attr, value) | ||||||
|  |         except AttributeError: | ||||||
|  |             try: | ||||||
|  |                 self._extras[attr] = value | ||||||
|  |             except AttributeError: | ||||||
|  |                 self._extras = {attr: value} | ||||||
|  |  | ||||||
|  |     def __delattr__(self, attr): | ||||||
|  |         try: | ||||||
|  |             super(SemiStrictDict, self).__delattr__(attr) | ||||||
|  |         except AttributeError: | ||||||
|  |             try: | ||||||
|  |                 del self._extras[attr] | ||||||
|  |             except KeyError as e: | ||||||
|  |                 raise AttributeError(e) | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|  |         try: | ||||||
|  |             extras_iter = iter(self.__getattribute__('_extras')) | ||||||
|  |         except AttributeError: | ||||||
|  |             extras_iter = () | ||||||
|  |         return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,24 +13,23 @@ from mongoengine import signals | |||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import (ValidationError, InvalidDocumentError, | from mongoengine.errors import (ValidationError, InvalidDocumentError, | ||||||
|                                 LookUpError) |                                 LookUpError) | ||||||
| from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, | from mongoengine.python_support import PY3, txt_type | ||||||
|                                         to_str_keys_recursive) |  | ||||||
|  |  | ||||||
| from mongoengine.base.common import get_document, ALLOW_INHERITANCE | from mongoengine.base.common import get_document, ALLOW_INHERITANCE | ||||||
| from mongoengine.base.datastructures import BaseDict, BaseList | from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict | ||||||
| from mongoengine.base.fields import ComplexBaseField | from mongoengine.base.fields import ComplexBaseField | ||||||
|  |  | ||||||
| __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') | ||||||
|  |  | ||||||
| NON_FIELD_ERRORS = '__all__' | NON_FIELD_ERRORS = '__all__' | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseDocument(object): | class BaseDocument(object): | ||||||
|  |     __slots__ = ('_changed_fields', '_initialised', '_created', '_data', | ||||||
|  |                   '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') | ||||||
|  |  | ||||||
|     _dynamic = False |     _dynamic = False | ||||||
|     _created = True |  | ||||||
|     _dynamic_lock = True |     _dynamic_lock = True | ||||||
|     _initialised = False |     STRICT = False | ||||||
|  |  | ||||||
|     def __init__(self, *args, **values): |     def __init__(self, *args, **values): | ||||||
|         """ |         """ | ||||||
| @@ -39,6 +38,8 @@ class BaseDocument(object): | |||||||
|         :param __auto_convert: Try and will cast python objects to Object types |         :param __auto_convert: Try and will cast python objects to Object types | ||||||
|         :param values: A dictionary of values for the document |         :param values: A dictionary of values for the document | ||||||
|         """ |         """ | ||||||
|  |         self._initialised = False | ||||||
|  |         self._created = True | ||||||
|         if args: |         if args: | ||||||
|             # Combine positional arguments with named arguments. |             # Combine positional arguments with named arguments. | ||||||
|             # We only want named arguments. |             # We only want named arguments. | ||||||
| @@ -54,7 +55,11 @@ class BaseDocument(object): | |||||||
|         __auto_convert = values.pop("__auto_convert", True) |         __auto_convert = values.pop("__auto_convert", True) | ||||||
|         signals.pre_init.send(self.__class__, document=self, values=values) |         signals.pre_init.send(self.__class__, document=self, values=values) | ||||||
|  |  | ||||||
|         self._data = {} |         if self.STRICT and not self._dynamic: | ||||||
|  |             self._data = StrictDict.create(allowed_keys=self._fields.keys())() | ||||||
|  |         else: | ||||||
|  |             self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())() | ||||||
|  |  | ||||||
|         self._dynamic_fields = SON() |         self._dynamic_fields = SON() | ||||||
|  |  | ||||||
|         # Assign default values to instance |         # Assign default values to instance | ||||||
| @@ -130,17 +135,25 @@ class BaseDocument(object): | |||||||
|                 self._data[name] = value |                 self._data[name] = value | ||||||
|                 if hasattr(self, '_changed_fields'): |                 if hasattr(self, '_changed_fields'): | ||||||
|                     self._mark_as_changed(name) |                     self._mark_as_changed(name) | ||||||
|  |         try: | ||||||
|  |             self__created = self._created | ||||||
|  |         except AttributeError: | ||||||
|  |             self__created = True | ||||||
|  |  | ||||||
|         if (self._is_document and not self._created and |         if (self._is_document and not self__created and | ||||||
|            name in self._meta.get('shard_key', tuple()) and |            name in self._meta.get('shard_key', tuple()) and | ||||||
|            self._data.get(name) != value): |            self._data.get(name) != value): | ||||||
|             OperationError = _import_class('OperationError') |             OperationError = _import_class('OperationError') | ||||||
|             msg = "Shard Keys are immutable. Tried to update %s" % name |             msg = "Shard Keys are immutable. Tried to update %s" % name | ||||||
|             raise OperationError(msg) |             raise OperationError(msg) | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             self__initialised = self._initialised | ||||||
|  |         except AttributeError: | ||||||
|  |             self__initialised = False | ||||||
|         # Check if the user has created a new instance of a class |         # Check if the user has created a new instance of a class | ||||||
|         if (self._is_document and self._initialised |         if (self._is_document and self__initialised | ||||||
|            and self._created and name == self._meta['id_field']): |            and self__created and name == self._meta['id_field']): | ||||||
|                 super(BaseDocument, self).__setattr__('_created', False) |                 super(BaseDocument, self).__setattr__('_created', False) | ||||||
|  |  | ||||||
|         super(BaseDocument, self).__setattr__(name, value) |         super(BaseDocument, self).__setattr__(name, value) | ||||||
| @@ -158,9 +171,11 @@ class BaseDocument(object): | |||||||
|         if isinstance(data["_data"], SON): |         if isinstance(data["_data"], SON): | ||||||
|             data["_data"] = self.__class__._from_son(data["_data"])._data |             data["_data"] = self.__class__._from_son(data["_data"])._data | ||||||
|         for k in ('_changed_fields', '_initialised', '_created', '_data', |         for k in ('_changed_fields', '_initialised', '_created', '_data', | ||||||
|                   '_fields_ordered', '_dynamic_fields'): |                    '_dynamic_fields'): | ||||||
|             if k in data: |             if k in data: | ||||||
|                 setattr(self, k, data[k]) |                 setattr(self, k, data[k]) | ||||||
|  |         if '_fields_ordered' in data: | ||||||
|  |             setattr(type(self), '_fields_ordered', data['_fields_ordered']) | ||||||
|         dynamic_fields = data.get('_dynamic_fields') or SON() |         dynamic_fields = data.get('_dynamic_fields') or SON() | ||||||
|         for k in dynamic_fields.keys(): |         for k in dynamic_fields.keys(): | ||||||
|             setattr(self, k, data["_data"].get(k)) |             setattr(self, k, data["_data"].get(k)) | ||||||
| @@ -182,7 +197,7 @@ class BaseDocument(object): | |||||||
|         """Dictionary-style field access, set a field's value. |         """Dictionary-style field access, set a field's value. | ||||||
|         """ |         """ | ||||||
|         # Ensure that the field exists before settings its value |         # Ensure that the field exists before settings its value | ||||||
|         if name not in self._fields: |         if not self._dynamic and name not in self._fields: | ||||||
|             raise KeyError(name) |             raise KeyError(name) | ||||||
|         return setattr(self, name, value) |         return setattr(self, name, value) | ||||||
|  |  | ||||||
| @@ -214,8 +229,9 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id'): |         if isinstance(other, self.__class__) and hasattr(other, 'id'): | ||||||
|             if self.id == other.id: |             return self.id == other.id | ||||||
|                 return True |         if isinstance(other, DBRef): | ||||||
|  |             return self._get_collection_name() == other.collection and self.id == other.id | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def __ne__(self, other): |     def __ne__(self, other): | ||||||
| @@ -317,7 +333,7 @@ class BaseDocument(object): | |||||||
|             pk = "None" |             pk = "None" | ||||||
|             if hasattr(self, 'pk'): |             if hasattr(self, 'pk'): | ||||||
|                 pk = self.pk |                 pk = self.pk | ||||||
|             elif self._instance: |             elif self._instance and hasattr(self._instance, 'pk'): | ||||||
|                 pk = self._instance.pk |                 pk = self._instance.pk | ||||||
|             message = "ValidationError (%s:%s) " % (self._class_name, pk) |             message = "ValidationError (%s:%s) " % (self._class_name, pk) | ||||||
|             raise ValidationError(message, errors=errors) |             raise ValidationError(message, errors=errors) | ||||||
| @@ -392,6 +408,8 @@ class BaseDocument(object): | |||||||
|                 else: |                 else: | ||||||
|                     data = getattr(data, part, None) |                     data = getattr(data, part, None) | ||||||
|                 if hasattr(data, "_changed_fields"): |                 if hasattr(data, "_changed_fields"): | ||||||
|  |                     if hasattr(data, "_is_document") and data._is_document: | ||||||
|  |                         continue | ||||||
|                     data._changed_fields = [] |                     data._changed_fields = [] | ||||||
|         self._changed_fields = [] |         self._changed_fields = [] | ||||||
|  |  | ||||||
| @@ -545,10 +563,6 @@ class BaseDocument(object): | |||||||
|         # class if unavailable |         # class if unavailable | ||||||
|         class_name = son.get('_cls', cls._class_name) |         class_name = son.get('_cls', cls._class_name) | ||||||
|         data = dict(("%s" % key, value) for key, value in son.iteritems()) |         data = dict(("%s" % key, value) for key, value in son.iteritems()) | ||||||
|         if not UNICODE_KWARGS: |  | ||||||
|             # python 2.6.4 and lower cannot handle unicode keys |  | ||||||
|             # passed to class constructor example: cls(**data) |  | ||||||
|             to_str_keys_recursive(data) |  | ||||||
|  |  | ||||||
|         # Return correct subclass for document type |         # Return correct subclass for document type | ||||||
|         if class_name != cls._class_name: |         if class_name != cls._class_name: | ||||||
| @@ -586,6 +600,8 @@ class BaseDocument(object): | |||||||
|                    % (cls._class_name, errors)) |                    % (cls._class_name, errors)) | ||||||
|             raise InvalidDocumentError(msg) |             raise InvalidDocumentError(msg) | ||||||
|  |  | ||||||
|  |         if cls.STRICT: | ||||||
|  |             data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) | ||||||
|         obj = cls(__auto_convert=False, **data) |         obj = cls(__auto_convert=False, **data) | ||||||
|         obj._changed_fields = changed_fields |         obj._changed_fields = changed_fields | ||||||
|         obj._created = False |         obj._created = False | ||||||
| @@ -825,7 +841,11 @@ class BaseDocument(object): | |||||||
|         """Dynamically set the display value for a field with choices""" |         """Dynamically set the display value for a field with choices""" | ||||||
|         for attr_name, field in self._fields.items(): |         for attr_name, field in self._fields.items(): | ||||||
|             if field.choices: |             if field.choices: | ||||||
|                 setattr(self, |                 if self._dynamic: | ||||||
|  |                     obj = self | ||||||
|  |                 else: | ||||||
|  |                     obj = type(self) | ||||||
|  |                 setattr(obj, | ||||||
|                         'get_%s_display' % attr_name, |                         'get_%s_display' % attr_name, | ||||||
|                         partial(self.__get_field_display, field=field)) |                         partial(self.__get_field_display, field=field)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
|                     new_class.id = field |                     new_class.id = field | ||||||
|  |  | ||||||
|         # Set primary key if not defined by the document |         # Set primary key if not defined by the document | ||||||
|         new_class._auto_id_field = False |         new_class._auto_id_field = getattr(parent_doc_cls, | ||||||
|  |                                            '_auto_id_field', False) | ||||||
|         if not new_class._meta.get('id_field'): |         if not new_class._meta.get('id_field'): | ||||||
|             new_class._auto_id_field = True |             new_class._auto_id_field = True | ||||||
|             new_class._meta['id_field'] = 'id' |             new_class._meta['id_field'] = 'id' | ||||||
|   | |||||||
| @@ -96,21 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|             raise ConnectionError(msg) |             raise ConnectionError(msg) | ||||||
|         conn_settings = _connection_settings[alias].copy() |         conn_settings = _connection_settings[alias].copy() | ||||||
|  |  | ||||||
|         if hasattr(pymongo, 'version_tuple'):  # Support for 2.1+ |  | ||||||
|         conn_settings.pop('name', None) |         conn_settings.pop('name', None) | ||||||
|         conn_settings.pop('slaves', None) |         conn_settings.pop('slaves', None) | ||||||
|         conn_settings.pop('is_slave', None) |         conn_settings.pop('is_slave', None) | ||||||
|         conn_settings.pop('username', None) |         conn_settings.pop('username', None) | ||||||
|         conn_settings.pop('password', None) |         conn_settings.pop('password', None) | ||||||
|         conn_settings.pop('authentication_source', None) |         conn_settings.pop('authentication_source', None) | ||||||
|         else: |  | ||||||
|             # Get all the slave connections |  | ||||||
|             if 'slaves' in conn_settings: |  | ||||||
|                 slaves = [] |  | ||||||
|                 for slave_alias in conn_settings['slaves']: |  | ||||||
|                     slaves.append(get_connection(slave_alias)) |  | ||||||
|                 conn_settings['slaves'] = slaves |  | ||||||
|                 conn_settings.pop('read_preference', None) |  | ||||||
|  |  | ||||||
|         connection_class = MongoClient |         connection_class = MongoClient | ||||||
|         if 'replicaSet' in conn_settings: |         if 'replicaSet' in conn_settings: | ||||||
| @@ -123,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|             connection_class = MongoReplicaSetClient |             connection_class = MongoReplicaSetClient | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             _connections[alias] = connection_class(**conn_settings) |             connection = None | ||||||
|  |             connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems()) | ||||||
|  |             for alias, connection_settings in connection_settings_iterator: | ||||||
|  |                 connection_settings.pop('name', None) | ||||||
|  |                 connection_settings.pop('slaves', None) | ||||||
|  |                 connection_settings.pop('is_slave', None) | ||||||
|  |                 connection_settings.pop('username', None) | ||||||
|  |                 connection_settings.pop('password', None) | ||||||
|  |                 if conn_settings == connection_settings and _connections.get(alias, None): | ||||||
|  |                     connection = _connections[alias] | ||||||
|  |                     break | ||||||
|  |  | ||||||
|  |             _connections[alias] = connection if connection else connection_class(**conn_settings) | ||||||
|         except Exception, e: |         except Exception, e: | ||||||
|             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) |             raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) | ||||||
|     return _connections[alias] |     return _connections[alias] | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db | ||||||
| from mongoengine.queryset import QuerySet |  | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ("switch_db", "switch_collection", "no_dereference", | __all__ = ("switch_db", "switch_collection", "no_dereference", | ||||||
| @@ -162,12 +161,6 @@ class no_sub_classes(object): | |||||||
|         return self.cls |         return self.cls | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuerySetNoDeRef(QuerySet): |  | ||||||
|     """Special no_dereference QuerySet""" |  | ||||||
|     def __dereference(items, max_depth=1, instance=None, name=None): |  | ||||||
|             return items |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class query_counter(object): | class query_counter(object): | ||||||
|     """ Query_counter context manager to get the number of queries. """ |     """ Query_counter context manager to get the number of queries. """ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -36,7 +36,7 @@ class DeReference(object): | |||||||
|         if instance and isinstance(instance, (Document, EmbeddedDocument, |         if instance and isinstance(instance, (Document, EmbeddedDocument, | ||||||
|                                               TopLevelDocumentMetaclass)): |                                               TopLevelDocumentMetaclass)): | ||||||
|             doc_type = instance._fields.get(name) |             doc_type = instance._fields.get(name) | ||||||
|             if hasattr(doc_type, 'field'): |             while hasattr(doc_type, 'field'): | ||||||
|                 doc_type = doc_type.field |                 doc_type = doc_type.field | ||||||
|  |  | ||||||
|             if isinstance(doc_type, ReferenceField): |             if isinstance(doc_type, ReferenceField): | ||||||
| @@ -51,9 +51,19 @@ class DeReference(object): | |||||||
|                     return items |                     return items | ||||||
|                 elif not field.dbref: |                 elif not field.dbref: | ||||||
|                     if not hasattr(items, 'items'): |                     if not hasattr(items, 'items'): | ||||||
|                         items = [field.to_python(v) |  | ||||||
|                              if not isinstance(v, (DBRef, Document)) else v |                         def _get_items(items): | ||||||
|                              for v in items] |                             new_items = [] | ||||||
|  |                             for v in items: | ||||||
|  |                                 if isinstance(v, list): | ||||||
|  |                                     new_items.append(_get_items(v)) | ||||||
|  |                                 elif not isinstance(v, (DBRef, Document)): | ||||||
|  |                                     new_items.append(field.to_python(v)) | ||||||
|  |                                 else: | ||||||
|  |                                     new_items.append(v) | ||||||
|  |                             return new_items | ||||||
|  |  | ||||||
|  |                         items = _get_items(items) | ||||||
|                     else: |                     else: | ||||||
|                         items = dict([ |                         items = dict([ | ||||||
|                             (k, field.to_python(v)) |                             (k, field.to_python(v)) | ||||||
| @@ -114,11 +124,11 @@ class DeReference(object): | |||||||
|         """Fetch all references and convert to their document objects |         """Fetch all references and convert to their document objects | ||||||
|         """ |         """ | ||||||
|         object_map = {} |         object_map = {} | ||||||
|         for col, dbrefs in self.reference_map.iteritems(): |         for collection, dbrefs in self.reference_map.iteritems(): | ||||||
|             keys = object_map.keys() |             keys = object_map.keys() | ||||||
|             refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys])) |             refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys])) | ||||||
|             if hasattr(col, 'objects'):  # We have a document class for the refs |             if hasattr(collection, 'objects'):  # We have a document class for the refs | ||||||
|                 references = col.objects.in_bulk(refs) |                 references = collection.objects.in_bulk(refs) | ||||||
|                 for key, doc in references.iteritems(): |                 for key, doc in references.iteritems(): | ||||||
|                     object_map[key] = doc |                     object_map[key] = doc | ||||||
|             else:  # Generic reference: use the refs data to convert to document |             else:  # Generic reference: use the refs data to convert to document | ||||||
| @@ -126,19 +136,19 @@ class DeReference(object): | |||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|                 if doc_type: |                 if doc_type: | ||||||
|                     references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) |                     references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) | ||||||
|                     for ref in references: |                     for ref in references: | ||||||
|                         doc = doc_type._from_son(ref) |                         doc = doc_type._from_son(ref) | ||||||
|                         object_map[doc.id] = doc |                         object_map[doc.id] = doc | ||||||
|                 else: |                 else: | ||||||
|                     references = get_db()[col].find({'_id': {'$in': refs}}) |                     references = get_db()[collection].find({'_id': {'$in': refs}}) | ||||||
|                     for ref in references: |                     for ref in references: | ||||||
|                         if '_cls' in ref: |                         if '_cls' in ref: | ||||||
|                             doc = get_document(ref["_cls"])._from_son(ref) |                             doc = get_document(ref["_cls"])._from_son(ref) | ||||||
|                         elif doc_type is None: |                         elif doc_type is None: | ||||||
|                             doc = get_document( |                             doc = get_document( | ||||||
|                                 ''.join(x.capitalize() |                                 ''.join(x.capitalize() | ||||||
|                                     for x in col.split('_')))._from_son(ref) |                                     for x in collection.split('_')))._from_son(ref) | ||||||
|                         else: |                         else: | ||||||
|                             doc = doc_type._from_son(ref) |                             doc = doc_type._from_son(ref) | ||||||
|                         object_map[doc.id] = doc |                         object_map[doc.id] = doc | ||||||
|   | |||||||
| @@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, | |||||||
|                               BaseDocument, BaseDict, BaseList, |                               BaseDocument, BaseDict, BaseList, | ||||||
|                               ALLOW_INHERITANCE, get_document) |                               ALLOW_INHERITANCE, get_document) | ||||||
| from mongoengine.errors import ValidationError | from mongoengine.errors import ValidationError | ||||||
| from mongoengine.queryset import OperationError, NotUniqueError, QuerySet | from mongoengine.queryset import (OperationError, NotUniqueError, | ||||||
|  |                                   QuerySet, transform) | ||||||
| from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME | ||||||
| from mongoengine.context_managers import switch_db, switch_collection | from mongoengine.context_managers import switch_db, switch_collection | ||||||
|  |  | ||||||
| @@ -54,20 +55,21 @@ class EmbeddedDocument(BaseDocument): | |||||||
|     dictionary. |     dictionary. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     __slots__ = ('_instance') | ||||||
|  |  | ||||||
|     # The __metaclass__ attribute is removed by 2to3 when running with Python3 |     # The __metaclass__ attribute is removed by 2to3 when running with Python3 | ||||||
|     # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 |     # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 | ||||||
|     my_metaclass  = DocumentMetaclass |     my_metaclass  = DocumentMetaclass | ||||||
|     __metaclass__ = DocumentMetaclass |     __metaclass__ = DocumentMetaclass | ||||||
|  |  | ||||||
|     _instance = None |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(EmbeddedDocument, self).__init__(*args, **kwargs) |         super(EmbeddedDocument, self).__init__(*args, **kwargs) | ||||||
|  |         self._instance = None | ||||||
|         self._changed_fields = [] |         self._changed_fields = [] | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if isinstance(other, self.__class__): |         if isinstance(other, self.__class__): | ||||||
|             return self.to_mongo() == other.to_mongo() |             return self._data == other._data | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def __ne__(self, other): |     def __ne__(self, other): | ||||||
| @@ -125,6 +127,8 @@ class Document(BaseDocument): | |||||||
|     my_metaclass  = TopLevelDocumentMetaclass |     my_metaclass  = TopLevelDocumentMetaclass | ||||||
|     __metaclass__ = TopLevelDocumentMetaclass |     __metaclass__ = TopLevelDocumentMetaclass | ||||||
|  |  | ||||||
|  |     __slots__ = ('__objects' ) | ||||||
|  |  | ||||||
|     def pk(): |     def pk(): | ||||||
|         """Primary key alias |         """Primary key alias | ||||||
|         """ |         """ | ||||||
| @@ -180,7 +184,7 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|     def save(self, force_insert=False, validate=True, clean=True, |     def save(self, force_insert=False, validate=True, clean=True, | ||||||
|              write_concern=None,  cascade=None, cascade_kwargs=None, |              write_concern=None,  cascade=None, cascade_kwargs=None, | ||||||
|              _refs=None, **kwargs): |              _refs=None, save_condition=None, **kwargs): | ||||||
|         """Save the :class:`~mongoengine.Document` to the database. If the |         """Save the :class:`~mongoengine.Document` to the database. If the | ||||||
|         document already exists, it will be updated, otherwise it will be |         document already exists, it will be updated, otherwise it will be | ||||||
|         created. |         created. | ||||||
| @@ -203,7 +207,8 @@ class Document(BaseDocument): | |||||||
|         :param cascade_kwargs: (optional) kwargs dictionary to be passed throw |         :param cascade_kwargs: (optional) kwargs dictionary to be passed throw | ||||||
|             to cascading saves.  Implies ``cascade=True``. |             to cascading saves.  Implies ``cascade=True``. | ||||||
|         :param _refs: A list of processed references used in cascading saves |         :param _refs: A list of processed references used in cascading saves | ||||||
|  |         :param save_condition: only perform save if matching record in db | ||||||
|  |             satisfies condition(s) (e.g., version number) | ||||||
|         .. versionchanged:: 0.5 |         .. versionchanged:: 0.5 | ||||||
|             In existing documents it only saves changed fields using |             In existing documents it only saves changed fields using | ||||||
|             set / unset.  Saves are cascaded and any |             set / unset.  Saves are cascaded and any | ||||||
| @@ -217,6 +222,9 @@ class Document(BaseDocument): | |||||||
|             meta['cascade'] = True.  Also you can pass different kwargs to |             meta['cascade'] = True.  Also you can pass different kwargs to | ||||||
|             the cascade save using cascade_kwargs which overwrites the |             the cascade save using cascade_kwargs which overwrites the | ||||||
|             existing kwargs with custom values. |             existing kwargs with custom values. | ||||||
|  |         .. versionchanged:: 0.8.5 | ||||||
|  |             Optional save_condition that only overwrites existing documents | ||||||
|  |             if the condition is satisfied in the current db record. | ||||||
|         """ |         """ | ||||||
|         signals.pre_save.send(self.__class__, document=self) |         signals.pre_save.send(self.__class__, document=self) | ||||||
|  |  | ||||||
| @@ -230,7 +238,8 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|         created = ('_id' not in doc or self._created or force_insert) |         created = ('_id' not in doc or self._created or force_insert) | ||||||
|  |  | ||||||
|         signals.pre_save_post_validation.send(self.__class__, document=self, created=created) |         signals.pre_save_post_validation.send(self.__class__, document=self, | ||||||
|  |                                               created=created) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             collection = self._get_collection() |             collection = self._get_collection() | ||||||
| @@ -243,7 +252,12 @@ class Document(BaseDocument): | |||||||
|                 object_id = doc['_id'] |                 object_id = doc['_id'] | ||||||
|                 updates, removals = self._delta() |                 updates, removals = self._delta() | ||||||
|                 # Need to add shard key to query, or you get an error |                 # Need to add shard key to query, or you get an error | ||||||
|                 select_dict = {'_id': object_id} |                 if save_condition is not None: | ||||||
|  |                     select_dict = transform.query(self.__class__, | ||||||
|  |                                                   **save_condition) | ||||||
|  |                 else: | ||||||
|  |                     select_dict = {} | ||||||
|  |                 select_dict['_id'] = object_id | ||||||
|                 shard_key = self.__class__._meta.get('shard_key', tuple()) |                 shard_key = self.__class__._meta.get('shard_key', tuple()) | ||||||
|                 for k in shard_key: |                 for k in shard_key: | ||||||
|                     actual_key = self._db_field_map.get(k, k) |                     actual_key = self._db_field_map.get(k, k) | ||||||
| @@ -263,10 +277,12 @@ class Document(BaseDocument): | |||||||
|                 if removals: |                 if removals: | ||||||
|                     update_query["$unset"] = removals |                     update_query["$unset"] = removals | ||||||
|                 if updates or removals: |                 if updates or removals: | ||||||
|  |                     upsert = save_condition is None | ||||||
|                     last_error = collection.update(select_dict, update_query, |                     last_error = collection.update(select_dict, update_query, | ||||||
|                                                    upsert=True, **write_concern) |                                                    upsert=upsert, **write_concern) | ||||||
|                     created = is_new_object(last_error) |                     created = is_new_object(last_error) | ||||||
|  |  | ||||||
|  |  | ||||||
|             if cascade is None: |             if cascade is None: | ||||||
|                 cascade = self._meta.get('cascade', False) or cascade_kwargs is not None |                 cascade = self._meta.get('cascade', False) or cascade_kwargs is not None | ||||||
|  |  | ||||||
| @@ -293,12 +309,12 @@ class Document(BaseDocument): | |||||||
|                 raise NotUniqueError(message % unicode(err)) |                 raise NotUniqueError(message % unicode(err)) | ||||||
|             raise OperationError(message % unicode(err)) |             raise OperationError(message % unicode(err)) | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         if id_field not in self._meta.get('shard_key', []): |         if created or id_field not in self._meta.get('shard_key', []): | ||||||
|             self[id_field] = self._fields[id_field].to_python(object_id) |             self[id_field] = self._fields[id_field].to_python(object_id) | ||||||
|  |  | ||||||
|  |         signals.post_save.send(self.__class__, document=self, created=created) | ||||||
|         self._clear_changed_fields() |         self._clear_changed_fields() | ||||||
|         self._created = False |         self._created = False | ||||||
|         signals.post_save.send(self.__class__, document=self, created=created) |  | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def cascade_save(self, *args, **kwargs): |     def cascade_save(self, *args, **kwargs): | ||||||
| @@ -447,27 +463,41 @@ class Document(BaseDocument): | |||||||
|         DeReference()([self], max_depth + 1) |         DeReference()([self], max_depth + 1) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def reload(self, max_depth=1): |     def reload(self, *fields, **kwargs): | ||||||
|         """Reloads all attributes from the database. |         """Reloads all attributes from the database. | ||||||
|  |  | ||||||
|  |         :param fields: (optional) args list of fields to reload | ||||||
|  |         :param max_depth: (optional) depth of dereferencing to follow | ||||||
|  |  | ||||||
|         .. versionadded:: 0.1.2 |         .. versionadded:: 0.1.2 | ||||||
|         .. versionchanged:: 0.6  Now chainable |         .. versionchanged:: 0.6  Now chainable | ||||||
|  |         .. versionchanged:: 0.9  Can provide specific fields to reload | ||||||
|         """ |         """ | ||||||
|  |         max_depth = 1 | ||||||
|  |         if fields and isinstance(fields[0], int): | ||||||
|  |             max_depth = fields[0] | ||||||
|  |             fields = fields[1:] | ||||||
|  |         elif "max_depth" in kwargs: | ||||||
|  |             max_depth = kwargs["max_depth"] | ||||||
|  |  | ||||||
|         if not self.pk: |         if not self.pk: | ||||||
|             raise self.DoesNotExist("Document does not exist") |             raise self.DoesNotExist("Document does not exist") | ||||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( |         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||||
|                     **self._object_key).limit(1).select_related(max_depth=max_depth) |                     **self._object_key).only(*fields).limit(1 | ||||||
|  |                     ).select_related(max_depth=max_depth) | ||||||
|  |  | ||||||
|         if obj: |         if obj: | ||||||
|             obj = obj[0] |             obj = obj[0] | ||||||
|         else: |         else: | ||||||
|             raise self.DoesNotExist("Document does not exist") |             raise self.DoesNotExist("Document does not exist") | ||||||
|  |  | ||||||
|         for field in self._fields_ordered: |         for field in self._fields_ordered: | ||||||
|  |             if not fields or field in fields: | ||||||
|                 setattr(self, field, self._reload(field, obj[field])) |                 setattr(self, field, self._reload(field, obj[field])) | ||||||
|  |  | ||||||
|         self._changed_fields = obj._changed_fields |         self._changed_fields = obj._changed_fields | ||||||
|         self._created = False |         self._created = False | ||||||
|         return obj |         return self | ||||||
|  |  | ||||||
|     def _reload(self, key, value): |     def _reload(self, key, value): | ||||||
|         """Used by :meth:`~mongoengine.Document.reload` to ensure the |         """Used by :meth:`~mongoengine.Document.reload` to ensure the | ||||||
|   | |||||||
| @@ -391,7 +391,7 @@ class DateTimeField(BaseField): | |||||||
|         if dateutil: |         if dateutil: | ||||||
|             try: |             try: | ||||||
|                 return dateutil.parser.parse(value) |                 return dateutil.parser.parse(value) | ||||||
|             except ValueError: |             except (TypeError, ValueError): | ||||||
|                 return None |                 return None | ||||||
|  |  | ||||||
|         # split usecs, because they are not recognized by strptime. |         # split usecs, because they are not recognized by strptime. | ||||||
| @@ -760,7 +760,7 @@ class DictField(ComplexBaseField): | |||||||
|     similar to an embedded document, but the structure is not defined. |     similar to an embedded document, but the structure is not defined. | ||||||
|  |  | ||||||
|     .. note:: |     .. note:: | ||||||
|         Required means it cannot be empty - as the default for ListFields is [] |         Required means it cannot be empty - as the default for DictFields is {} | ||||||
|  |  | ||||||
|     .. versionadded:: 0.3 |     .. versionadded:: 0.3 | ||||||
|     .. versionchanged:: 0.5 - Can now handle complex / varying types of data |     .. versionchanged:: 0.5 - Can now handle complex / varying types of data | ||||||
| @@ -1554,6 +1554,14 @@ class SequenceField(BaseField): | |||||||
|  |  | ||||||
|         return super(SequenceField, self).__set__(instance, value) |         return super(SequenceField, self).__set__(instance, value) | ||||||
|  |  | ||||||
|  |     def prepare_query_value(self, op, value): | ||||||
|  |         """ | ||||||
|  |         This method is overriden in order to convert the query value into to required | ||||||
|  |         type. We need to do this in order to be able to successfully compare query    | ||||||
|  |         values passed as string, the base implementation returns the value as is. | ||||||
|  |         """ | ||||||
|  |         return self.value_decorator(value) | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         if value is None: |         if value is None: | ||||||
|             value = self.generate() |             value = self.generate() | ||||||
| @@ -1613,7 +1621,12 @@ class UUIDField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class GeoPointField(BaseField): | class GeoPointField(BaseField): | ||||||
|     """A list storing a latitude and longitude. |     """A list storing a longitude and latitude coordinate.  | ||||||
|  |  | ||||||
|  |     .. note:: this represents a generic point in a 2D plane and a legacy way of  | ||||||
|  |         representing a geo point. It admits 2d indexes but not "2dsphere" indexes  | ||||||
|  |         in MongoDB > 2.4 which are more natural for modeling geospatial points.  | ||||||
|  |         See :ref:`geospatial-indexes`  | ||||||
|  |  | ||||||
|     .. versionadded:: 0.4 |     .. versionadded:: 0.4 | ||||||
|     """ |     """ | ||||||
| @@ -1635,7 +1648,7 @@ class GeoPointField(BaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class PointField(GeoJsonBaseField): | class PointField(GeoJsonBaseField): | ||||||
|     """A geo json field storing a latitude and longitude. |     """A GeoJSON field storing a longitude and latitude coordinate. | ||||||
|  |  | ||||||
|     The data is represented as: |     The data is represented as: | ||||||
|  |  | ||||||
| @@ -1654,7 +1667,7 @@ class PointField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class LineStringField(GeoJsonBaseField): | class LineStringField(GeoJsonBaseField): | ||||||
|     """A geo json field storing a line of latitude and longitude coordinates. |     """A GeoJSON field storing a line of longitude and latitude coordinates. | ||||||
|  |  | ||||||
|     The data is represented as: |     The data is represented as: | ||||||
|  |  | ||||||
| @@ -1672,7 +1685,7 @@ class LineStringField(GeoJsonBaseField): | |||||||
|  |  | ||||||
|  |  | ||||||
| class PolygonField(GeoJsonBaseField): | class PolygonField(GeoJsonBaseField): | ||||||
|     """A geo json field storing a polygon of latitude and longitude coordinates. |     """A GeoJSON field storing a polygon of longitude and latitude coordinates. | ||||||
|  |  | ||||||
|     The data is represented as: |     The data is represented as: | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,8 +3,6 @@ | |||||||
| import sys | import sys | ||||||
|  |  | ||||||
| PY3 = sys.version_info[0] == 3 | PY3 = sys.version_info[0] == 3 | ||||||
| PY25 = sys.version_info[:2] == (2, 5) |  | ||||||
| UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264 |  | ||||||
|  |  | ||||||
| if PY3: | if PY3: | ||||||
|     import codecs |     import codecs | ||||||
| @@ -29,33 +27,3 @@ else: | |||||||
|     txt_type = unicode |     txt_type = unicode | ||||||
|  |  | ||||||
| str_types = (bin_type, txt_type) | str_types = (bin_type, txt_type) | ||||||
|  |  | ||||||
| if PY25: |  | ||||||
|     def product(*args, **kwds): |  | ||||||
|         pools = map(tuple, args) * kwds.get('repeat', 1) |  | ||||||
|         result = [[]] |  | ||||||
|         for pool in pools: |  | ||||||
|             result = [x + [y] for x in result for y in pool] |  | ||||||
|         for prod in result: |  | ||||||
|             yield tuple(prod) |  | ||||||
|     reduce = reduce |  | ||||||
| else: |  | ||||||
|     from itertools import product |  | ||||||
|     from functools import reduce |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # For use with Python 2.5 |  | ||||||
| # converts all keys from unicode to str for d and all nested dictionaries |  | ||||||
| def to_str_keys_recursive(d): |  | ||||||
|     if isinstance(d, list): |  | ||||||
|         for val in d: |  | ||||||
|             if isinstance(val, (dict, list)): |  | ||||||
|                 to_str_keys_recursive(val) |  | ||||||
|     elif isinstance(d, dict): |  | ||||||
|         for key, val in d.items(): |  | ||||||
|             if isinstance(val, (dict, list)): |  | ||||||
|                 to_str_keys_recursive(val) |  | ||||||
|             if isinstance(key, unicode): |  | ||||||
|                 d[str(key)] = d.pop(key) |  | ||||||
|     else: |  | ||||||
|         raise ValueError("non list/dict parameter not allowed") |  | ||||||
|   | |||||||
| @@ -7,17 +7,20 @@ import pprint | |||||||
| import re | import re | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
|  | from bson import SON | ||||||
| from bson.code import Code | from bson.code import Code | ||||||
| from bson import json_util | from bson import json_util | ||||||
| import pymongo | import pymongo | ||||||
|  | import pymongo.errors | ||||||
| from pymongo.common import validate_read_preference | from pymongo.common import validate_read_preference | ||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
|  | from mongoengine.connection import get_db | ||||||
|  | from mongoengine.context_managers import switch_db | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.base.common import get_document | from mongoengine.base.common import get_document | ||||||
| from mongoengine.errors import (OperationError, NotUniqueError, | from mongoengine.errors import (OperationError, NotUniqueError, | ||||||
|                                 InvalidQueryError, LookUpError) |                                 InvalidQueryError, LookUpError) | ||||||
|  |  | ||||||
| from mongoengine.queryset import transform | from mongoengine.queryset import transform | ||||||
| from mongoengine.queryset.field_list import QueryFieldList | from mongoengine.queryset.field_list import QueryFieldList | ||||||
| from mongoengine.queryset.visitor import Q, QNode | from mongoengine.queryset.visitor import Q, QNode | ||||||
| @@ -50,7 +53,7 @@ class BaseQuerySet(object): | |||||||
|         self._initial_query = {} |         self._initial_query = {} | ||||||
|         self._where_clause = None |         self._where_clause = None | ||||||
|         self._loaded_fields = QueryFieldList() |         self._loaded_fields = QueryFieldList() | ||||||
|         self._ordering = [] |         self._ordering = None | ||||||
|         self._snapshot = False |         self._snapshot = False | ||||||
|         self._timeout = True |         self._timeout = True | ||||||
|         self._class_check = True |         self._class_check = True | ||||||
| @@ -146,7 +149,7 @@ class BaseQuerySet(object): | |||||||
|                     queryset._document._from_son(queryset._cursor[key], |                     queryset._document._from_son(queryset._cursor[key], | ||||||
|                                                  _auto_dereference=self._auto_dereference)) |                                                  _auto_dereference=self._auto_dereference)) | ||||||
|             if queryset._as_pymongo: |             if queryset._as_pymongo: | ||||||
|                 return queryset._get_as_pymongo(queryset._cursor.next()) |                 return queryset._get_as_pymongo(queryset._cursor[key]) | ||||||
|             return queryset._document._from_son(queryset._cursor[key], |             return queryset._document._from_son(queryset._cursor[key], | ||||||
|                                                 _auto_dereference=self._auto_dereference) |                                                 _auto_dereference=self._auto_dereference) | ||||||
|         raise AttributeError |         raise AttributeError | ||||||
| @@ -154,6 +157,22 @@ class BaseQuerySet(object): | |||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def _has_data(self): | ||||||
|  |         """ Retrieves whether cursor has any data. """ | ||||||
|  |  | ||||||
|  |         queryset = self.order_by() | ||||||
|  |         return False if queryset.first() is None else True | ||||||
|  |  | ||||||
|  |     def __nonzero__(self): | ||||||
|  |         """ Avoid to open all records in an if stmt in Py2. """ | ||||||
|  |  | ||||||
|  |         return self._has_data() | ||||||
|  |  | ||||||
|  |     def __bool__(self): | ||||||
|  |         """ Avoid to open all records in an if stmt in Py3. """ | ||||||
|  |  | ||||||
|  |         return self._has_data() | ||||||
|  |  | ||||||
|     # Core functions |     # Core functions | ||||||
|  |  | ||||||
|     def all(self): |     def all(self): | ||||||
| @@ -175,7 +194,7 @@ class BaseQuerySet(object): | |||||||
|         .. versionadded:: 0.3 |         .. versionadded:: 0.3 | ||||||
|         """ |         """ | ||||||
|         queryset = self.clone() |         queryset = self.clone() | ||||||
|         queryset = queryset.limit(2) |         queryset = queryset.order_by().limit(2) | ||||||
|         queryset = queryset.filter(*q_objs, **query) |         queryset = queryset.filter(*q_objs, **query) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
| @@ -443,6 +462,8 @@ class BaseQuerySet(object): | |||||||
|                 return result |                 return result | ||||||
|             elif result: |             elif result: | ||||||
|                 return result['n'] |                 return result['n'] | ||||||
|  |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|  |             raise NotUniqueError(u'Update failed (%s)' % unicode(err)) | ||||||
|         except pymongo.errors.OperationFailure, err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             if unicode(err) == u'multi not coded yet': |             if unicode(err) == u'multi not coded yet': | ||||||
|                 message = u'update() method requires MongoDB 1.1.3+' |                 message = u'update() method requires MongoDB 1.1.3+' | ||||||
| @@ -466,6 +487,59 @@ class BaseQuerySet(object): | |||||||
|         return self.update( |         return self.update( | ||||||
|             upsert=upsert, multi=False, write_concern=write_concern, **update) |             upsert=upsert, multi=False, write_concern=write_concern, **update) | ||||||
|  |  | ||||||
|  |     def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): | ||||||
|  |         """Update and return the updated document. | ||||||
|  |  | ||||||
|  |         Returns either the document before or after modification based on `new` | ||||||
|  |         parameter. If no documents match the query and `upsert` is false, | ||||||
|  |         returns ``None``. If upserting and `new` is false, returns ``None``. | ||||||
|  |  | ||||||
|  |         If the full_response parameter is ``True``, the return value will be | ||||||
|  |         the entire response object from the server, including the 'ok' and | ||||||
|  |         'lastErrorObject' fields, rather than just the modified document. | ||||||
|  |         This is useful mainly because the 'lastErrorObject' document holds | ||||||
|  |         information about the command's execution. | ||||||
|  |  | ||||||
|  |         :param upsert: insert if document doesn't exist (default ``False``) | ||||||
|  |         :param full_response: return the entire response object from the | ||||||
|  |             server (default ``False``) | ||||||
|  |         :param remove: remove rather than updating (default ``False``) | ||||||
|  |         :param new: return updated rather than original document | ||||||
|  |             (default ``False``) | ||||||
|  |         :param update: Django-style update keyword arguments | ||||||
|  |  | ||||||
|  |         .. versionadded:: 0.9 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if remove and new: | ||||||
|  |             raise OperationError("Conflicting parameters: remove and new") | ||||||
|  |  | ||||||
|  |         if not update and not upsert and not remove: | ||||||
|  |             raise OperationError("No update parameters, must either update or remove") | ||||||
|  |  | ||||||
|  |         queryset = self.clone() | ||||||
|  |         query = queryset._query | ||||||
|  |         update = transform.update(queryset._document, **update) | ||||||
|  |         sort = queryset._ordering | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             result = queryset._collection.find_and_modify( | ||||||
|  |                 query, update, upsert=upsert, sort=sort, remove=remove, new=new, | ||||||
|  |                 full_response=full_response, **self._cursor_args) | ||||||
|  |         except pymongo.errors.DuplicateKeyError, err: | ||||||
|  |             raise NotUniqueError(u"Update failed (%s)" % err) | ||||||
|  |         except pymongo.errors.OperationFailure, err: | ||||||
|  |             raise OperationError(u"Update failed (%s)" % err) | ||||||
|  |  | ||||||
|  |         if full_response: | ||||||
|  |             if result["value"] is not None: | ||||||
|  |                 result["value"] = self._document._from_son(result["value"]) | ||||||
|  |         else: | ||||||
|  |             if result is not None: | ||||||
|  |                 result = self._document._from_son(result) | ||||||
|  |  | ||||||
|  |         return result | ||||||
|  |  | ||||||
|     def with_id(self, object_id): |     def with_id(self, object_id): | ||||||
|         """Retrieve the object matching the id provided.  Uses `object_id` only |         """Retrieve the object matching the id provided.  Uses `object_id` only | ||||||
|         and raises InvalidQueryError if a filter has been applied. Returns |         and raises InvalidQueryError if a filter has been applied. Returns | ||||||
| @@ -522,6 +596,19 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|  |     def using(self, alias): | ||||||
|  |         """This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database. | ||||||
|  |  | ||||||
|  |         :param alias: The database alias | ||||||
|  |  | ||||||
|  |         .. versionadded:: 0.8 | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         with switch_db(self._document, alias) as cls: | ||||||
|  |             collection = cls._get_collection() | ||||||
|  |  | ||||||
|  |         return self.clone_into(self.__class__(self._document, collection)) | ||||||
|  |  | ||||||
|     def clone(self): |     def clone(self): | ||||||
|         """Creates a copy of the current |         """Creates a copy of the current | ||||||
|           :class:`~mongoengine.queryset.QuerySet` |           :class:`~mongoengine.queryset.QuerySet` | ||||||
| @@ -923,8 +1010,37 @@ class BaseQuerySet(object): | |||||||
|             map_reduce_function = 'inline_map_reduce' |             map_reduce_function = 'inline_map_reduce' | ||||||
|         else: |         else: | ||||||
|             map_reduce_function = 'map_reduce' |             map_reduce_function = 'map_reduce' | ||||||
|  |  | ||||||
|  |             if isinstance(output, basestring): | ||||||
|                 mr_args['out'] = output |                 mr_args['out'] = output | ||||||
|  |  | ||||||
|  |             elif isinstance(output, dict): | ||||||
|  |                 ordered_output = [] | ||||||
|  |  | ||||||
|  |                 for part in ('replace', 'merge', 'reduce'): | ||||||
|  |                     value = output.get(part) | ||||||
|  |                     if value: | ||||||
|  |                         ordered_output.append((part, value)) | ||||||
|  |                         break | ||||||
|  |  | ||||||
|  |                 else: | ||||||
|  |                     raise OperationError("actionData not specified for output") | ||||||
|  |  | ||||||
|  |                 db_alias = output.get('db_alias') | ||||||
|  |                 remaing_args = ['db', 'sharded', 'nonAtomic'] | ||||||
|  |  | ||||||
|  |                 if db_alias: | ||||||
|  |                     ordered_output.append(('db', get_db(db_alias).name)) | ||||||
|  |                     del remaing_args[0] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |                 for part in remaing_args: | ||||||
|  |                     value = output.get(part) | ||||||
|  |                     if value: | ||||||
|  |                         ordered_output.append((part, value)) | ||||||
|  |  | ||||||
|  |                 mr_args['out'] = SON(ordered_output) | ||||||
|  |  | ||||||
|         results = getattr(queryset._collection, map_reduce_function)( |         results = getattr(queryset._collection, map_reduce_function)( | ||||||
|             map_f, reduce_f, **mr_args) |             map_f, reduce_f, **mr_args) | ||||||
|  |  | ||||||
| @@ -1189,8 +1305,9 @@ class BaseQuerySet(object): | |||||||
|             if self._ordering: |             if self._ordering: | ||||||
|                 # Apply query ordering |                 # Apply query ordering | ||||||
|                 self._cursor_obj.sort(self._ordering) |                 self._cursor_obj.sort(self._ordering) | ||||||
|             elif self._document._meta['ordering']: |             elif self._ordering is None and self._document._meta['ordering']: | ||||||
|                 # Otherwise, apply the ordering from the document model |                 # Otherwise, apply the ordering from the document model, unless | ||||||
|  |                 # it's been explicitly cleared via order_by with no arguments | ||||||
|                 order = self._get_order_by(self._document._meta['ordering']) |                 order = self._get_order_by(self._document._meta['ordering']) | ||||||
|                 self._cursor_obj.sort(order) |                 self._cursor_obj.sort(order) | ||||||
|  |  | ||||||
| @@ -1392,7 +1509,7 @@ class BaseQuerySet(object): | |||||||
|                 pass |                 pass | ||||||
|             key_list.append((key, direction)) |             key_list.append((key, direction)) | ||||||
|  |  | ||||||
|         if self._cursor_obj: |         if self._cursor_obj and key_list: | ||||||
|             self._cursor_obj.sort(key_list) |             self._cursor_obj.sort(key_list) | ||||||
|         return key_list |         return key_list | ||||||
|  |  | ||||||
| @@ -1450,6 +1567,7 @@ class BaseQuerySet(object): | |||||||
|                     # type of this field and use the corresponding |                     # type of this field and use the corresponding | ||||||
|                     # .to_python(...) |                     # .to_python(...) | ||||||
|                     from mongoengine.fields import EmbeddedDocumentField |                     from mongoengine.fields import EmbeddedDocumentField | ||||||
|  |  | ||||||
|                     obj = self._document |                     obj = self._document | ||||||
|                     for chunk in path.split('.'): |                     for chunk in path.split('.'): | ||||||
|                         obj = getattr(obj, chunk, None) |                         obj = getattr(obj, chunk, None) | ||||||
| @@ -1460,6 +1578,7 @@ class BaseQuerySet(object): | |||||||
|                     if obj and data is not None: |                     if obj and data is not None: | ||||||
|                         data = obj.to_python(data) |                         data = obj.to_python(data) | ||||||
|             return data |             return data | ||||||
|  |  | ||||||
|         return clean(row) |         return clean(row) | ||||||
|  |  | ||||||
|     def _sub_js_fields(self, code): |     def _sub_js_fields(self, code): | ||||||
| @@ -1468,6 +1587,7 @@ class BaseQuerySet(object): | |||||||
|         substituted for the MongoDB name of the field (specified using the |         substituted for the MongoDB name of the field (specified using the | ||||||
|         :attr:`name` keyword argument in a field's constructor). |         :attr:`name` keyword argument in a field's constructor). | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         def field_sub(match): |         def field_sub(match): | ||||||
|             # Extract just the field name, and look up the field objects |             # Extract just the field name, and look up the field objects | ||||||
|             field_name = match.group(1).split('.') |             field_name = match.group(1).split('.') | ||||||
|   | |||||||
| @@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet): | |||||||
|             queryset = self.clone() |             queryset = self.clone() | ||||||
|         queryset.rewind() |         queryset.rewind() | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuerySetNoDeRef(QuerySet): | ||||||
|  |     """Special no_dereference QuerySet""" | ||||||
|  |  | ||||||
|  |     def __dereference(items, max_depth=1, instance=None, name=None): | ||||||
|  |         return items | ||||||
| @@ -3,6 +3,7 @@ from collections import defaultdict | |||||||
| import pymongo | import pymongo | ||||||
| from bson import SON | from bson import SON | ||||||
|  |  | ||||||
|  | from mongoengine.connection import get_connection | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import InvalidQueryError, LookUpError | from mongoengine.errors import InvalidQueryError, LookUpError | ||||||
|  |  | ||||||
| @@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query): | |||||||
|             mongo_query.update(value) |             mongo_query.update(value) | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|         parts = key.split('__') |         parts = key.rsplit('__') | ||||||
|         indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] |         indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] | ||||||
|         parts = [part for part in parts if not part.isdigit()] |         parts = [part for part in parts if not part.isdigit()] | ||||||
|         # Check for an operator and transform to mongo-style if there is |         # Check for an operator and transform to mongo-style if there is | ||||||
| @@ -115,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query): | |||||||
|             if key in mongo_query and isinstance(mongo_query[key], dict): |             if key in mongo_query and isinstance(mongo_query[key], dict): | ||||||
|                 mongo_query[key].update(value) |                 mongo_query[key].update(value) | ||||||
|                 # $maxDistance needs to come last - convert to SON |                 # $maxDistance needs to come last - convert to SON | ||||||
|                 if '$maxDistance' in mongo_query[key]: |  | ||||||
|                 value_dict = mongo_query[key] |                 value_dict = mongo_query[key] | ||||||
|  |                 if ('$maxDistance' in value_dict and '$near' in value_dict): | ||||||
|                     value_son = SON() |                     value_son = SON() | ||||||
|  |                     if isinstance(value_dict['$near'], dict): | ||||||
|  |                         for k, v in value_dict.iteritems(): | ||||||
|  |                             if k == '$maxDistance': | ||||||
|  |                                 continue | ||||||
|  |                             value_son[k] = v | ||||||
|  |                         if (get_connection().max_wire_version <= 1): | ||||||
|  |                             value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||||
|  |                         else: | ||||||
|  |                             value_son['$near'] = SON(value_son['$near']) | ||||||
|  |                             value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] | ||||||
|  |                     else: | ||||||
|                         for k, v in value_dict.iteritems(): |                         for k, v in value_dict.iteritems(): | ||||||
|                             if k == '$maxDistance': |                             if k == '$maxDistance': | ||||||
|                                 continue |                                 continue | ||||||
|                             value_son[k] = v |                             value_son[k] = v | ||||||
|                         value_son['$maxDistance'] = value_dict['$maxDistance'] |                         value_son['$maxDistance'] = value_dict['$maxDistance'] | ||||||
|  |  | ||||||
|                     mongo_query[key] = value_son |                     mongo_query[key] = value_son | ||||||
|             else: |             else: | ||||||
|                 # Store for manually merging later |                 # Store for manually merging later | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
| import copy | import copy | ||||||
|  |  | ||||||
| from mongoengine.errors import InvalidQueryError | from itertools import product | ||||||
| from mongoengine.python_support import product, reduce | from functools import reduce | ||||||
|  |  | ||||||
|  | from mongoengine.errors import InvalidQueryError | ||||||
| from mongoengine.queryset import transform | from mongoengine.queryset import transform | ||||||
|  |  | ||||||
| __all__ = ('Q',) | __all__ = ('Q',) | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								setup.py
									
									
									
									
									
								
							| @@ -38,12 +38,14 @@ CLASSIFIERS = [ | |||||||
|     'Operating System :: OS Independent', |     'Operating System :: OS Independent', | ||||||
|     'Programming Language :: Python', |     'Programming Language :: Python', | ||||||
|     "Programming Language :: Python :: 2", |     "Programming Language :: Python :: 2", | ||||||
|     "Programming Language :: Python :: 2.6", |     "Programming Language :: Python :: 2.6.6", | ||||||
|     "Programming Language :: Python :: 2.7", |     "Programming Language :: Python :: 2.7", | ||||||
|     "Programming Language :: Python :: 3", |     "Programming Language :: Python :: 3", | ||||||
|     "Programming Language :: Python :: 3.1", |  | ||||||
|     "Programming Language :: Python :: 3.2", |     "Programming Language :: Python :: 3.2", | ||||||
|  |     "Programming Language :: Python :: 3.3", | ||||||
|  |     "Programming Language :: Python :: 3.4", | ||||||
|     "Programming Language :: Python :: Implementation :: CPython", |     "Programming Language :: Python :: Implementation :: CPython", | ||||||
|  |     "Programming Language :: Python :: Implementation :: PyPy", | ||||||
|     'Topic :: Database', |     'Topic :: Database', | ||||||
|     'Topic :: Software Development :: Libraries :: Python Modules', |     'Topic :: Software Development :: Libraries :: Python Modules', | ||||||
| ] | ] | ||||||
| @@ -51,12 +53,15 @@ CLASSIFIERS = [ | |||||||
| extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} | ||||||
| if sys.version_info[0] == 3: | if sys.version_info[0] == 3: | ||||||
|     extra_opts['use_2to3'] = True |     extra_opts['use_2to3'] = True | ||||||
|     extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] |     extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'Pillow>=2.0.0', 'django>=1.5.1'] | ||||||
|     if "test" in sys.argv or "nosetests" in sys.argv: |     if "test" in sys.argv or "nosetests" in sys.argv: | ||||||
|         extra_opts['packages'] = find_packages() |         extra_opts['packages'] = find_packages() | ||||||
|         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} |         extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} | ||||||
| else: | else: | ||||||
|     extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] |     extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'Pillow>=2.0.0', 'jinja2>=2.6', 'python-dateutil'] | ||||||
|  |  | ||||||
|  |     if sys.version_info[0] == 2 and sys.version_info[1] == 6: | ||||||
|  |         extra_opts['tests_require'].append('unittest2') | ||||||
|  |  | ||||||
| setup(name='mongoengine', | setup(name='mongoengine', | ||||||
|       version=VERSION, |       version=VERSION, | ||||||
| @@ -72,7 +77,7 @@ setup(name='mongoengine', | |||||||
|       long_description=LONG_DESCRIPTION, |       long_description=LONG_DESCRIPTION, | ||||||
|       platforms=['any'], |       platforms=['any'], | ||||||
|       classifiers=CLASSIFIERS, |       classifiers=CLASSIFIERS, | ||||||
|       install_requires=['pymongo>=2.5'], |       install_requires=['pymongo>=2.7'], | ||||||
|       test_suite='nose.collector', |       test_suite='nose.collector', | ||||||
|       **extra_opts |       **extra_opts | ||||||
| ) | ) | ||||||
|   | |||||||
| @@ -735,5 +735,47 @@ class DeltaTest(unittest.TestCase): | |||||||
|         mydoc._clear_changed_fields() |         mydoc._clear_changed_fields() | ||||||
|         self.assertEqual([], mydoc._get_changed_fields()) |         self.assertEqual([], mydoc._get_changed_fields()) | ||||||
|  |  | ||||||
|  |     def test_referenced_object_changed_attributes(self): | ||||||
|  |         """Ensures that when you save a new reference to a field, the referenced object isn't altered""" | ||||||
|  |  | ||||||
|  |         class Organization(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             org = ReferenceField('Organization', required=True) | ||||||
|  |  | ||||||
|  |         Organization.drop_collection() | ||||||
|  |         User.drop_collection() | ||||||
|  |  | ||||||
|  |         org1 = Organization(name='Org 1') | ||||||
|  |         org1.save() | ||||||
|  |  | ||||||
|  |         org2 = Organization(name='Org 2') | ||||||
|  |         org2.save() | ||||||
|  |  | ||||||
|  |         user = User(name='Fred', org=org1) | ||||||
|  |         user.save() | ||||||
|  |  | ||||||
|  |         org1.reload() | ||||||
|  |         org2.reload() | ||||||
|  |         user.reload() | ||||||
|  |         self.assertEqual(org1.name, 'Org 1') | ||||||
|  |         self.assertEqual(org2.name, 'Org 2') | ||||||
|  |         self.assertEqual(user.name, 'Fred') | ||||||
|  |  | ||||||
|  |         user.name = 'Harold' | ||||||
|  |         user.org = org2 | ||||||
|  |  | ||||||
|  |         org2.name = 'New Org 2' | ||||||
|  |         self.assertEqual(org2.name, 'New Org 2') | ||||||
|  |  | ||||||
|  |         user.save() | ||||||
|  |         org2.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(org2.name, 'New Org 2') | ||||||
|  |         org2.reload() | ||||||
|  |         self.assertEqual(org2.name, 'New Org 2') | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -292,6 +292,44 @@ class DynamicTest(unittest.TestCase): | |||||||
|         person.save() |         person.save() | ||||||
|         self.assertEqual(Person.objects.first().age, 35) |         self.assertEqual(Person.objects.first().age, 35) | ||||||
|  |  | ||||||
|  |     def test_dynamic_and_embedded_dict_access(self): | ||||||
|  |         """Ensure embedded dynamic documents work with dict[] style access""" | ||||||
|  |  | ||||||
|  |         class Address(EmbeddedDocument): | ||||||
|  |             city = StringField() | ||||||
|  |  | ||||||
|  |         class Person(DynamicDocument): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |  | ||||||
|  |         Person(name="Ross", address=Address(city="London")).save() | ||||||
|  |  | ||||||
|  |         person = Person.objects.first() | ||||||
|  |         person.attrval = "This works" | ||||||
|  |  | ||||||
|  |         person["phone"] = "555-1212" # but this should too | ||||||
|  |  | ||||||
|  |         # Same thing two levels deep | ||||||
|  |         person["address"]["city"] = "Lundenne" | ||||||
|  |         person.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(Person.objects.first().address.city, "Lundenne") | ||||||
|  |  | ||||||
|  |         self.assertEqual(Person.objects.first().phone, "555-1212") | ||||||
|  |  | ||||||
|  |         person = Person.objects.first() | ||||||
|  |         person.address = Address(city="Londinium") | ||||||
|  |         person.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(Person.objects.first().address.city, "Londinium") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         person = Person.objects.first() | ||||||
|  |         person["age"] = 35 | ||||||
|  |         person.save() | ||||||
|  |         self.assertEqual(Person.objects.first().age, 35) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -15,7 +15,7 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, | |||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.errors import (NotRegistered, InvalidDocumentError, | from mongoengine.errors import (NotRegistered, InvalidDocumentError, | ||||||
|                                 InvalidQueryError) |                                 InvalidQueryError, NotUniqueError) | ||||||
| from mongoengine.queryset import NULLIFY, Q | from mongoengine.queryset import NULLIFY, Q | ||||||
| from mongoengine.connection import get_db | from mongoengine.connection import get_db | ||||||
| from mongoengine.base import get_document | from mongoengine.base import get_document | ||||||
| @@ -57,7 +57,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|             date = DateTimeField(default=datetime.now) |             date = DateTimeField(default=datetime.now) | ||||||
|             meta = { |             meta = { | ||||||
|                 'max_documents': 10, |                 'max_documents': 10, | ||||||
|                 'max_size': 90000, |                 'max_size': 4096, | ||||||
|             } |             } | ||||||
|  |  | ||||||
|         Log.drop_collection() |         Log.drop_collection() | ||||||
| @@ -75,7 +75,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|         options = Log.objects._collection.options() |         options = Log.objects._collection.options() | ||||||
|         self.assertEqual(options['capped'], True) |         self.assertEqual(options['capped'], True) | ||||||
|         self.assertEqual(options['max'], 10) |         self.assertEqual(options['max'], 10) | ||||||
|         self.assertEqual(options['size'], 90000) |         self.assertTrue(options['size'] >= 4096) | ||||||
|  |  | ||||||
|         # Check that the document cannot be redefined with different options |         # Check that the document cannot be redefined with different options | ||||||
|         def recreate_log_document(): |         def recreate_log_document(): | ||||||
| @@ -353,6 +353,14 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(person.name, "Test User") |         self.assertEqual(person.name, "Test User") | ||||||
|         self.assertEqual(person.age, 20) |         self.assertEqual(person.age, 20) | ||||||
|  |  | ||||||
|  |         person.reload('age') | ||||||
|  |         self.assertEqual(person.name, "Test User") | ||||||
|  |         self.assertEqual(person.age, 21) | ||||||
|  |  | ||||||
|  |         person.reload() | ||||||
|  |         self.assertEqual(person.name, "Mr Test User") | ||||||
|  |         self.assertEqual(person.age, 21) | ||||||
|  |  | ||||||
|         person.reload() |         person.reload() | ||||||
|         self.assertEqual(person.name, "Mr Test User") |         self.assertEqual(person.name, "Mr Test User") | ||||||
|         self.assertEqual(person.age, 21) |         self.assertEqual(person.age, 21) | ||||||
| @@ -402,6 +410,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|             'embedded_field.dict_field']) |             'embedded_field.dict_field']) | ||||||
|         doc.save() |         doc.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(len(doc.list_field), 4) | ||||||
|         doc = doc.reload(10) |         doc = doc.reload(10) | ||||||
|         self.assertEqual(doc._get_changed_fields(), []) |         self.assertEqual(doc._get_changed_fields(), []) | ||||||
|         self.assertEqual(len(doc.list_field), 4) |         self.assertEqual(len(doc.list_field), 4) | ||||||
| @@ -409,6 +418,16 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertEqual(len(doc.embedded_field.list_field), 4) |         self.assertEqual(len(doc.embedded_field.list_field), 4) | ||||||
|         self.assertEqual(len(doc.embedded_field.dict_field), 2) |         self.assertEqual(len(doc.embedded_field.dict_field), 2) | ||||||
|  |  | ||||||
|  |         doc.list_field.append(1) | ||||||
|  |         doc.save() | ||||||
|  |         doc.dict_field['extra'] = 1 | ||||||
|  |         doc = doc.reload(10, 'list_field') | ||||||
|  |         self.assertEqual(doc._get_changed_fields(), []) | ||||||
|  |         self.assertEqual(len(doc.list_field), 5) | ||||||
|  |         self.assertEqual(len(doc.dict_field), 3) | ||||||
|  |         self.assertEqual(len(doc.embedded_field.list_field), 4) | ||||||
|  |         self.assertEqual(len(doc.embedded_field.dict_field), 2) | ||||||
|  |  | ||||||
|     def test_reload_doesnt_exist(self): |     def test_reload_doesnt_exist(self): | ||||||
|         class Foo(Document): |         class Foo(Document): | ||||||
|             pass |             pass | ||||||
| @@ -515,9 +534,6 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         class Email(EmbeddedDocument): |         class Email(EmbeddedDocument): | ||||||
|             email = EmailField() |             email = EmailField() | ||||||
|             def clean(self): |  | ||||||
|                 print "instance:" |  | ||||||
|                 print self._instance |  | ||||||
|  |  | ||||||
|         class Account(Document): |         class Account(Document): | ||||||
|             email = EmbeddedDocumentField(Email) |             email = EmbeddedDocumentField(Email) | ||||||
| @@ -820,6 +836,80 @@ class InstanceTest(unittest.TestCase): | |||||||
|         p1.reload() |         p1.reload() | ||||||
|         self.assertEqual(p1.name, p.parent.name) |         self.assertEqual(p1.name, p.parent.name) | ||||||
|  |  | ||||||
|  |     def test_save_atomicity_condition(self): | ||||||
|  |  | ||||||
|  |         class Widget(Document): | ||||||
|  |             toggle = BooleanField(default=False) | ||||||
|  |             count = IntField(default=0) | ||||||
|  |             save_id = UUIDField() | ||||||
|  |  | ||||||
|  |         def flip(widget): | ||||||
|  |             widget.toggle = not widget.toggle | ||||||
|  |             widget.count += 1 | ||||||
|  |  | ||||||
|  |         def UUID(i): | ||||||
|  |             return uuid.UUID(int=i) | ||||||
|  |  | ||||||
|  |         Widget.drop_collection() | ||||||
|  |  | ||||||
|  |         w1 = Widget(toggle=False, save_id=UUID(1)) | ||||||
|  |  | ||||||
|  |         # ignore save_condition on new record creation | ||||||
|  |         w1.save(save_condition={'save_id':UUID(42)}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertFalse(w1.toggle) | ||||||
|  |         self.assertEqual(w1.save_id, UUID(1)) | ||||||
|  |         self.assertEqual(w1.count, 0) | ||||||
|  |  | ||||||
|  |         # mismatch in save_condition prevents save | ||||||
|  |         flip(w1) | ||||||
|  |         self.assertTrue(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 1) | ||||||
|  |         w1.save(save_condition={'save_id':UUID(42)}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertFalse(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 0) | ||||||
|  |  | ||||||
|  |         # matched save_condition allows save | ||||||
|  |         flip(w1) | ||||||
|  |         self.assertTrue(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 1) | ||||||
|  |         w1.save(save_condition={'save_id':UUID(1)}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertTrue(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 1) | ||||||
|  |  | ||||||
|  |         # save_condition can be used to ensure atomic read & updates | ||||||
|  |         # i.e., prevent interleaved reads and writes from separate contexts | ||||||
|  |         w2 = Widget.objects.get() | ||||||
|  |         self.assertEqual(w1, w2) | ||||||
|  |         old_id = w1.save_id | ||||||
|  |  | ||||||
|  |         flip(w1) | ||||||
|  |         w1.save_id = UUID(2) | ||||||
|  |         w1.save(save_condition={'save_id':old_id}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertFalse(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 2) | ||||||
|  |         flip(w2) | ||||||
|  |         flip(w2) | ||||||
|  |         w2.save(save_condition={'save_id':old_id}) | ||||||
|  |         w2.reload() | ||||||
|  |         self.assertFalse(w2.toggle) | ||||||
|  |         self.assertEqual(w2.count, 2) | ||||||
|  |  | ||||||
|  |         # save_condition uses mongoengine-style operator syntax | ||||||
|  |         flip(w1) | ||||||
|  |         w1.save(save_condition={'count__lt':w1.count}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertTrue(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 3) | ||||||
|  |         flip(w1) | ||||||
|  |         w1.save(save_condition={'count__gte':w1.count}) | ||||||
|  |         w1.reload() | ||||||
|  |         self.assertTrue(w1.toggle) | ||||||
|  |         self.assertEqual(w1.count, 3) | ||||||
|  |  | ||||||
|     def test_update(self): |     def test_update(self): | ||||||
|         """Ensure that an existing document is updated instead of be |         """Ensure that an existing document is updated instead of be | ||||||
|         overwritten.""" |         overwritten.""" | ||||||
| @@ -990,6 +1080,16 @@ class InstanceTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertRaises(InvalidQueryError, update_no_op_raises) |         self.assertRaises(InvalidQueryError, update_no_op_raises) | ||||||
|  |  | ||||||
|  |     def test_update_unique_field(self): | ||||||
|  |         class Doc(Document): | ||||||
|  |             name = StringField(unique=True) | ||||||
|  |  | ||||||
|  |         doc1 = Doc(name="first").save() | ||||||
|  |         doc2 = Doc(name="second").save() | ||||||
|  |  | ||||||
|  |         self.assertRaises(NotUniqueError, lambda: | ||||||
|  |                           doc2.update(set__name=doc1.name)) | ||||||
|  |  | ||||||
|     def test_embedded_update(self): |     def test_embedded_update(self): | ||||||
|         """ |         """ | ||||||
|         Test update on `EmbeddedDocumentField` fields |         Test update on `EmbeddedDocumentField` fields | ||||||
| @@ -2281,6 +2381,8 @@ class InstanceTest(unittest.TestCase): | |||||||
|         log.machine = "Localhost" |         log.machine = "Localhost" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
|  |         self.assertTrue(log.id is not None) | ||||||
|  |  | ||||||
|         log.log = "Saving" |         log.log = "Saving" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
| @@ -2304,6 +2406,8 @@ class InstanceTest(unittest.TestCase): | |||||||
|         log.machine = "Localhost" |         log.machine = "Localhost" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
|  |         self.assertTrue(log.id is not None) | ||||||
|  |  | ||||||
|         log.log = "Saving" |         log.log = "Saving" | ||||||
|         log.save() |         log.save() | ||||||
|  |  | ||||||
| @@ -2411,7 +2515,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|                 for parameter_name, parameter in self.parameters.iteritems(): |                 for parameter_name, parameter in self.parameters.iteritems(): | ||||||
|                     parameter.expand() |                     parameter.expand() | ||||||
|  |  | ||||||
|         class System(Document): |         class NodesSystem(Document): | ||||||
|             name = StringField(required=True) |             name = StringField(required=True) | ||||||
|             nodes = MapField(ReferenceField(Node, dbref=False)) |             nodes = MapField(ReferenceField(Node, dbref=False)) | ||||||
|  |  | ||||||
| @@ -2419,18 +2523,18 @@ class InstanceTest(unittest.TestCase): | |||||||
|                 for node_name, node in self.nodes.iteritems(): |                 for node_name, node in self.nodes.iteritems(): | ||||||
|                     node.expand() |                     node.expand() | ||||||
|                     node.save(*args, **kwargs) |                     node.save(*args, **kwargs) | ||||||
|                 super(System, self).save(*args, **kwargs) |                 super(NodesSystem, self).save(*args, **kwargs) | ||||||
|  |  | ||||||
|         System.drop_collection() |         NodesSystem.drop_collection() | ||||||
|         Node.drop_collection() |         Node.drop_collection() | ||||||
|  |  | ||||||
|         system = System(name="system") |         system = NodesSystem(name="system") | ||||||
|         system.nodes["node"] = Node() |         system.nodes["node"] = Node() | ||||||
|         system.save() |         system.save() | ||||||
|         system.nodes["node"].parameters["param"] = Parameter() |         system.nodes["node"].parameters["param"] = Parameter() | ||||||
|         system.save() |         system.save() | ||||||
|  |  | ||||||
|         system = System.objects.first() |         system = NodesSystem.objects.first() | ||||||
|         self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) |         self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) | ||||||
|  |  | ||||||
|     def test_embedded_document_equality(self): |     def test_embedded_document_equality(self): | ||||||
| @@ -2452,5 +2556,65 @@ class InstanceTest(unittest.TestCase): | |||||||
|         f1.ref  # Dereferences lazily |         f1.ref  # Dereferences lazily | ||||||
|         self.assertEqual(f1, f2) |         self.assertEqual(f1, f2) | ||||||
|  |  | ||||||
|  |     def test_dbref_equality(self): | ||||||
|  |         class Test2(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Test3(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Test(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             test2 = ReferenceField('Test2') | ||||||
|  |             test3 = ReferenceField('Test3') | ||||||
|  |  | ||||||
|  |         Test.drop_collection() | ||||||
|  |         Test2.drop_collection() | ||||||
|  |         Test3.drop_collection() | ||||||
|  |  | ||||||
|  |         t2 = Test2(name='a') | ||||||
|  |         t2.save() | ||||||
|  |  | ||||||
|  |         t3 = Test3(name='x') | ||||||
|  |         t3.id = t2.id | ||||||
|  |         t3.save() | ||||||
|  |  | ||||||
|  |         t = Test(name='b', test2=t2, test3=t3) | ||||||
|  |  | ||||||
|  |         f = Test._from_son(t.to_mongo()) | ||||||
|  |  | ||||||
|  |         dbref2 = f._data['test2'] | ||||||
|  |         obj2 = f.test2 | ||||||
|  |         self.assertTrue(isinstance(dbref2, DBRef)) | ||||||
|  |         self.assertTrue(isinstance(obj2, Test2)) | ||||||
|  |         self.assertTrue(obj2.id == dbref2.id) | ||||||
|  |         self.assertTrue(obj2 == dbref2) | ||||||
|  |         self.assertTrue(dbref2 == obj2) | ||||||
|  |  | ||||||
|  |         dbref3 = f._data['test3'] | ||||||
|  |         obj3 = f.test3 | ||||||
|  |         self.assertTrue(isinstance(dbref3, DBRef)) | ||||||
|  |         self.assertTrue(isinstance(obj3, Test3)) | ||||||
|  |         self.assertTrue(obj3.id == dbref3.id) | ||||||
|  |         self.assertTrue(obj3 == dbref3) | ||||||
|  |         self.assertTrue(dbref3 == obj3) | ||||||
|  |  | ||||||
|  |         self.assertTrue(obj2.id == obj3.id) | ||||||
|  |         self.assertTrue(dbref2.id == dbref3.id) | ||||||
|  |         self.assertFalse(dbref2 == dbref3) | ||||||
|  |         self.assertFalse(dbref3 == dbref2) | ||||||
|  |         self.assertTrue(dbref2 != dbref3) | ||||||
|  |         self.assertTrue(dbref3 != dbref2) | ||||||
|  |  | ||||||
|  |         self.assertFalse(obj2 == dbref3) | ||||||
|  |         self.assertFalse(dbref3 == obj2) | ||||||
|  |         self.assertTrue(obj2 != dbref3) | ||||||
|  |         self.assertTrue(dbref3 != obj2) | ||||||
|  |  | ||||||
|  |         self.assertFalse(obj3 == dbref2) | ||||||
|  |         self.assertFalse(dbref2 == obj3) | ||||||
|  |         self.assertTrue(obj3 != dbref2) | ||||||
|  |         self.assertTrue(dbref2 != obj3) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -279,7 +279,7 @@ class FileTest(unittest.TestCase): | |||||||
|                 t.image.put(f) |                 t.image.put(f) | ||||||
|                 self.fail("Should have raised an invalidation error") |                 self.fail("Should have raised an invalidation error") | ||||||
|             except ValidationError, e: |             except ValidationError, e: | ||||||
|                 self.assertEqual("%s" % e, "Invalid image: cannot identify image file") |                 self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) | ||||||
|  |  | ||||||
|         t = TestImage() |         t = TestImage() | ||||||
|         t.image.put(open(TEST_IMAGE_PATH, 'rb')) |         t.image.put(open(TEST_IMAGE_PATH, 'rb')) | ||||||
|   | |||||||
| @@ -3,3 +3,4 @@ from field_list import * | |||||||
| from queryset import * | from queryset import * | ||||||
| from visitor import * | from visitor import * | ||||||
| from geo import * | from geo import * | ||||||
|  | from modify import * | ||||||
| @@ -5,6 +5,8 @@ import unittest | |||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|  |  | ||||||
|  | from nose.plugins.skip import SkipTest | ||||||
|  |  | ||||||
| __all__ = ("GeoQueriesTest",) | __all__ = ("GeoQueriesTest",) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -139,6 +141,7 @@ class GeoQueriesTest(unittest.TestCase): | |||||||
|     def test_spherical_geospatial_operators(self): |     def test_spherical_geospatial_operators(self): | ||||||
|         """Ensure that spherical geospatial queries are working |         """Ensure that spherical geospatial queries are working | ||||||
|         """ |         """ | ||||||
|  |         raise SkipTest("https://jira.mongodb.org/browse/SERVER-14039") | ||||||
|         class Point(Document): |         class Point(Document): | ||||||
|             location = GeoPointField() |             location = GeoPointField() | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										102
									
								
								tests/queryset/modify.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								tests/queryset/modify.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,102 @@ | |||||||
|  | import sys | ||||||
|  | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import connect, Document, IntField | ||||||
|  |  | ||||||
|  | __all__ = ("FindAndModifyTest",) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Doc(Document): | ||||||
|  |     id = IntField(primary_key=True) | ||||||
|  |     value = IntField() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FindAndModifyTest(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db="mongoenginetest") | ||||||
|  |         Doc.drop_collection() | ||||||
|  |  | ||||||
|  |     def assertDbEqual(self, docs): | ||||||
|  |         self.assertEqual(list(Doc._collection.find().sort("id")), docs) | ||||||
|  |  | ||||||
|  |     def test_modify(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         doc = Doc(id=1, value=1).save() | ||||||
|  |  | ||||||
|  |         old_doc = Doc.objects(id=1).modify(set__value=-1) | ||||||
|  |         self.assertEqual(old_doc.to_json(), doc.to_json()) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_new(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         doc = Doc(id=1, value=1).save() | ||||||
|  |  | ||||||
|  |         new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) | ||||||
|  |         doc.value = -1 | ||||||
|  |         self.assertEqual(new_doc.to_json(), doc.to_json()) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_not_existing(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_upsert(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) | ||||||
|  |         self.assertEqual(old_doc, None) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_upsert_existing(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         doc = Doc(id=1, value=1).save() | ||||||
|  |  | ||||||
|  |         old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) | ||||||
|  |         self.assertEqual(old_doc.to_json(), doc.to_json()) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_upsert_with_new(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) | ||||||
|  |         self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1}) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_remove(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         doc = Doc(id=1, value=1).save() | ||||||
|  |  | ||||||
|  |         old_doc = Doc.objects(id=1).modify(remove=True) | ||||||
|  |         self.assertEqual(old_doc.to_json(), doc.to_json()) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}]) | ||||||
|  |  | ||||||
|  |     def test_find_and_modify_with_remove_not_existing(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         self.assertEqual(Doc.objects(id=1).modify(remove=True), None) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_order_by(self): | ||||||
|  |         Doc(id=0, value=3).save() | ||||||
|  |         Doc(id=1, value=2).save() | ||||||
|  |         Doc(id=2, value=1).save() | ||||||
|  |         doc = Doc(id=3, value=0).save() | ||||||
|  |  | ||||||
|  |         old_doc = Doc.objects().order_by("-id").modify(set__value=-1) | ||||||
|  |         self.assertEqual(old_doc.to_json(), doc.to_json()) | ||||||
|  |         self.assertDbEqual([ | ||||||
|  |             {"_id": 0, "value": 3}, {"_id": 1, "value": 2}, | ||||||
|  |             {"_id": 2, "value": 1}, {"_id": 3, "value": -1}]) | ||||||
|  |  | ||||||
|  |     def test_modify_with_fields(self): | ||||||
|  |         Doc(id=0, value=0).save() | ||||||
|  |         Doc(id=1, value=1).save() | ||||||
|  |  | ||||||
|  |         old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) | ||||||
|  |         self.assertEqual(old_doc.to_mongo(), {"_id": 1}) | ||||||
|  |         self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     unittest.main() | ||||||
| @@ -14,9 +14,9 @@ from pymongo.read_preferences import ReadPreference | |||||||
| from bson import ObjectId | from bson import ObjectId | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.connection import get_connection | from mongoengine.connection import get_connection, get_db | ||||||
| from mongoengine.python_support import PY3 | from mongoengine.python_support import PY3 | ||||||
| from mongoengine.context_managers import query_counter | from mongoengine.context_managers import query_counter, switch_db | ||||||
| from mongoengine.queryset import (QuerySet, QuerySetManager, | from mongoengine.queryset import (QuerySet, QuerySetManager, | ||||||
|                                   MultipleObjectsReturned, DoesNotExist, |                                   MultipleObjectsReturned, DoesNotExist, | ||||||
|                                   queryset_manager) |                                   queryset_manager) | ||||||
| @@ -25,10 +25,17 @@ from mongoengine.errors import InvalidQueryError | |||||||
| __all__ = ("QuerySetTest",) | __all__ = ("QuerySetTest",) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class db_ops_tracker(query_counter): | ||||||
|  |     def get_ops(self): | ||||||
|  |         ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} | ||||||
|  |         return list(self.db.system.profile.find(ignore_query)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuerySetTest(unittest.TestCase): | class QuerySetTest(unittest.TestCase): | ||||||
|  |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         connect(db='mongoenginetest') |         connect(db='mongoenginetest') | ||||||
|  |         connect(db='mongoenginetest2', alias='test2') | ||||||
|  |  | ||||||
|         class PersonMeta(EmbeddedDocument): |         class PersonMeta(EmbeddedDocument): | ||||||
|             weight = IntField() |             weight = IntField() | ||||||
| @@ -650,7 +657,10 @@ class QuerySetTest(unittest.TestCase): | |||||||
|                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) |                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) | ||||||
|  |  | ||||||
|             Blog.objects.insert(blogs, load_bulk=False) |             Blog.objects.insert(blogs, load_bulk=False) | ||||||
|             self.assertEqual(q, 1)  # 1 for the insert |             if (get_connection().max_wire_version <= 1): | ||||||
|  |                 self.assertEqual(q, 1) | ||||||
|  |             else: | ||||||
|  |                 self.assertEqual(q, 99)  # profiling logs each doc now in the bulk op | ||||||
|  |  | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
|         Blog.ensure_indexes() |         Blog.ensure_indexes() | ||||||
| @@ -659,7 +669,10 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.assertEqual(q, 0) |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|             Blog.objects.insert(blogs) |             Blog.objects.insert(blogs) | ||||||
|  |             if (get_connection().max_wire_version <= 1): | ||||||
|                 self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch |                 self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch | ||||||
|  |             else: | ||||||
|  |                 self.assertEqual(q, 100)  # 99 for insert, and 1 for in bulk fetch | ||||||
|  |  | ||||||
|         Blog.drop_collection() |         Blog.drop_collection() | ||||||
|  |  | ||||||
| @@ -1040,6 +1053,54 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         expected = [blog_post_1, blog_post_2, blog_post_3] |         expected = [blog_post_1, blog_post_2, blog_post_3] | ||||||
|         self.assertSequence(qs, expected) |         self.assertSequence(qs, expected) | ||||||
|  |  | ||||||
|  |     def test_clear_ordering(self): | ||||||
|  |         """ Ensure that the default ordering can be cleared by calling order_by(). | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             title = StringField() | ||||||
|  |             published_date = DateTimeField() | ||||||
|  |  | ||||||
|  |             meta = { | ||||||
|  |                 'ordering': ['-published_date'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         with db_ops_tracker() as q: | ||||||
|  |             BlogPost.objects.filter(title='whatever').first() | ||||||
|  |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|  |             self.assertEqual(q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) | ||||||
|  |  | ||||||
|  |         with db_ops_tracker() as q: | ||||||
|  |             BlogPost.objects.filter(title='whatever').order_by().first() | ||||||
|  |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|  |             print q.get_ops()[0]['query'] | ||||||
|  |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|  |     def test_no_ordering_for_get(self): | ||||||
|  |         """ Ensure that Doc.objects.get doesn't use any ordering. | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             title = StringField() | ||||||
|  |             published_date = DateTimeField() | ||||||
|  |  | ||||||
|  |             meta = { | ||||||
|  |                 'ordering': ['-published_date'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         BlogPost.objects.create(title='whatever', published_date=datetime.utcnow()) | ||||||
|  |  | ||||||
|  |         with db_ops_tracker() as q: | ||||||
|  |             BlogPost.objects.get(title='whatever') | ||||||
|  |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|  |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|  |         # Ordering should be ignored for .get even if we set it explicitly | ||||||
|  |         with db_ops_tracker() as q: | ||||||
|  |             BlogPost.objects.order_by('-title').get(title='whatever') | ||||||
|  |             self.assertEqual(len(q.get_ops()), 1) | ||||||
|  |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|     def test_find_embedded(self): |     def test_find_embedded(self): | ||||||
|         """Ensure that an embedded document is properly returned from a query. |         """Ensure that an embedded document is properly returned from a query. | ||||||
|         """ |         """ | ||||||
| @@ -1925,6 +1986,140 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_map_reduce_custom_output(self): | ||||||
|  |         """ | ||||||
|  |         Test map/reduce custom output | ||||||
|  |         """ | ||||||
|  |         register_connection('test2', 'mongoenginetest2') | ||||||
|  |  | ||||||
|  |         class Family(Document): | ||||||
|  |             id = IntField( | ||||||
|  |                 primary_key=True) | ||||||
|  |             log = StringField() | ||||||
|  |  | ||||||
|  |         class Person(Document): | ||||||
|  |             id = IntField( | ||||||
|  |                 primary_key=True) | ||||||
|  |             name = StringField() | ||||||
|  |             age = IntField() | ||||||
|  |             family = ReferenceField(Family) | ||||||
|  |  | ||||||
|  |         Family.drop_collection() | ||||||
|  |         Person.drop_collection() | ||||||
|  |  | ||||||
|  |         # creating first family | ||||||
|  |         f1 = Family(id=1, log="Trav 02 de Julho") | ||||||
|  |         f1.save() | ||||||
|  |  | ||||||
|  |         # persons of first family | ||||||
|  |         Person(id=1, family=f1, name=u"Wilson Jr", age=21).save() | ||||||
|  |         Person(id=2, family=f1, name=u"Wilson Father", age=45).save() | ||||||
|  |         Person(id=3, family=f1, name=u"Eliana Costa", age=40).save() | ||||||
|  |         Person(id=4, family=f1, name=u"Tayza Mariana", age=17).save() | ||||||
|  |  | ||||||
|  |         # creating second family | ||||||
|  |         f2 = Family(id=2, log="Av prof frasc brunno") | ||||||
|  |         f2.save() | ||||||
|  |  | ||||||
|  |         #persons of second family | ||||||
|  |         Person(id=5, family=f2, name="Isabella Luanna", age=16).save() | ||||||
|  |         Person(id=6, family=f2, name="Sandra Mara", age=36).save() | ||||||
|  |         Person(id=7, family=f2, name="Igor Gabriel", age=10).save() | ||||||
|  |  | ||||||
|  |         # creating third family | ||||||
|  |         f3 = Family(id=3, log="Av brazil") | ||||||
|  |         f3.save() | ||||||
|  |  | ||||||
|  |         #persons of thrird family | ||||||
|  |         Person(id=8, family=f3, name="Arthur WA", age=30).save() | ||||||
|  |         Person(id=9, family=f3, name="Paula Leonel", age=25).save() | ||||||
|  |  | ||||||
|  |         # executing join map/reduce | ||||||
|  |         map_person = """ | ||||||
|  |             function () { | ||||||
|  |                 emit(this.family, { | ||||||
|  |                      totalAge: this.age, | ||||||
|  |                      persons: [{ | ||||||
|  |                         name: this.name, | ||||||
|  |                         age: this.age | ||||||
|  |                 }]}); | ||||||
|  |             } | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         map_family = """ | ||||||
|  |             function () { | ||||||
|  |                 emit(this._id, { | ||||||
|  |                    totalAge: 0, | ||||||
|  |                    persons: [] | ||||||
|  |                 }); | ||||||
|  |             } | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         reduce_f = """ | ||||||
|  |             function (key, values) { | ||||||
|  |                 var family = {persons: [], totalAge: 0}; | ||||||
|  |  | ||||||
|  |                 values.forEach(function(value) { | ||||||
|  |                     if (value.persons) { | ||||||
|  |                         value.persons.forEach(function (person) { | ||||||
|  |                             family.persons.push(person); | ||||||
|  |                             family.totalAge += person.age; | ||||||
|  |                         }); | ||||||
|  |                     } | ||||||
|  |                 }); | ||||||
|  |  | ||||||
|  |                 return family; | ||||||
|  |             } | ||||||
|  |         """ | ||||||
|  |         cursor = Family.objects.map_reduce( | ||||||
|  |             map_f=map_family, | ||||||
|  |             reduce_f=reduce_f, | ||||||
|  |             output={'replace': 'family_map', 'db_alias': 'test2'}) | ||||||
|  |  | ||||||
|  |         # start a map/reduce | ||||||
|  |         cursor.next() | ||||||
|  |  | ||||||
|  |         results = Person.objects.map_reduce( | ||||||
|  |             map_f=map_person, | ||||||
|  |             reduce_f=reduce_f, | ||||||
|  |             output={'reduce': 'family_map', 'db_alias': 'test2'}) | ||||||
|  |  | ||||||
|  |         results = list(results) | ||||||
|  |         collection = get_db('test2').family_map | ||||||
|  |  | ||||||
|  |         self.assertEqual( | ||||||
|  |             collection.find_one({'_id': 1}), { | ||||||
|  |                 '_id': 1, | ||||||
|  |                 'value': { | ||||||
|  |                     'persons': [ | ||||||
|  |                         {'age': 21, 'name': u'Wilson Jr'}, | ||||||
|  |                         {'age': 45, 'name': u'Wilson Father'}, | ||||||
|  |                         {'age': 40, 'name': u'Eliana Costa'}, | ||||||
|  |                         {'age': 17, 'name': u'Tayza Mariana'}], | ||||||
|  |                     'totalAge': 123} | ||||||
|  |                 }) | ||||||
|  |  | ||||||
|  |         self.assertEqual( | ||||||
|  |             collection.find_one({'_id': 2}), { | ||||||
|  |                 '_id': 2, | ||||||
|  |                 'value': { | ||||||
|  |                     'persons': [ | ||||||
|  |                         {'age': 16, 'name': u'Isabella Luanna'}, | ||||||
|  |                         {'age': 36, 'name': u'Sandra Mara'}, | ||||||
|  |                         {'age': 10, 'name': u'Igor Gabriel'}], | ||||||
|  |                     'totalAge': 62} | ||||||
|  |                 }) | ||||||
|  |  | ||||||
|  |         self.assertEqual( | ||||||
|  |             collection.find_one({'_id': 3}), { | ||||||
|  |                 '_id': 3, | ||||||
|  |                 'value': { | ||||||
|  |                     'persons': [ | ||||||
|  |                         {'age': 30, 'name': u'Arthur WA'}, | ||||||
|  |                         {'age': 25, 'name': u'Paula Leonel'}], | ||||||
|  |                     'totalAge': 55} | ||||||
|  |                 }) | ||||||
|  |  | ||||||
|     def test_map_reduce_finalize(self): |     def test_map_reduce_finalize(self): | ||||||
|         """Ensure that map, reduce, and finalize run and introduce "scope" |         """Ensure that map, reduce, and finalize run and introduce "scope" | ||||||
|         by simulating "hotness" ranking with Reddit algorithm. |         by simulating "hotness" ranking with Reddit algorithm. | ||||||
| @@ -2957,6 +3152,23 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         Number.drop_collection() |         Number.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_using(self): | ||||||
|  |         """Ensure that switching databases for a queryset is possible | ||||||
|  |         """ | ||||||
|  |         class Number2(Document): | ||||||
|  |             n = IntField() | ||||||
|  |  | ||||||
|  |         Number2.drop_collection() | ||||||
|  |         with switch_db(Number2, 'test2') as Number2: | ||||||
|  |             Number2.drop_collection() | ||||||
|  |  | ||||||
|  |         for i in range(1, 10): | ||||||
|  |             t = Number2(n=i) | ||||||
|  |             t.switch_db('test2') | ||||||
|  |             t.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(len(Number2.objects.using('test2')), 9) | ||||||
|  |  | ||||||
|     def test_unset_reference(self): |     def test_unset_reference(self): | ||||||
|         class Comment(Document): |         class Comment(Document): | ||||||
|             text = StringField() |             text = StringField() | ||||||
| @@ -3586,6 +3798,12 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|             [x for x in people] |             [x for x in people] | ||||||
|             self.assertEqual(100, len(people._result_cache)) |             self.assertEqual(100, len(people._result_cache)) | ||||||
|  |  | ||||||
|  |             import platform | ||||||
|  |  | ||||||
|  |             if platform.python_implementation() != "PyPy": | ||||||
|  |                 # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. | ||||||
|  |                 # This may be a bug in PyPy (PyPy/#1802) but it does not affect the behavior of MongoEngine. | ||||||
|                 self.assertEqual(None, people._len) |                 self.assertEqual(None, people._len) | ||||||
|             self.assertEqual(q, 1) |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
| @@ -3814,6 +4032,111 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         self.assertEqual(Example.objects(size=instance_size).count(), 1) |         self.assertEqual(Example.objects(size=instance_size).count(), 1) | ||||||
|         self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) |         self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) | ||||||
|  |  | ||||||
|  |     def test_cursor_in_an_if_stmt(self): | ||||||
|  |  | ||||||
|  |         class Test(Document): | ||||||
|  |             test_field = StringField() | ||||||
|  |  | ||||||
|  |         Test.drop_collection() | ||||||
|  |         queryset = Test.objects | ||||||
|  |  | ||||||
|  |         if queryset: | ||||||
|  |             raise AssertionError('Empty cursor returns True') | ||||||
|  |  | ||||||
|  |         test = Test() | ||||||
|  |         test.test_field = 'test' | ||||||
|  |         test.save() | ||||||
|  |  | ||||||
|  |         queryset = Test.objects | ||||||
|  |         if not test: | ||||||
|  |             raise AssertionError('Cursor has data and returned False') | ||||||
|  |  | ||||||
|  |         queryset.next() | ||||||
|  |         if not queryset: | ||||||
|  |             raise AssertionError('Cursor has data and it must returns True,' | ||||||
|  |                 ' even in the last item.') | ||||||
|  |  | ||||||
|  |     def test_bool_performance(self): | ||||||
|  |  | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |         for i in xrange(100): | ||||||
|  |             Person(name="No: %s" % i).save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             if Person.objects: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |             op = q.db.system.profile.find({"ns": | ||||||
|  |                 {"$ne": "%s.system.indexes" % q.db.name}})[0] | ||||||
|  |  | ||||||
|  |             self.assertEqual(op['nreturned'], 1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_bool_with_ordering(self): | ||||||
|  |  | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |         Person(name="Test").save() | ||||||
|  |  | ||||||
|  |         qs = Person.objects.order_by('name') | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |  | ||||||
|  |             if qs: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             op = q.db.system.profile.find({"ns": | ||||||
|  |                 {"$ne": "%s.system.indexes" % q.db.name}})[0] | ||||||
|  |  | ||||||
|  |             self.assertFalse('$orderby' in op['query'], | ||||||
|  |                 'BaseQuerySet cannot use orderby in if stmt') | ||||||
|  |  | ||||||
|  |         with query_counter() as p: | ||||||
|  |  | ||||||
|  |             for x in qs: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             op = p.db.system.profile.find({"ns": | ||||||
|  |                 {"$ne": "%s.system.indexes" % q.db.name}})[0] | ||||||
|  |  | ||||||
|  |             self.assertTrue('$orderby' in op['query'], | ||||||
|  |                 'BaseQuerySet cannot remove orderby in for loop') | ||||||
|  |  | ||||||
|  |     def test_bool_with_ordering_from_meta_dict(self): | ||||||
|  |  | ||||||
|  |         class Person(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             meta = { | ||||||
|  |                 'ordering': ['name'] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         Person.drop_collection() | ||||||
|  |  | ||||||
|  |         Person(name="B").save() | ||||||
|  |         Person(name="C").save() | ||||||
|  |         Person(name="A").save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |  | ||||||
|  |             if Person.objects: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             op = q.db.system.profile.find({"ns": | ||||||
|  |                 {"$ne": "%s.system.indexes" % q.db.name}})[0] | ||||||
|  |  | ||||||
|  |             self.assertFalse('$orderby' in op['query'], | ||||||
|  |                 'BaseQuerySet must remove orderby from meta in boolen test') | ||||||
|  |  | ||||||
|  |             self.assertEqual(Person.objects.first().name, 'A') | ||||||
|  |             self.assertTrue(Person.objects._has_data(), | ||||||
|  |                             'Cursor has data and returned False') | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -1,6 +1,11 @@ | |||||||
| import sys | import sys | ||||||
| sys.path[0:0] = [""] | sys.path[0:0] = [""] | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     import unittest2 as unittest | ||||||
|  | except ImportError: | ||||||
|     import unittest |     import unittest | ||||||
|  |  | ||||||
| import datetime | import datetime | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
| @@ -34,6 +39,17 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         conn = get_connection('testdb') |         conn = get_connection('testdb') | ||||||
|         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) |         self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) | ||||||
|  |  | ||||||
|  |     def test_sharing_connections(self): | ||||||
|  |         """Ensure that connections are shared when the connection settings are exactly the same | ||||||
|  |         """ | ||||||
|  |         connect('mongoenginetest', alias='testdb1') | ||||||
|  |  | ||||||
|  |         expected_connection = get_connection('testdb1') | ||||||
|  |  | ||||||
|  |         connect('mongoenginetest', alias='testdb2') | ||||||
|  |         actual_connection = get_connection('testdb2') | ||||||
|  |         self.assertIs(expected_connection, actual_connection) | ||||||
|  |  | ||||||
|     def test_connect_uri(self): |     def test_connect_uri(self): | ||||||
|         """Ensure that the connect() method works properly with uri's |         """Ensure that the connect() method works properly with uri's | ||||||
|         """ |         """ | ||||||
|   | |||||||
							
								
								
									
										107
									
								
								tests/test_datastructures.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								tests/test_datastructures.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,107 @@ | |||||||
|  | import unittest | ||||||
|  | from mongoengine.base.datastructures import StrictDict, SemiStrictDict  | ||||||
|  |  | ||||||
|  | class TestStrictDict(unittest.TestCase): | ||||||
|  |     def strict_dict_class(self, *args, **kwargs): | ||||||
|  |         return StrictDict.create(*args, **kwargs) | ||||||
|  |     def setUp(self): | ||||||
|  |         self.dtype = self.strict_dict_class(("a", "b", "c")) | ||||||
|  |     def test_init(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1) | ||||||
|  |         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||||
|  |  | ||||||
|  |     def test_init_fails_on_nonexisting_attrs(self): | ||||||
|  |         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||||
|  |          | ||||||
|  |     def test_eq(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1) | ||||||
|  |         dd = self.dtype(a=1, b=1, c=1) | ||||||
|  |         e = self.dtype(a=1, b=1, c=3) | ||||||
|  |         f = self.dtype(a=1, b=1) | ||||||
|  |         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) | ||||||
|  |         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) | ||||||
|  |         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) | ||||||
|  |          | ||||||
|  |         self.assertEqual(d, dd) | ||||||
|  |         self.assertNotEqual(d, e) | ||||||
|  |         self.assertNotEqual(d, f) | ||||||
|  |         self.assertNotEqual(d, g) | ||||||
|  |         self.assertNotEqual(f, d) | ||||||
|  |         self.assertEqual(d, h) | ||||||
|  |         self.assertNotEqual(d, i) | ||||||
|  |  | ||||||
|  |     def test_setattr_getattr(self): | ||||||
|  |         d = self.dtype() | ||||||
|  |         d.a = 1 | ||||||
|  |         self.assertEqual(d.a, 1) | ||||||
|  |         self.assertRaises(AttributeError, lambda: d.b) | ||||||
|  |      | ||||||
|  |     def test_setattr_raises_on_nonexisting_attr(self): | ||||||
|  |         d = self.dtype() | ||||||
|  |         def _f(): | ||||||
|  |             d.x=1 | ||||||
|  |         self.assertRaises(AttributeError, _f) | ||||||
|  |      | ||||||
|  |     def test_setattr_getattr_special(self): | ||||||
|  |         d = self.strict_dict_class(["items"]) | ||||||
|  |         d.items = 1 | ||||||
|  |         self.assertEqual(d.items, 1) | ||||||
|  |      | ||||||
|  |     def test_get(self): | ||||||
|  |         d = self.dtype(a=1) | ||||||
|  |         self.assertEqual(d.get('a'), 1) | ||||||
|  |         self.assertEqual(d.get('b', 'bla'), 'bla') | ||||||
|  |  | ||||||
|  |     def test_items(self): | ||||||
|  |         d = self.dtype(a=1) | ||||||
|  |         self.assertEqual(d.items(), [('a', 1)]) | ||||||
|  |         d = self.dtype(a=1, b=2) | ||||||
|  |         self.assertEqual(d.items(), [('a', 1), ('b', 2)]) | ||||||
|  |  | ||||||
|  |     def test_mappings_protocol(self): | ||||||
|  |         d = self.dtype(a=1, b=2) | ||||||
|  |         assert dict(d) == {'a': 1, 'b': 2} | ||||||
|  |         assert dict(**d) == {'a': 1, 'b': 2} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestSemiSrictDict(TestStrictDict): | ||||||
|  |     def strict_dict_class(self, *args, **kwargs): | ||||||
|  |         return SemiStrictDict.create(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def test_init_fails_on_nonexisting_attrs(self): | ||||||
|  |         # disable irrelevant test | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def test_setattr_raises_on_nonexisting_attr(self): | ||||||
|  |         # disable irrelevant test | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def test_setattr_getattr_nonexisting_attr_succeeds(self): | ||||||
|  |         d = self.dtype() | ||||||
|  |         d.x = 1 | ||||||
|  |         self.assertEqual(d.x, 1) | ||||||
|  |  | ||||||
|  |     def test_init_succeeds_with_nonexisting_attrs(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|  |         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) | ||||||
|  |     | ||||||
|  |     def test_iter_with_nonexisting_attrs(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|  |         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) | ||||||
|  |  | ||||||
|  |     def test_iteritems_with_nonexisting_attrs(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|  |         self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)]) | ||||||
|  |  | ||||||
|  |     def tets_cmp_with_strict_dicts(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1) | ||||||
|  |         dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1) | ||||||
|  |         self.assertEqual(d, dd) | ||||||
|  |  | ||||||
|  |     def test_cmp_with_strict_dict_with_nonexisting_attrs(self): | ||||||
|  |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|  |         dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2) | ||||||
|  |         self.assertEqual(d, dd) | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     unittest.main() | ||||||
| @@ -291,6 +291,30 @@ class FieldTest(unittest.TestCase): | |||||||
|                 self.assertEqual(employee.friends, friends) |                 self.assertEqual(employee.friends, friends) | ||||||
|                 self.assertEqual(q, 2) |                 self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |     def test_list_of_lists_of_references(self): | ||||||
|  |  | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Post(Document): | ||||||
|  |             user_lists = ListField(ListField(ReferenceField(User))) | ||||||
|  |  | ||||||
|  |         class SimpleList(Document): | ||||||
|  |             users = ListField(ReferenceField(User)) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Post.drop_collection() | ||||||
|  |  | ||||||
|  |         u1 = User.objects.create(name='u1') | ||||||
|  |         u2 = User.objects.create(name='u2') | ||||||
|  |         u3 = User.objects.create(name='u3') | ||||||
|  |  | ||||||
|  |         SimpleList.objects.create(users=[u1, u2, u3]) | ||||||
|  |         self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) | ||||||
|  |  | ||||||
|  |         Post.objects.create(user_lists=[[u1, u2], [u3]]) | ||||||
|  |         self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) | ||||||
|  |  | ||||||
|     def test_circular_reference(self): |     def test_circular_reference(self): | ||||||
|         """Ensure you can handle circular references |         """Ensure you can handle circular references | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -54,7 +54,9 @@ class SignalTests(unittest.TestCase): | |||||||
|  |  | ||||||
|             @classmethod |             @classmethod | ||||||
|             def post_save(cls, sender, document, **kwargs): |             def post_save(cls, sender, document, **kwargs): | ||||||
|  |                 dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() | ||||||
|                 signal_output.append('post_save signal, %s' % document) |                 signal_output.append('post_save signal, %s' % document) | ||||||
|  |                 signal_output.append('post_save dirty keys, %s' % dirty_keys) | ||||||
|                 if 'created' in kwargs: |                 if 'created' in kwargs: | ||||||
|                     if kwargs['created']: |                     if kwargs['created']: | ||||||
|                         signal_output.append('Is created') |                         signal_output.append('Is created') | ||||||
| @@ -203,6 +205,7 @@ class SignalTests(unittest.TestCase): | |||||||
|             "pre_save_post_validation signal, Bill Shakespeare", |             "pre_save_post_validation signal, Bill Shakespeare", | ||||||
|             "Is created", |             "Is created", | ||||||
|             "post_save signal, Bill Shakespeare", |             "post_save signal, Bill Shakespeare", | ||||||
|  |             "post_save dirty keys, ['name']", | ||||||
|             "Is created" |             "Is created" | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
| @@ -213,6 +216,7 @@ class SignalTests(unittest.TestCase): | |||||||
|             "pre_save_post_validation signal, William Shakespeare", |             "pre_save_post_validation signal, William Shakespeare", | ||||||
|             "Is updated", |             "Is updated", | ||||||
|             "post_save signal, William Shakespeare", |             "post_save signal, William Shakespeare", | ||||||
|  |             "post_save dirty keys, ['name']", | ||||||
|             "Is updated" |             "Is updated" | ||||||
|         ]) |         ]) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user