diff --git a/.travis.yml b/.travis.yml index 5739909b..40736165 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,26 +6,28 @@ python: - "2.7" - "3.2" - "3.3" + - "3.4" + - "pypy" env: - - PYMONGO=dev DJANGO=1.6 - - PYMONGO=dev DJANGO=1.5.5 - - PYMONGO=dev DJANGO=1.4.10 - - PYMONGO=2.5 DJANGO=1.6 - - PYMONGO=2.5 DJANGO=1.5.5 - - PYMONGO=2.5 DJANGO=1.4.10 - - PYMONGO=3.2 DJANGO=1.6 - - PYMONGO=3.2 DJANGO=1.5.5 - - PYMONGO=3.3 DJANGO=1.6 - - PYMONGO=3.3 DJANGO=1.5.5 + - PYMONGO=dev DJANGO=1.6.5 + - PYMONGO=dev DJANGO=1.5.8 + - PYMONGO=2.7.1 DJANGO=1.6.5 + - PYMONGO=2.7.1 DJANGO=1.5.8 + +matrix: + fast_finish: true + install: - - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi + - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev python-tk - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi + - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO; true; fi + - pip install Django==$DJANGO - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b - python setup.py install script: - python setup.py test + - if [[ $TRAVIS_PYTHON_VERSION == '3.'* ]]; then 2to3 . -w; fi; + - python benchmark.py notifications: irc: "irc.freenode.org#mongoengine" branches: diff --git a/AUTHORS b/AUTHORS index d6994d50..c7d57d01 100644 --- a/AUTHORS +++ b/AUTHORS @@ -171,7 +171,7 @@ that much better: * Michael Bartnett (https://github.com/michaelbartnett) * Alon Horev (https://github.com/alonho) * Kelvin Hammond (https://github.com/kelvinhammond) - * Jatin- (https://github.com/jatin-) + * Jatin Chopra (https://github.com/jatin) * Paul Uithol (https://github.com/PaulUithol) * Thom Knowles (https://github.com/fleat) * Paul (https://github.com/squamous) @@ -189,3 +189,17 @@ that much better: * Tom (https://github.com/tomprimozic) * j0hnsmith (https://github.com/j0hnsmith) * Damien Churchill (https://github.com/damoxc) + * Jonathan Simon Prates (https://github.com/jonathansp) + * Thiago Papageorgiou (https://github.com/tmpapageorgiou) + * Omer Katz (https://github.com/thedrow) + * Falcon Dai (https://github.com/falcondai) + * Polyrabbit (https://github.com/polyrabbit) + * Sagiv Malihi (https://github.com/sagivmalihi) + * Dmitry Konishchev (https://github.com/KonishchevDmitry) + * Martyn Smith (https://github.com/martynsmith) + * Andrei Zbikowski (https://github.com/b1naryth1ef) + * Ronald van Rij (https://github.com/ronaldvanrij) + * François Schmidts (https://github.com/jaesivsm) + * Eric Plumb (https://github.com/professorplumb) + * Damien Churchill (https://github.com/damoxc) + * Aleksandr Sorokoumov (https://github.com/Gerrrr) \ No newline at end of file diff --git a/README.rst b/README.rst index cc4524ae..8c3ee26e 100644 --- a/README.rst +++ b/README.rst @@ -29,9 +29,18 @@ setup.py install``. Dependencies ============ -- pymongo 2.5+ +- pymongo>=2.5 - sphinx (optional - for documentation generation) +Optional Dependencies +--------------------- +- **Django Integration:** Django>=1.4.0 for Python 2.x or PyPy and Django>=1.5.0 for Python 3.x +- **Image Fields**: Pillow>=2.0.0 or PIL (not recommended since MongoEngine is tested with Pillow) +- dateutil>=2.1.0 + +.. note + MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: Django 1.6.5 + Examples ======== Some simple examples of what MongoEngine code looks like:: diff --git a/benchmark.py b/benchmark.py index 16b2fd47..53ecf32c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -15,7 +15,7 @@ def cprofile_main(): class Noddy(Document): fields = DictField() - for i in xrange(1): + for i in range(1): noddy = Noddy() for j in range(20): noddy.fields["key" + str(j)] = "value " + str(j) @@ -113,6 +113,7 @@ def main(): 4.68946313858 ---------------------------------------------------------------------------------------------------- """ + print("Benchmarking...") setup = """ from pymongo import MongoClient @@ -127,7 +128,7 @@ connection = MongoClient() db = connection.timeit_test noddy = db.noddy -for i in xrange(10000): +for i in range(10000): example = {'fields': {}} for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) @@ -138,10 +139,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - Pymongo""" + print("-" * 100) + print("""Creating 10000 dictionaries - Pymongo""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ from pymongo import MongoClient @@ -150,7 +151,7 @@ connection = MongoClient() db = connection.timeit_test noddy = db.noddy -for i in xrange(10000): +for i in range(10000): example = {'fields': {}} for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) @@ -161,10 +162,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""" + print("-" * 100) + print("""Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) setup = """ from pymongo import MongoClient @@ -180,7 +181,7 @@ class Noddy(Document): """ stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -190,13 +191,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() fields = {} for j in range(20): @@ -208,13 +209,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries without continual assign - MongoEngine""" + print("-" * 100) + print("""Creating 10000 dictionaries without continual assign - MongoEngine""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -224,13 +225,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -240,13 +241,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -256,13 +257,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -272,11 +273,11 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/changelog.rst b/docs/changelog.rst index 51134238..0c9c4825 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,40 @@ Changelog ========= + +Changes in 0.9.X - DEV +====================== + +- Fixed nested reference field distinct error #583 +- Fixed change tracking on nested MapFields #539 +- Dynamic fields in embedded documents now visible to queryset.only() / qs.exclude() #425 #507 +- Add authentication_source option to register_connection #178 #464 #573 #580 #590 +- Implemented equality between Documents and DBRefs #597 +- Fixed ReferenceField inside nested ListFields dereferencing problem #368 +- Added the ability to reload specific document fields #100 +- Added db_alias support and fixes for custom map/reduce output #586 +- post_save signal now has access to delta information about field changes #594 #589 +- Don't query with $orderby for qs.get() #600 +- Fix id shard key save issue #636 +- Fixes issue with recursive embedded document errors #557 +- Fix clear_changed_fields() clearing unsaved documents bug #602 +- Removing support for Django 1.4.x, pymongo 2.5.x, pymongo 2.6.x. +- Removing support for Python < 2.6.6 +- Fixed $maxDistance location for geoJSON $near queries with MongoDB 2.6+ #664 +- QuerySet.modify() method to provide find_and_modify() like behaviour #677 +- Added support for the using() method on a queryset #676 +- PYPY support #673 +- Connection pooling #674 +- Avoid to open all documents from cursors in an if stmt #655 +- Ability to clear the ordering #657 +- Raise NotUniqueError in Document.update() on pymongo.errors.DuplicateKeyError #626 +- Slots - memory improvements #625 +- Fixed incorrectly split a query key when it ends with "_" #619 +- Geo docs updates #613 +- Workaround a dateutil bug #608 +- Conditional save for atomic-style operations #511 +- Allow dynamic dictionary-style field access #559 + Changes in 0.8.7 ================ - Calling reload on deleted / nonexistant documents raises DoesNotExist (#538) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 5d8b628a..07bce3bb 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -531,6 +531,8 @@ field name to the index definition. Sometimes its more efficient to index parts of Embedded / dictionary fields, in this case use 'dot' notation to identify the value to index eg: `rank.title` +.. _geospatial-indexes: + Geospatial indexes ------------------ diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst index 596585de..68e7a6d2 100644 --- a/docs/guide/gridfs.rst +++ b/docs/guide/gridfs.rst @@ -46,7 +46,7 @@ slightly different manner. First, a new file must be created by calling the marmot.photo.write('some_more_image_data') marmot.photo.close() - marmot.photo.save() + marmot.save() Deletion -------- diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 32cbb94e..34996dc6 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -488,8 +488,9 @@ calling it with keyword arguments:: Atomic updates ============== Documents may be updated atomically by using the -:meth:`~mongoengine.queryset.QuerySet.update_one` and -:meth:`~mongoengine.queryset.QuerySet.update` methods on a +:meth:`~mongoengine.queryset.QuerySet.update_one`, +:meth:`~mongoengine.queryset.QuerySet.update` and +:meth:`~mongoengine.queryset.QuerySet.modify` methods on a :meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" that you may use with these methods: @@ -499,11 +500,13 @@ that you may use with these methods: * ``dec`` -- decrement a value by a given amount * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list -* ``pop`` -- remove the first or last element of a list +* ``pop`` -- remove the first or last element of a list `depending on the value`_ * ``pull`` -- remove a value from a list * ``pull_all`` -- remove several values from a list * ``add_to_set`` -- add value to a list only if its not in the list already +.. _depending on the value: http://docs.mongodb.org/manual/reference/operator/update/pop/ + The syntax for atomic updates is similar to the querying syntax, but the modifier comes before the field, not after it:: diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 4652fb56..dc92d183 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,12 +1,13 @@ import weakref +import functools +import itertools from mongoengine.common import _import_class __all__ = ("BaseDict", "BaseList") class BaseDict(dict): - """A special dict so we can watch any changes - """ + """A special dict so we can watch any changes""" _dereferenced = False _instance = None @@ -21,29 +22,37 @@ class BaseDict(dict): self._name = name return super(BaseDict, self).__init__(dict_items) - def __getitem__(self, *args, **kwargs): - value = super(BaseDict, self).__getitem__(*args, **kwargs) + def __getitem__(self, key, *args, **kwargs): + value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance + elif not isinstance(value, BaseDict) and isinstance(value, dict): + value = BaseDict(value, None, '%s.%s' % (self._name, key)) + super(BaseDict, self).__setitem__(key, value) + value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseDict, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__setitem__(*args, **kwargs) + def __setitem__(self, key, value, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseDict, self).__setitem__(key, value) def __delete__(self, *args, **kwargs): self._mark_as_changed() return super(BaseDict, self).__delete__(*args, **kwargs) - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delitem__(*args, **kwargs) + def __delitem__(self, key, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseDict, self).__delitem__(key) - def __delattr__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delattr__(*args, **kwargs) + def __delattr__(self, key, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseDict, self).__delattr__(key) def __getstate__(self): self.instance = None @@ -70,9 +79,12 @@ class BaseDict(dict): self._mark_as_changed() return super(BaseDict, self).update(*args, **kwargs) - def _mark_as_changed(self): + def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) + if key: + self._instance._mark_as_changed('%s.%s' % (self._name, key)) + else: + self._instance._mark_as_changed(self._name) class BaseList(list): @@ -92,21 +104,35 @@ class BaseList(list): self._name = name return super(BaseList, self).__init__(list_items) - def __getitem__(self, *args, **kwargs): - value = super(BaseList, self).__getitem__(*args, **kwargs) + def __getitem__(self, key, *args, **kwargs): + value = super(BaseList, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance + elif not isinstance(value, BaseDict) and isinstance(value, dict): + value = BaseDict(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setitem__(*args, **kwargs) + def __setitem__(self, key, value, *args, **kwargs): + if isinstance(key, slice): + self._mark_as_changed() + else: + self._mark_as_changed(key) + return super(BaseList, self).__setitem__(key, value) - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delitem__(*args, **kwargs) + def __delitem__(self, key, *args, **kwargs): + if isinstance(key, slice): + self._mark_as_changed() + else: + self._mark_as_changed(key) + return super(BaseList, self).__delitem__(key) def __setslice__(self, *args, **kwargs): self._mark_as_changed() @@ -153,6 +179,103 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).sort(*args, **kwargs) - def _mark_as_changed(self): + def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) + if key: + self._instance._mark_as_changed('%s.%s' % (self._name, key)) + else: + self._instance._mark_as_changed(self._name) + + +class StrictDict(object): + __slots__ = () + _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) + _classes = {} + def __init__(self, **kwargs): + for k,v in kwargs.iteritems(): + setattr(self, k, v) + def __getitem__(self, key): + key = '_reserved_' + key if key in self._special_fields else key + try: + return getattr(self, key) + except AttributeError: + raise KeyError(key) + def __setitem__(self, key, value): + key = '_reserved_' + key if key in self._special_fields else key + return setattr(self, key, value) + def __contains__(self, key): + return hasattr(self, key) + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + def pop(self, key, default=None): + v = self.get(key, default) + try: + delattr(self, key) + except AttributeError: + pass + return v + def iteritems(self): + for key in self: + yield key, self[key] + def items(self): + return [(k, self[k]) for k in iter(self)] + def keys(self): + return list(iter(self)) + def __iter__(self): + return (key for key in self.__slots__ if hasattr(self, key)) + def __len__(self): + return len(list(self.iteritems())) + def __eq__(self, other): + return self.items() == other.items() + def __neq__(self, other): + return self.items() != other.items() + + @classmethod + def create(cls, allowed_keys): + allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) + allowed_keys = frozenset(allowed_keys_tuple) + if allowed_keys not in cls._classes: + class SpecificStrictDict(cls): + __slots__ = allowed_keys_tuple + cls._classes[allowed_keys] = SpecificStrictDict + return cls._classes[allowed_keys] + + +class SemiStrictDict(StrictDict): + __slots__ = ('_extras') + _classes = {} + def __getattr__(self, attr): + try: + super(SemiStrictDict, self).__getattr__(attr) + except AttributeError: + try: + return self.__getattribute__('_extras')[attr] + except KeyError as e: + raise AttributeError(e) + def __setattr__(self, attr, value): + try: + super(SemiStrictDict, self).__setattr__(attr, value) + except AttributeError: + try: + self._extras[attr] = value + except AttributeError: + self._extras = {attr: value} + + def __delattr__(self, attr): + try: + super(SemiStrictDict, self).__delattr__(attr) + except AttributeError: + try: + del self._extras[attr] + except KeyError as e: + raise AttributeError(e) + + def __iter__(self): + try: + extras_iter = iter(self.__getattribute__('_extras')) + except AttributeError: + extras_iter = () + return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f5eae8ff..8ec370b6 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -13,24 +13,23 @@ from mongoengine import signals from mongoengine.common import _import_class from mongoengine.errors import (ValidationError, InvalidDocumentError, LookUpError) -from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, - to_str_keys_recursive) +from mongoengine.python_support import PY3, txt_type from mongoengine.base.common import get_document, ALLOW_INHERITANCE -from mongoengine.base.datastructures import BaseDict, BaseList +from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict from mongoengine.base.fields import ComplexBaseField __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' - class BaseDocument(object): + __slots__ = ('_changed_fields', '_initialised', '_created', '_data', + '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') _dynamic = False - _created = True _dynamic_lock = True - _initialised = False + STRICT = False def __init__(self, *args, **values): """ @@ -39,6 +38,8 @@ class BaseDocument(object): :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ + self._initialised = False + self._created = True if args: # Combine positional arguments with named arguments. # We only want named arguments. @@ -54,7 +55,11 @@ class BaseDocument(object): __auto_convert = values.pop("__auto_convert", True) signals.pre_init.send(self.__class__, document=self, values=values) - self._data = {} + if self.STRICT and not self._dynamic: + self._data = StrictDict.create(allowed_keys=self._fields.keys())() + else: + self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())() + self._dynamic_fields = SON() # Assign default values to instance @@ -130,17 +135,25 @@ class BaseDocument(object): self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) + try: + self__created = self._created + except AttributeError: + self__created = True - if (self._is_document and not self._created and + if (self._is_document and not self__created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) + try: + self__initialised = self._initialised + except AttributeError: + self__initialised = False # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): + if (self._is_document and self__initialised + and self__created and name == self._meta['id_field']): super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -158,9 +171,11 @@ class BaseDocument(object): if isinstance(data["_data"], SON): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ('_changed_fields', '_initialised', '_created', '_data', - '_fields_ordered', '_dynamic_fields'): + '_dynamic_fields'): if k in data: setattr(self, k, data[k]) + if '_fields_ordered' in data: + setattr(type(self), '_fields_ordered', data['_fields_ordered']) dynamic_fields = data.get('_dynamic_fields') or SON() for k in dynamic_fields.keys(): setattr(self, k, data["_data"].get(k)) @@ -182,7 +197,7 @@ class BaseDocument(object): """Dictionary-style field access, set a field's value. """ # Ensure that the field exists before settings its value - if name not in self._fields: + if not self._dynamic and name not in self._fields: raise KeyError(name) return setattr(self, name, value) @@ -214,8 +229,9 @@ class BaseDocument(object): def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True + return self.id == other.id + if isinstance(other, DBRef): + return self._get_collection_name() == other.collection and self.id == other.id return False def __ne__(self, other): @@ -317,7 +333,7 @@ class BaseDocument(object): pk = "None" if hasattr(self, 'pk'): pk = self.pk - elif self._instance: + elif self._instance and hasattr(self._instance, 'pk'): pk = self._instance.pk message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) @@ -370,9 +386,18 @@ class BaseDocument(object): """ if not key: return - key = self._db_field_map.get(key, key) - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): + + if not hasattr(self, '_changed_fields'): + return + + if '.' in key: + key, rest = key.split('.', 1) + key = self._db_field_map.get(key, key) + key = '%s.%s' % (key, rest) + else: + key = self._db_field_map.get(key, key) + + if key not in self._changed_fields: self._changed_fields.append(key) def _clear_changed_fields(self): @@ -392,6 +417,8 @@ class BaseDocument(object): else: data = getattr(data, part, None) if hasattr(data, "_changed_fields"): + if hasattr(data, "_is_document") and data._is_document: + continue data._changed_fields = [] self._changed_fields = [] @@ -405,6 +432,10 @@ class BaseDocument(object): for index, value in iterator: list_key = "%s%s." % (key, index) + # don't check anything lower if this key is already marked + # as changed. + if list_key[:-1] in changed_fields: + continue if hasattr(value, '_get_changed_fields'): changed = value._get_changed_fields(inspected) changed_fields += ["%s%s" % (list_key, k) @@ -420,6 +451,7 @@ class BaseDocument(object): ReferenceField = _import_class("ReferenceField") changed_fields = [] changed_fields += getattr(self, '_changed_fields', []) + inspected = inspected or set() if hasattr(self, 'id') and isinstance(self.id, Hashable): if self.id in inspected: @@ -472,7 +504,10 @@ class BaseDocument(object): if isinstance(d, (ObjectId, DBRef)): break elif isinstance(d, list) and p.isdigit(): - d = d[int(p)] + try: + d = d[int(p)] + except IndexError: + d = None elif hasattr(d, 'get'): d = d.get(p) new_path.append(p) @@ -545,10 +580,6 @@ class BaseDocument(object): # class if unavailable class_name = son.get('_cls', cls._class_name) data = dict(("%s" % key, value) for key, value in son.iteritems()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) # Return correct subclass for document type if class_name != cls._class_name: @@ -586,6 +617,8 @@ class BaseDocument(object): % (cls._class_name, errors)) raise InvalidDocumentError(msg) + if cls.STRICT: + data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) obj = cls(__auto_convert=False, **data) obj._changed_fields = changed_fields obj._created = False @@ -804,8 +837,17 @@ class BaseDocument(object): # Look up subfield on the previous field new_field = field.lookup_member(field_name) if not new_field and isinstance(field, ComplexBaseField): - fields.append(field_name) - continue + if hasattr(field.field, 'document_type') and cls._dynamic \ + and field.field.document_type._dynamic: + DynamicField = _import_class('DynamicField') + new_field = DynamicField(db_field=field_name) + else: + fields.append(field_name) + continue + elif not new_field and hasattr(field, 'document_type') and cls._dynamic \ + and field.document_type._dynamic: + DynamicField = _import_class('DynamicField') + new_field = DynamicField(db_field=field_name) elif not new_field: raise LookUpError('Cannot resolve field "%s"' % field_name) @@ -825,7 +867,11 @@ class BaseDocument(object): """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): if field.choices: - setattr(self, + if self._dynamic: + obj = self + else: + obj = type(self) + setattr(obj, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index ff5afddf..4b2e8b9b 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.id = field # Set primary key if not defined by the document - new_class._auto_id_field = False + new_class._auto_id_field = getattr(parent_doc_cls, + '_auto_id_field', False) if not new_class._meta.get('id_field'): new_class._auto_id_field = True new_class._meta['id_field'] = 'id' diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 7cc626f4..fbba3caa 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -20,7 +20,8 @@ _dbs = {} def register_connection(alias, name, host=None, port=None, is_slave=False, read_preference=False, slaves=None, - username=None, password=None, **kwargs): + username=None, password=None, authentication_source=None, + **kwargs): """Add a connection. :param alias: the name that will be used to refer to this connection @@ -36,6 +37,7 @@ def register_connection(alias, name, host=None, port=None, be a registered connection that has :attr:`is_slave` set to ``True`` :param username: username to authenticate with :param password: password to authenticate with + :param authentication_source: database to authenticate against :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver """ @@ -46,10 +48,11 @@ def register_connection(alias, name, host=None, port=None, 'host': host or 'localhost', 'port': port or 27017, 'is_slave': is_slave, + 'read_preference': read_preference, 'slaves': slaves or [], 'username': username, 'password': password, - 'read_preference': read_preference + 'authentication_source': authentication_source } # Handle uri style connections @@ -93,20 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): raise ConnectionError(msg) conn_settings = _connection_settings[alias].copy() - if hasattr(pymongo, 'version_tuple'): # Support for 2.1+ - conn_settings.pop('name', None) - conn_settings.pop('slaves', None) - conn_settings.pop('is_slave', None) - conn_settings.pop('username', None) - conn_settings.pop('password', None) - else: - # Get all the slave connections - if 'slaves' in conn_settings: - slaves = [] - for slave_alias in conn_settings['slaves']: - slaves.append(get_connection(slave_alias)) - conn_settings['slaves'] = slaves - conn_settings.pop('read_preference', None) + conn_settings.pop('name', None) + conn_settings.pop('slaves', None) + conn_settings.pop('is_slave', None) + conn_settings.pop('username', None) + conn_settings.pop('password', None) + conn_settings.pop('authentication_source', None) connection_class = MongoClient if 'replicaSet' in conn_settings: @@ -119,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): connection_class = MongoReplicaSetClient try: - _connections[alias] = connection_class(**conn_settings) + connection = None + connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems()) + for alias, connection_settings in connection_settings_iterator: + connection_settings.pop('name', None) + connection_settings.pop('slaves', None) + connection_settings.pop('is_slave', None) + connection_settings.pop('username', None) + connection_settings.pop('password', None) + if conn_settings == connection_settings and _connections.get(alias, None): + connection = _connections[alias] + break + + _connections[alias] = connection if connection else connection_class(**conn_settings) except Exception, e: raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) return _connections[alias] @@ -137,7 +144,8 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Authenticate if necessary if conn_settings['username'] and conn_settings['password']: db.authenticate(conn_settings['username'], - conn_settings['password']) + conn_settings['password'], + source=conn_settings['authentication_source']) _dbs[alias] = db return _dbs[alias] diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 13ed1009..cc860066 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,6 +1,5 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.queryset import QuerySet __all__ = ("switch_db", "switch_collection", "no_dereference", @@ -162,12 +161,6 @@ class no_sub_classes(object): return self.cls -class QuerySetNoDeRef(QuerySet): - """Special no_dereference QuerySet""" - def __dereference(items, max_depth=1, instance=None, name=None): - return items - - class query_counter(object): """ Query_counter context manager to get the number of queries. """ diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index ceda403e..08eac7dd 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -36,7 +36,7 @@ class DeReference(object): if instance and isinstance(instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) - if hasattr(doc_type, 'field'): + while hasattr(doc_type, 'field'): doc_type = doc_type.field if isinstance(doc_type, ReferenceField): @@ -51,9 +51,19 @@ class DeReference(object): return items elif not field.dbref: if not hasattr(items, 'items'): - items = [field.to_python(v) - if not isinstance(v, (DBRef, Document)) else v - for v in items] + + def _get_items(items): + new_items = [] + for v in items: + if isinstance(v, list): + new_items.append(_get_items(v)) + elif not isinstance(v, (DBRef, Document)): + new_items.append(field.to_python(v)) + else: + new_items.append(v) + return new_items + + items = _get_items(items) else: items = dict([ (k, field.to_python(v)) @@ -114,11 +124,11 @@ class DeReference(object): """Fetch all references and convert to their document objects """ object_map = {} - for col, dbrefs in self.reference_map.iteritems(): + for collection, dbrefs in self.reference_map.iteritems(): keys = object_map.keys() refs = list(set([dbref for dbref in dbrefs if unicode(dbref).encode('utf-8') not in keys])) - if hasattr(col, 'objects'): # We have a document class for the refs - references = col.objects.in_bulk(refs) + if hasattr(collection, 'objects'): # We have a document class for the refs + references = collection.objects.in_bulk(refs) for key, doc in references.iteritems(): object_map[key] = doc else: # Generic reference: use the refs data to convert to document @@ -126,19 +136,19 @@ class DeReference(object): continue if doc_type: - references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) + references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: doc = doc_type._from_son(ref) object_map[doc.id] = doc else: - references = get_db()[col].find({'_id': {'$in': refs}}) + references = get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: if '_cls' in ref: doc = get_document(ref["_cls"])._from_son(ref) elif doc_type is None: doc = get_document( ''.join(x.capitalize() - for x in col.split('_')))._from_son(ref) + for x in collection.split('_')))._from_son(ref) else: doc = doc_type._from_son(ref) object_map[doc.id] = doc @@ -204,7 +214,8 @@ class DeReference(object): elif isinstance(v, (list, tuple)) and depth <= self.max_depth: data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) + item_name = '%s.%s' % (name, k) if name else name + data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) elif hasattr(v, 'id'): data[k] = self.object_map.get(v.id, v) diff --git a/mongoengine/document.py b/mongoengine/document.py index 114778eb..d969d638 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) from mongoengine.errors import ValidationError -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.queryset import (OperationError, NotUniqueError, + QuerySet, transform) from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -54,20 +55,21 @@ class EmbeddedDocument(BaseDocument): dictionary. """ + __slots__ = ('_instance') + # The __metaclass__ attribute is removed by 2to3 when running with Python3 # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass - _instance = None - def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) + self._instance = None self._changed_fields = [] def __eq__(self, other): if isinstance(other, self.__class__): - return self.to_mongo() == other.to_mongo() + return self._data == other._data return False def __ne__(self, other): @@ -125,6 +127,8 @@ class Document(BaseDocument): my_metaclass = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass + __slots__ = ('__objects' ) + def pk(): """Primary key alias """ @@ -180,7 +184,7 @@ class Document(BaseDocument): def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, **kwargs): + _refs=None, save_condition=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -203,7 +207,8 @@ class Document(BaseDocument): :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves - + :param save_condition: only perform save if matching record in db + satisfies condition(s) (e.g., version number) .. versionchanged:: 0.5 In existing documents it only saves changed fields using set / unset. Saves are cascaded and any @@ -217,6 +222,9 @@ class Document(BaseDocument): meta['cascade'] = True. Also you can pass different kwargs to the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. + .. versionchanged:: 0.8.5 + Optional save_condition that only overwrites existing documents + if the condition is satisfied in the current db record. """ signals.pre_save.send(self.__class__, document=self) @@ -230,7 +238,8 @@ class Document(BaseDocument): created = ('_id' not in doc or self._created or force_insert) - signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + signals.pre_save_post_validation.send(self.__class__, document=self, + created=created) try: collection = self._get_collection() @@ -243,7 +252,12 @@ class Document(BaseDocument): object_id = doc['_id'] updates, removals = self._delta() # Need to add shard key to query, or you get an error - select_dict = {'_id': object_id} + if save_condition is not None: + select_dict = transform.query(self.__class__, + **save_condition) + else: + select_dict = {} + select_dict['_id'] = object_id shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: actual_key = self._db_field_map.get(k, k) @@ -263,10 +277,12 @@ class Document(BaseDocument): if removals: update_query["$unset"] = removals if updates or removals: + upsert = save_condition is None last_error = collection.update(select_dict, update_query, - upsert=True, **write_concern) + upsert=upsert, **write_concern) created = is_new_object(last_error) + if cascade is None: cascade = self._meta.get('cascade', False) or cascade_kwargs is not None @@ -293,12 +309,12 @@ class Document(BaseDocument): raise NotUniqueError(message % unicode(err)) raise OperationError(message % unicode(err)) id_field = self._meta['id_field'] - if id_field not in self._meta.get('shard_key', []): + if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self.__class__, document=self, created=created) self._clear_changed_fields() self._created = False - signals.post_save.send(self.__class__, document=self, created=created) return self def cascade_save(self, *args, **kwargs): @@ -447,27 +463,41 @@ class Document(BaseDocument): DeReference()([self], max_depth + 1) return self - def reload(self, max_depth=1): + def reload(self, *fields, **kwargs): """Reloads all attributes from the database. + :param fields: (optional) args list of fields to reload + :param max_depth: (optional) depth of dereferencing to follow + .. versionadded:: 0.1.2 .. versionchanged:: 0.6 Now chainable + .. versionchanged:: 0.9 Can provide specific fields to reload """ + max_depth = 1 + if fields and isinstance(fields[0], int): + max_depth = fields[0] + fields = fields[1:] + elif "max_depth" in kwargs: + max_depth = kwargs["max_depth"] + if not self.pk: raise self.DoesNotExist("Document does not exist") obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( - **self._object_key).limit(1).select_related(max_depth=max_depth) - + **self._object_key).only(*fields).limit(1 + ).select_related(max_depth=max_depth) if obj: obj = obj[0] else: raise self.DoesNotExist("Document does not exist") + for field in self._fields_ordered: - setattr(self, field, self._reload(field, obj[field])) + if not fields or field in fields: + setattr(self, field, self._reload(field, obj[field])) + self._changed_fields = obj._changed_fields self._created = False - return obj + return self def _reload(self, key, value): """Used by :meth:`~mongoengine.Document.reload` to ensure the diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 82642cda..6bf9ce4b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -391,7 +391,7 @@ class DateTimeField(BaseField): if dateutil: try: return dateutil.parser.parse(value) - except ValueError: + except (TypeError, ValueError): return None # split usecs, because they are not recognized by strptime. @@ -760,7 +760,7 @@ class DictField(ComplexBaseField): similar to an embedded document, but the structure is not defined. .. note:: - Required means it cannot be empty - as the default for ListFields is [] + Required means it cannot be empty - as the default for DictFields is {} .. versionadded:: 0.3 .. versionchanged:: 0.5 - Can now handle complex / varying types of data @@ -1554,6 +1554,14 @@ class SequenceField(BaseField): return super(SequenceField, self).__set__(instance, value) + def prepare_query_value(self, op, value): + """ + This method is overriden in order to convert the query value into to required + type. We need to do this in order to be able to successfully compare query + values passed as string, the base implementation returns the value as is. + """ + return self.value_decorator(value) + def to_python(self, value): if value is None: value = self.generate() @@ -1613,7 +1621,12 @@ class UUIDField(BaseField): class GeoPointField(BaseField): - """A list storing a latitude and longitude. + """A list storing a longitude and latitude coordinate. + + .. note:: this represents a generic point in a 2D plane and a legacy way of + representing a geo point. It admits 2d indexes but not "2dsphere" indexes + in MongoDB > 2.4 which are more natural for modeling geospatial points. + See :ref:`geospatial-indexes` .. versionadded:: 0.4 """ @@ -1635,7 +1648,7 @@ class GeoPointField(BaseField): class PointField(GeoJsonBaseField): - """A geo json field storing a latitude and longitude. + """A GeoJSON field storing a longitude and latitude coordinate. The data is represented as: @@ -1654,7 +1667,7 @@ class PointField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField): - """A geo json field storing a line of latitude and longitude coordinates. + """A GeoJSON field storing a line of longitude and latitude coordinates. The data is represented as: @@ -1672,7 +1685,7 @@ class LineStringField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField): - """A geo json field storing a polygon of latitude and longitude coordinates. + """A GeoJSON field storing a polygon of longitude and latitude coordinates. The data is represented as: diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 097740eb..2c4df00c 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -3,8 +3,6 @@ import sys PY3 = sys.version_info[0] == 3 -PY25 = sys.version_info[:2] == (2, 5) -UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264 if PY3: import codecs @@ -29,33 +27,3 @@ else: txt_type = unicode str_types = (bin_type, txt_type) - -if PY25: - def product(*args, **kwds): - pools = map(tuple, args) * kwds.get('repeat', 1) - result = [[]] - for pool in pools: - result = [x + [y] for x in result for y in pool] - for prod in result: - yield tuple(prod) - reduce = reduce -else: - from itertools import product - from functools import reduce - - -# For use with Python 2.5 -# converts all keys from unicode to str for d and all nested dictionaries -def to_str_keys_recursive(d): - if isinstance(d, list): - for val in d: - if isinstance(val, (dict, list)): - to_str_keys_recursive(val) - elif isinstance(d, dict): - for key, val in d.items(): - if isinstance(val, (dict, list)): - to_str_keys_recursive(val) - if isinstance(key, unicode): - d[str(key)] = d.pop(key) - else: - raise ValueError("non list/dict parameter not allowed") diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index c2ad027e..4b7ec491 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -7,17 +7,20 @@ import pprint import re import warnings +from bson import SON from bson.code import Code from bson import json_util import pymongo +import pymongo.errors from pymongo.common import validate_read_preference from mongoengine import signals +from mongoengine.connection import get_db +from mongoengine.context_managers import switch_db from mongoengine.common import _import_class from mongoengine.base.common import get_document from mongoengine.errors import (OperationError, NotUniqueError, InvalidQueryError, LookUpError) - from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode @@ -50,7 +53,7 @@ class BaseQuerySet(object): self._initial_query = {} self._where_clause = None self._loaded_fields = QueryFieldList() - self._ordering = [] + self._ordering = None self._snapshot = False self._timeout = True self._class_check = True @@ -146,7 +149,7 @@ class BaseQuerySet(object): queryset._document._from_son(queryset._cursor[key], _auto_dereference=self._auto_dereference)) if queryset._as_pymongo: - return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._get_as_pymongo(queryset._cursor[key]) return queryset._document._from_son(queryset._cursor[key], _auto_dereference=self._auto_dereference) raise AttributeError @@ -154,6 +157,22 @@ class BaseQuerySet(object): def __iter__(self): raise NotImplementedError + def _has_data(self): + """ Retrieves whether cursor has any data. """ + + queryset = self.order_by() + return False if queryset.first() is None else True + + def __nonzero__(self): + """ Avoid to open all records in an if stmt in Py2. """ + + return self._has_data() + + def __bool__(self): + """ Avoid to open all records in an if stmt in Py3. """ + + return self._has_data() + # Core functions def all(self): @@ -175,7 +194,7 @@ class BaseQuerySet(object): .. versionadded:: 0.3 """ queryset = self.clone() - queryset = queryset.limit(2) + queryset = queryset.order_by().limit(2) queryset = queryset.filter(*q_objs, **query) try: @@ -389,7 +408,7 @@ class BaseQuerySet(object): ref_q = document_cls.objects(**{field_name + '__in': self}) ref_q_count = ref_q.count() if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): + or (doc == document_cls and ref_q_count > 0)): ref_q.delete(write_concern=write_concern) elif rule == NULLIFY: document_cls.objects(**{field_name + '__in': self}).update( @@ -443,6 +462,8 @@ class BaseQuerySet(object): return result elif result: return result['n'] + except pymongo.errors.DuplicateKeyError, err: + raise NotUniqueError(u'Update failed (%s)' % unicode(err)) except pymongo.errors.OperationFailure, err: if unicode(err) == u'multi not coded yet': message = u'update() method requires MongoDB 1.1.3+' @@ -466,6 +487,59 @@ class BaseQuerySet(object): return self.update( upsert=upsert, multi=False, write_concern=write_concern, **update) + def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): + """Update and return the updated document. + + Returns either the document before or after modification based on `new` + parameter. If no documents match the query and `upsert` is false, + returns ``None``. If upserting and `new` is false, returns ``None``. + + If the full_response parameter is ``True``, the return value will be + the entire response object from the server, including the 'ok' and + 'lastErrorObject' fields, rather than just the modified document. + This is useful mainly because the 'lastErrorObject' document holds + information about the command's execution. + + :param upsert: insert if document doesn't exist (default ``False``) + :param full_response: return the entire response object from the + server (default ``False``) + :param remove: remove rather than updating (default ``False``) + :param new: return updated rather than original document + (default ``False``) + :param update: Django-style update keyword arguments + + .. versionadded:: 0.9 + """ + + if remove and new: + raise OperationError("Conflicting parameters: remove and new") + + if not update and not upsert and not remove: + raise OperationError("No update parameters, must either update or remove") + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + sort = queryset._ordering + + try: + result = queryset._collection.find_and_modify( + query, update, upsert=upsert, sort=sort, remove=remove, new=new, + full_response=full_response, **self._cursor_args) + except pymongo.errors.DuplicateKeyError, err: + raise NotUniqueError(u"Update failed (%s)" % err) + except pymongo.errors.OperationFailure, err: + raise OperationError(u"Update failed (%s)" % err) + + if full_response: + if result["value"] is not None: + result["value"] = self._document._from_son(result["value"]) + else: + if result is not None: + result = self._document._from_son(result) + + return result + def with_id(self, object_id): """Retrieve the object matching the id provided. Uses `object_id` only and raises InvalidQueryError if a filter has been applied. Returns @@ -522,6 +596,19 @@ class BaseQuerySet(object): return self + def using(self, alias): + """This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database. + + :param alias: The database alias + + .. versionadded:: 0.8 + """ + + with switch_db(self._document, alias) as cls: + collection = cls._get_collection() + + return self.clone_into(self.__class__(self._document, collection)) + def clone(self): """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet` @@ -630,7 +717,10 @@ class BaseQuerySet(object): # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) doc_field = getattr(self._document._fields.get(field), "field", None) instance = getattr(doc_field, "document_type", False) - if instance: + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') + GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') + if instance and isinstance(doc_field, (EmbeddedDocumentField, + GenericEmbeddedDocumentField)): distinct = [instance(**doc) for doc in distinct] return distinct @@ -923,10 +1013,39 @@ class BaseQuerySet(object): map_reduce_function = 'inline_map_reduce' else: map_reduce_function = 'map_reduce' - mr_args['out'] = output + + if isinstance(output, basestring): + mr_args['out'] = output + + elif isinstance(output, dict): + ordered_output = [] + + for part in ('replace', 'merge', 'reduce'): + value = output.get(part) + if value: + ordered_output.append((part, value)) + break + + else: + raise OperationError("actionData not specified for output") + + db_alias = output.get('db_alias') + remaing_args = ['db', 'sharded', 'nonAtomic'] + + if db_alias: + ordered_output.append(('db', get_db(db_alias).name)) + del remaing_args[0] + + + for part in remaing_args: + value = output.get(part) + if value: + ordered_output.append((part, value)) + + mr_args['out'] = SON(ordered_output) results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) + map_f, reduce_f, **mr_args) if map_reduce_function == 'map_reduce': results = results.find() @@ -1189,8 +1308,9 @@ class BaseQuerySet(object): if self._ordering: # Apply query ordering self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model + elif self._ordering is None and self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model, unless + # it's been explicitly cleared via order_by with no arguments order = self._get_order_by(self._document._meta['ordering']) self._cursor_obj.sort(order) @@ -1362,7 +1482,7 @@ class BaseQuerySet(object): for subdoc in subclasses: try: subfield = ".".join(f.db_field for f in - subdoc._lookup_field(field.split('.'))) + subdoc._lookup_field(field.split('.'))) ret.append(subfield) found = True break @@ -1392,7 +1512,7 @@ class BaseQuerySet(object): pass key_list.append((key, direction)) - if self._cursor_obj: + if self._cursor_obj and key_list: self._cursor_obj.sort(key_list) return key_list @@ -1450,6 +1570,7 @@ class BaseQuerySet(object): # type of this field and use the corresponding # .to_python(...) from mongoengine.fields import EmbeddedDocumentField + obj = self._document for chunk in path.split('.'): obj = getattr(obj, chunk, None) @@ -1460,6 +1581,7 @@ class BaseQuerySet(object): if obj and data is not None: data = obj.to_python(data) return data + return clean(row) def _sub_js_fields(self, code): @@ -1468,6 +1590,7 @@ class BaseQuerySet(object): substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ + def field_sub(match): # Extract just the field name, and look up the field objects field_name = match.group(1).split('.') diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 1437e76b..cebfcc50 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet): queryset = self.clone() queryset.rewind() return queryset + + +class QuerySetNoDeRef(QuerySet): + """Special no_dereference QuerySet""" + + def __dereference(items, max_depth=1, instance=None, name=None): + return items \ No newline at end of file diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index e31a8b7d..d72d97a3 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -3,6 +3,7 @@ from collections import defaultdict import pymongo from bson import SON +from mongoengine.connection import get_connection from mongoengine.common import _import_class from mongoengine.errors import InvalidQueryError, LookUpError @@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query): mongo_query.update(value) continue - parts = key.split('__') + parts = key.rsplit('__') indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is @@ -115,14 +116,26 @@ def query(_doc_cls=None, _field_operation=False, **query): if key in mongo_query and isinstance(mongo_query[key], dict): mongo_query[key].update(value) # $maxDistance needs to come last - convert to SON - if '$maxDistance' in mongo_query[key]: - value_dict = mongo_query[key] + value_dict = mongo_query[key] + if ('$maxDistance' in value_dict and '$near' in value_dict): value_son = SON() - for k, v in value_dict.iteritems(): - if k == '$maxDistance': - continue - value_son[k] = v - value_son['$maxDistance'] = value_dict['$maxDistance'] + if isinstance(value_dict['$near'], dict): + for k, v in value_dict.iteritems(): + if k == '$maxDistance': + continue + value_son[k] = v + if (get_connection().max_wire_version <= 1): + value_son['$maxDistance'] = value_dict['$maxDistance'] + else: + value_son['$near'] = SON(value_son['$near']) + value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] + else: + for k, v in value_dict.iteritems(): + if k == '$maxDistance': + continue + value_son[k] = v + value_son['$maxDistance'] = value_dict['$maxDistance'] + mongo_query[key] = value_son else: # Store for manually merging later diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 41f4ebf8..a39b05f0 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -1,8 +1,9 @@ import copy -from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import product, reduce +from itertools import product +from functools import reduce +from mongoengine.errors import InvalidQueryError from mongoengine.queryset import transform __all__ = ('Q',) diff --git a/setup.py b/setup.py index 85707d00..7270331a 100644 --- a/setup.py +++ b/setup.py @@ -38,12 +38,14 @@ CLASSIFIERS = [ 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.6", + "Programming Language :: Python :: 2.6.6", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.1", "Programming Language :: Python :: 3.2", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", 'Topic :: Database', 'Topic :: Software Development :: Libraries :: Python Modules', ] @@ -51,12 +53,15 @@ CLASSIFIERS = [ extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'Pillow>=2.0.0', 'django>=1.5.1'] if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'] = find_packages() extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'Pillow>=2.0.0', 'jinja2>=2.6', 'python-dateutil'] + + if sys.version_info[0] == 2 and sys.version_info[1] == 6: + extra_opts['tests_require'].append('unittest2') setup(name='mongoengine', version=VERSION, @@ -72,7 +77,7 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=2.5'], + install_requires=['pymongo>=2.7'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/delta.py b/tests/document/delta.py index b0f5f01a..24910627 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -207,22 +207,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { + ['embedded_field.list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': { '_cls': 'Embedded', 'string_field': 'hello world', 'int_field': 1, 'list_field': ['1', 2, {'hello': 'world'}], 'dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) + self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -253,7 +252,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + ({}, {'embedded_field.list_field.2.list_field.2.hello': 1})) doc.save() doc = doc.reload(10) @@ -548,22 +547,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { + ['db_embedded_field.db_list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}]}, {})) + 'db_dict_field': {'hello': 'world'}}}, {})) self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { + 'db_embedded_field.db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], 'db_dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -594,8 +592,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_list_field': - [1, 2, {}]}, {})) + ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1})) doc.save() doc = doc.reload(10) @@ -735,5 +732,47 @@ class DeltaTest(unittest.TestCase): mydoc._clear_changed_fields() self.assertEqual([], mydoc._get_changed_fields()) + def test_referenced_object_changed_attributes(self): + """Ensures that when you save a new reference to a field, the referenced object isn't altered""" + + class Organization(Document): + name = StringField() + + class User(Document): + name = StringField() + org = ReferenceField('Organization', required=True) + + Organization.drop_collection() + User.drop_collection() + + org1 = Organization(name='Org 1') + org1.save() + + org2 = Organization(name='Org 2') + org2.save() + + user = User(name='Fred', org=org1) + user.save() + + org1.reload() + org2.reload() + user.reload() + self.assertEqual(org1.name, 'Org 1') + self.assertEqual(org2.name, 'Org 2') + self.assertEqual(user.name, 'Fred') + + user.name = 'Harold' + user.org = org2 + + org2.name = 'New Org 2' + self.assertEqual(org2.name, 'New Org 2') + + user.save() + org2.save() + + self.assertEqual(org2.name, 'New Org 2') + org2.reload() + self.assertEqual(org2.name, 'New Org 2') + if __name__ == '__main__': unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 6263e68c..a0bdb136 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -292,6 +292,59 @@ class DynamicTest(unittest.TestCase): person.save() self.assertEqual(Person.objects.first().age, 35) + def test_dynamic_embedded_works_with_only(self): + """Ensure custom fieldnames on a dynamic embedded document are found by qs.only()""" + + class Address(DynamicEmbeddedDocument): + city = StringField() + + class Person(DynamicDocument): + address = EmbeddedDocumentField(Address) + + Person.drop_collection() + + Person(name="Eric", address=Address(city="San Francisco", street_number="1337")).save() + + self.assertEqual(Person.objects.first().address.street_number, '1337') + self.assertEqual(Person.objects.only('address__street_number').first().address.street_number, '1337') + + def test_dynamic_and_embedded_dict_access(self): + """Ensure embedded dynamic documents work with dict[] style access""" + + class Address(EmbeddedDocument): + city = StringField() + + class Person(DynamicDocument): + name = StringField() + + Person.drop_collection() + + Person(name="Ross", address=Address(city="London")).save() + + person = Person.objects.first() + person.attrval = "This works" + + person["phone"] = "555-1212" # but this should too + + # Same thing two levels deep + person["address"]["city"] = "Lundenne" + person.save() + + self.assertEqual(Person.objects.first().address.city, "Lundenne") + + self.assertEqual(Person.objects.first().phone, "555-1212") + + person = Person.objects.first() + person.address = Address(city="Londinium") + person.save() + + self.assertEqual(Person.objects.first().address.city, "Londinium") + + + person = Person.objects.first() + person["age"] = 35 + person.save() + self.assertEqual(Person.objects.first().age, 35) if __name__ == '__main__': unittest.main() diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..3ecbbec1 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -15,7 +15,7 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, - InvalidQueryError) + InvalidQueryError, NotUniqueError) from mongoengine.queryset import NULLIFY, Q from mongoengine.connection import get_db from mongoengine.base import get_document @@ -57,7 +57,7 @@ class InstanceTest(unittest.TestCase): date = DateTimeField(default=datetime.now) meta = { 'max_documents': 10, - 'max_size': 90000, + 'max_size': 4096, } Log.drop_collection() @@ -75,7 +75,7 @@ class InstanceTest(unittest.TestCase): options = Log.objects._collection.options() self.assertEqual(options['capped'], True) self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 90000) + self.assertTrue(options['size'] >= 4096) # Check that the document cannot be redefined with different options def recreate_log_document(): @@ -353,6 +353,14 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 20) + person.reload('age') + self.assertEqual(person.name, "Test User") + self.assertEqual(person.age, 21) + + person.reload() + self.assertEqual(person.name, "Mr Test User") + self.assertEqual(person.age, 21) + person.reload() self.assertEqual(person.name, "Mr Test User") self.assertEqual(person.age, 21) @@ -398,10 +406,11 @@ class InstanceTest(unittest.TestCase): doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) + 'list_field', 'dict_field.woot', 'embedded_field.list_field', + 'embedded_field.dict_field.woot']) doc.save() + self.assertEqual(len(doc.list_field), 4) doc = doc.reload(10) self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(len(doc.list_field), 4) @@ -409,6 +418,16 @@ class InstanceTest(unittest.TestCase): self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.dict_field), 2) + doc.list_field.append(1) + doc.save() + doc.dict_field['extra'] = 1 + doc = doc.reload(10, 'list_field') + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(len(doc.list_field), 5) + self.assertEqual(len(doc.dict_field), 3) + self.assertEqual(len(doc.embedded_field.list_field), 4) + self.assertEqual(len(doc.embedded_field.dict_field), 2) + def test_reload_doesnt_exist(self): class Foo(Document): pass @@ -515,9 +534,6 @@ class InstanceTest(unittest.TestCase): class Email(EmbeddedDocument): email = EmailField() - def clean(self): - print "instance:" - print self._instance class Account(Document): email = EmbeddedDocumentField(Email) @@ -820,6 +836,80 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) + def test_save_atomicity_condition(self): + + class Widget(Document): + toggle = BooleanField(default=False) + count = IntField(default=0) + save_id = UUIDField() + + def flip(widget): + widget.toggle = not widget.toggle + widget.count += 1 + + def UUID(i): + return uuid.UUID(int=i) + + Widget.drop_collection() + + w1 = Widget(toggle=False, save_id=UUID(1)) + + # ignore save_condition on new record creation + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.save_id, UUID(1)) + self.assertEqual(w1.count, 0) + + # mismatch in save_condition prevents save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 0) + + # matched save_condition allows save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(1)}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + + # save_condition can be used to ensure atomic read & updates + # i.e., prevent interleaved reads and writes from separate contexts + w2 = Widget.objects.get() + self.assertEqual(w1, w2) + old_id = w1.save_id + + flip(w1) + w1.save_id = UUID(2) + w1.save(save_condition={'save_id':old_id}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 2) + flip(w2) + flip(w2) + w2.save(save_condition={'save_id':old_id}) + w2.reload() + self.assertFalse(w2.toggle) + self.assertEqual(w2.count, 2) + + # save_condition uses mongoengine-style operator syntax + flip(w1) + w1.save(save_condition={'count__lt':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + flip(w1) + w1.save(save_condition={'count__gte':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" @@ -990,6 +1080,16 @@ class InstanceTest(unittest.TestCase): self.assertRaises(InvalidQueryError, update_no_op_raises) + def test_update_unique_field(self): + class Doc(Document): + name = StringField(unique=True) + + doc1 = Doc(name="first").save() + doc2 = Doc(name="second").save() + + self.assertRaises(NotUniqueError, lambda: + doc2.update(set__name=doc1.name)) + def test_embedded_update(self): """ Test update on `EmbeddedDocumentField` fields @@ -2281,6 +2381,8 @@ class InstanceTest(unittest.TestCase): log.machine = "Localhost" log.save() + self.assertTrue(log.id is not None) + log.log = "Saving" log.save() @@ -2304,6 +2406,8 @@ class InstanceTest(unittest.TestCase): log.machine = "Localhost" log.save() + self.assertTrue(log.id is not None) + log.log = "Saving" log.save() @@ -2411,7 +2515,7 @@ class InstanceTest(unittest.TestCase): for parameter_name, parameter in self.parameters.iteritems(): parameter.expand() - class System(Document): + class NodesSystem(Document): name = StringField(required=True) nodes = MapField(ReferenceField(Node, dbref=False)) @@ -2419,18 +2523,18 @@ class InstanceTest(unittest.TestCase): for node_name, node in self.nodes.iteritems(): node.expand() node.save(*args, **kwargs) - super(System, self).save(*args, **kwargs) + super(NodesSystem, self).save(*args, **kwargs) - System.drop_collection() + NodesSystem.drop_collection() Node.drop_collection() - system = System(name="system") + system = NodesSystem(name="system") system.nodes["node"] = Node() system.save() system.nodes["node"].parameters["param"] = Parameter() system.save() - system = System.objects.first() + system = NodesSystem.objects.first() self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) def test_embedded_document_equality(self): @@ -2452,5 +2556,65 @@ class InstanceTest(unittest.TestCase): f1.ref # Dereferences lazily self.assertEqual(f1, f2) + def test_dbref_equality(self): + class Test2(Document): + name = StringField() + + class Test3(Document): + name = StringField() + + class Test(Document): + name = StringField() + test2 = ReferenceField('Test2') + test3 = ReferenceField('Test3') + + Test.drop_collection() + Test2.drop_collection() + Test3.drop_collection() + + t2 = Test2(name='a') + t2.save() + + t3 = Test3(name='x') + t3.id = t2.id + t3.save() + + t = Test(name='b', test2=t2, test3=t3) + + f = Test._from_son(t.to_mongo()) + + dbref2 = f._data['test2'] + obj2 = f.test2 + self.assertTrue(isinstance(dbref2, DBRef)) + self.assertTrue(isinstance(obj2, Test2)) + self.assertTrue(obj2.id == dbref2.id) + self.assertTrue(obj2 == dbref2) + self.assertTrue(dbref2 == obj2) + + dbref3 = f._data['test3'] + obj3 = f.test3 + self.assertTrue(isinstance(dbref3, DBRef)) + self.assertTrue(isinstance(obj3, Test3)) + self.assertTrue(obj3.id == dbref3.id) + self.assertTrue(obj3 == dbref3) + self.assertTrue(dbref3 == obj3) + + self.assertTrue(obj2.id == obj3.id) + self.assertTrue(dbref2.id == dbref3.id) + self.assertFalse(dbref2 == dbref3) + self.assertFalse(dbref3 == dbref2) + self.assertTrue(dbref2 != dbref3) + self.assertTrue(dbref3 != dbref2) + + self.assertFalse(obj2 == dbref3) + self.assertFalse(dbref3 == obj2) + self.assertTrue(obj2 != dbref3) + self.assertTrue(dbref3 != obj2) + + self.assertFalse(obj3 == dbref2) + self.assertFalse(dbref2 == obj3) + self.assertTrue(obj3 != dbref2) + self.assertTrue(dbref2 != obj3) + if __name__ == '__main__': unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 902b1512..7ae53e8a 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -279,7 +279,7 @@ class FileTest(unittest.TestCase): t.image.put(f) self.fail("Should have raised an invalidation error") except ValidationError, e: - self.assertEqual("%s" % e, "Invalid image: cannot identify image file") + self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() t.image.put(open(TEST_IMAGE_PATH, 'rb')) diff --git a/tests/queryset/__init__.py b/tests/queryset/__init__.py index 8a93c19f..c36b2684 100644 --- a/tests/queryset/__init__.py +++ b/tests/queryset/__init__.py @@ -3,3 +3,4 @@ from field_list import * from queryset import * from visitor import * from geo import * +from modify import * \ No newline at end of file diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 65ab519a..5148a48e 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -5,6 +5,8 @@ import unittest from datetime import datetime, timedelta from mongoengine import * +from nose.plugins.skip import SkipTest + __all__ = ("GeoQueriesTest",) @@ -139,6 +141,7 @@ class GeoQueriesTest(unittest.TestCase): def test_spherical_geospatial_operators(self): """Ensure that spherical geospatial queries are working """ + raise SkipTest("https://jira.mongodb.org/browse/SERVER-14039") class Point(Document): location = GeoPointField() diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py new file mode 100644 index 00000000..e0c7d1fe --- /dev/null +++ b/tests/queryset/modify.py @@ -0,0 +1,102 @@ +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import connect, Document, IntField + +__all__ = ("FindAndModifyTest",) + + +class Doc(Document): + id = IntField(primary_key=True) + value = IntField() + + +class FindAndModifyTest(unittest.TestCase): + + def setUp(self): + connect(db="mongoenginetest") + Doc.drop_collection() + + def assertDbEqual(self, docs): + self.assertEqual(list(Doc._collection.find().sort("id")), docs) + + def test_modify(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_new(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) + doc.value = -1 + self.assertEqual(new_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_not_existing(self): + Doc(id=0, value=0).save() + self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_modify_with_upsert(self): + Doc(id=0, value=0).save() + old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) + self.assertEqual(old_doc, None) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_upsert_existing(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_upsert_with_new(self): + Doc(id=0, value=0).save() + new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) + self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1}) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_remove(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(remove=True) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_find_and_modify_with_remove_not_existing(self): + Doc(id=0, value=0).save() + self.assertEqual(Doc.objects(id=1).modify(remove=True), None) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_modify_with_order_by(self): + Doc(id=0, value=3).save() + Doc(id=1, value=2).save() + Doc(id=2, value=1).save() + doc = Doc(id=3, value=0).save() + + old_doc = Doc.objects().order_by("-id").modify(set__value=-1) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([ + {"_id": 0, "value": 3}, {"_id": 1, "value": 2}, + {"_id": 2, "value": 1}, {"_id": 3, "value": -1}]) + + def test_modify_with_fields(self): + Doc(id=0, value=0).save() + Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) + self.assertEqual(old_doc.to_mongo(), {"_id": 1}) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 7ff2965d..0999a2d6 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -14,9 +14,9 @@ from pymongo.read_preferences import ReadPreference from bson import ObjectId from mongoengine import * -from mongoengine.connection import get_connection +from mongoengine.connection import get_connection, get_db from mongoengine.python_support import PY3 -from mongoengine.context_managers import query_counter +from mongoengine.context_managers import query_counter, switch_db from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, queryset_manager) @@ -25,10 +25,17 @@ from mongoengine.errors import InvalidQueryError __all__ = ("QuerySetTest",) +class db_ops_tracker(query_counter): + def get_ops(self): + ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + return list(self.db.system.profile.find(ignore_query)) + + class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + connect(db='mongoenginetest2', alias='test2') class PersonMeta(EmbeddedDocument): weight = IntField() @@ -650,7 +657,10 @@ class QuerySetTest(unittest.TestCase): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 for the insert + if (get_connection().max_wire_version <= 1): + self.assertEqual(q, 1) + else: + self.assertEqual(q, 99) # profiling logs each doc now in the bulk op Blog.drop_collection() Blog.ensure_indexes() @@ -659,7 +669,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) Blog.objects.insert(blogs) - self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch + if (get_connection().max_wire_version <= 1): + self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch + else: + self.assertEqual(q, 100) # 99 for insert, and 1 for in bulk fetch Blog.drop_collection() @@ -1040,6 +1053,54 @@ class QuerySetTest(unittest.TestCase): expected = [blog_post_1, blog_post_2, blog_post_3] self.assertSequence(qs, expected) + def test_clear_ordering(self): + """ Ensure that the default ordering can be cleared by calling order_by(). + """ + class BlogPost(Document): + title = StringField() + published_date = DateTimeField() + + meta = { + 'ordering': ['-published_date'] + } + + BlogPost.drop_collection() + + with db_ops_tracker() as q: + BlogPost.objects.filter(title='whatever').first() + self.assertEqual(len(q.get_ops()), 1) + self.assertEqual(q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) + + with db_ops_tracker() as q: + BlogPost.objects.filter(title='whatever').order_by().first() + self.assertEqual(len(q.get_ops()), 1) + print q.get_ops()[0]['query'] + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + + def test_no_ordering_for_get(self): + """ Ensure that Doc.objects.get doesn't use any ordering. + """ + class BlogPost(Document): + title = StringField() + published_date = DateTimeField() + + meta = { + 'ordering': ['-published_date'] + } + + BlogPost.objects.create(title='whatever', published_date=datetime.utcnow()) + + with db_ops_tracker() as q: + BlogPost.objects.get(title='whatever') + self.assertEqual(len(q.get_ops()), 1) + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + + # Ordering should be ignored for .get even if we set it explicitly + with db_ops_tracker() as q: + BlogPost.objects.order_by('-title').get(title='whatever') + self.assertEqual(len(q.get_ops()), 1) + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -1925,6 +1986,140 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_map_reduce_custom_output(self): + """ + Test map/reduce custom output + """ + register_connection('test2', 'mongoenginetest2') + + class Family(Document): + id = IntField( + primary_key=True) + log = StringField() + + class Person(Document): + id = IntField( + primary_key=True) + name = StringField() + age = IntField() + family = ReferenceField(Family) + + Family.drop_collection() + Person.drop_collection() + + # creating first family + f1 = Family(id=1, log="Trav 02 de Julho") + f1.save() + + # persons of first family + Person(id=1, family=f1, name=u"Wilson Jr", age=21).save() + Person(id=2, family=f1, name=u"Wilson Father", age=45).save() + Person(id=3, family=f1, name=u"Eliana Costa", age=40).save() + Person(id=4, family=f1, name=u"Tayza Mariana", age=17).save() + + # creating second family + f2 = Family(id=2, log="Av prof frasc brunno") + f2.save() + + #persons of second family + Person(id=5, family=f2, name="Isabella Luanna", age=16).save() + Person(id=6, family=f2, name="Sandra Mara", age=36).save() + Person(id=7, family=f2, name="Igor Gabriel", age=10).save() + + # creating third family + f3 = Family(id=3, log="Av brazil") + f3.save() + + #persons of thrird family + Person(id=8, family=f3, name="Arthur WA", age=30).save() + Person(id=9, family=f3, name="Paula Leonel", age=25).save() + + # executing join map/reduce + map_person = """ + function () { + emit(this.family, { + totalAge: this.age, + persons: [{ + name: this.name, + age: this.age + }]}); + } + """ + + map_family = """ + function () { + emit(this._id, { + totalAge: 0, + persons: [] + }); + } + """ + + reduce_f = """ + function (key, values) { + var family = {persons: [], totalAge: 0}; + + values.forEach(function(value) { + if (value.persons) { + value.persons.forEach(function (person) { + family.persons.push(person); + family.totalAge += person.age; + }); + } + }); + + return family; + } + """ + cursor = Family.objects.map_reduce( + map_f=map_family, + reduce_f=reduce_f, + output={'replace': 'family_map', 'db_alias': 'test2'}) + + # start a map/reduce + cursor.next() + + results = Person.objects.map_reduce( + map_f=map_person, + reduce_f=reduce_f, + output={'reduce': 'family_map', 'db_alias': 'test2'}) + + results = list(results) + collection = get_db('test2').family_map + + self.assertEqual( + collection.find_one({'_id': 1}), { + '_id': 1, + 'value': { + 'persons': [ + {'age': 21, 'name': u'Wilson Jr'}, + {'age': 45, 'name': u'Wilson Father'}, + {'age': 40, 'name': u'Eliana Costa'}, + {'age': 17, 'name': u'Tayza Mariana'}], + 'totalAge': 123} + }) + + self.assertEqual( + collection.find_one({'_id': 2}), { + '_id': 2, + 'value': { + 'persons': [ + {'age': 16, 'name': u'Isabella Luanna'}, + {'age': 36, 'name': u'Sandra Mara'}, + {'age': 10, 'name': u'Igor Gabriel'}], + 'totalAge': 62} + }) + + self.assertEqual( + collection.find_one({'_id': 3}), { + '_id': 3, + 'value': { + 'persons': [ + {'age': 30, 'name': u'Arthur WA'}, + {'age': 25, 'name': u'Paula Leonel'}], + 'totalAge': 55} + }) + def test_map_reduce_finalize(self): """Ensure that map, reduce, and finalize run and introduce "scope" by simulating "hotness" ranking with Reddit algorithm. @@ -2540,6 +2735,27 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(authors, [mark_twain, john_tolkien]) + def test_distinct_ListField_ReferenceField(self): + class Foo(Document): + bar_lst = ListField(ReferenceField('Bar')) + + class Bar(Document): + text = StringField() + + Bar.drop_collection() + Foo.drop_collection() + + bar_1 = Bar(text="hi") + bar_1.save() + + bar_2 = Bar(text="bye") + bar_2.save() + + foo = Foo(bar=bar_1, bar_lst=[bar_1, bar_2]) + foo.save() + + self.assertEqual(Foo.objects.distinct("bar_lst"), [bar_1, bar_2]) + def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """ @@ -2957,6 +3173,23 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_using(self): + """Ensure that switching databases for a queryset is possible + """ + class Number2(Document): + n = IntField() + + Number2.drop_collection() + with switch_db(Number2, 'test2') as Number2: + Number2.drop_collection() + + for i in range(1, 10): + t = Number2(n=i) + t.switch_db('test2') + t.save() + + self.assertEqual(len(Number2.objects.using('test2')), 9) + def test_unset_reference(self): class Comment(Document): text = StringField() @@ -3586,7 +3819,13 @@ class QuerySetTest(unittest.TestCase): [x for x in people] self.assertEqual(100, len(people._result_cache)) - self.assertEqual(None, people._len) + + import platform + + if platform.python_implementation() != "PyPy": + # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. + # This may be a bug in PyPy (PyPy/#1802) but it does not affect the behavior of MongoEngine. + self.assertEqual(None, people._len) self.assertEqual(q, 1) list(people) @@ -3814,6 +4053,111 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Example.objects(size=instance_size).count(), 1) self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) + def test_cursor_in_an_if_stmt(self): + + class Test(Document): + test_field = StringField() + + Test.drop_collection() + queryset = Test.objects + + if queryset: + raise AssertionError('Empty cursor returns True') + + test = Test() + test.test_field = 'test' + test.save() + + queryset = Test.objects + if not test: + raise AssertionError('Cursor has data and returned False') + + queryset.next() + if not queryset: + raise AssertionError('Cursor has data and it must returns True,' + ' even in the last item.') + + def test_bool_performance(self): + + class Person(Document): + name = StringField() + + Person.drop_collection() + for i in xrange(100): + Person(name="No: %s" % i).save() + + with query_counter() as q: + if Person.objects: + pass + + self.assertEqual(q, 1) + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertEqual(op['nreturned'], 1) + + + def test_bool_with_ordering(self): + + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name="Test").save() + + qs = Person.objects.order_by('name') + + with query_counter() as q: + + if qs: + pass + + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertFalse('$orderby' in op['query'], + 'BaseQuerySet cannot use orderby in if stmt') + + with query_counter() as p: + + for x in qs: + pass + + op = p.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertTrue('$orderby' in op['query'], + 'BaseQuerySet cannot remove orderby in for loop') + + def test_bool_with_ordering_from_meta_dict(self): + + class Person(Document): + name = StringField() + meta = { + 'ordering': ['name'] + } + + Person.drop_collection() + + Person(name="B").save() + Person(name="C").save() + Person(name="A").save() + + with query_counter() as q: + + if Person.objects: + pass + + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertFalse('$orderby' in op['query'], + 'BaseQuerySet must remove orderby from meta in boolen test') + + self.assertEqual(Person.objects.first().name, 'A') + self.assertTrue(Person.objects._has_data(), + 'Cursor has data and returned False') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_connection.py b/tests/test_connection.py index 96135bc5..6cdbd654 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,11 @@ import sys sys.path[0:0] = [""] -import unittest + +try: + import unittest2 as unittest +except ImportError: + import unittest + import datetime import pymongo @@ -34,6 +39,17 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + def test_sharing_connections(self): + """Ensure that connections are shared when the connection settings are exactly the same + """ + connect('mongoenginetest', alias='testdb1') + + expected_connection = get_connection('testdb1') + + connect('mongoenginetest', alias='testdb2') + actual_connection = get_connection('testdb2') + self.assertEqual(expected_connection, actual_connection) + def test_connect_uri(self): """Ensure that the connect() method works properly with uri's """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py new file mode 100644 index 00000000..c761a41e --- /dev/null +++ b/tests/test_datastructures.py @@ -0,0 +1,107 @@ +import unittest +from mongoengine.base.datastructures import StrictDict, SemiStrictDict + +class TestStrictDict(unittest.TestCase): + def strict_dict_class(self, *args, **kwargs): + return StrictDict.create(*args, **kwargs) + def setUp(self): + self.dtype = self.strict_dict_class(("a", "b", "c")) + def test_init(self): + d = self.dtype(a=1, b=1, c=1) + self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + + def test_init_fails_on_nonexisting_attrs(self): + self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) + + def test_eq(self): + d = self.dtype(a=1, b=1, c=1) + dd = self.dtype(a=1, b=1, c=1) + e = self.dtype(a=1, b=1, c=3) + f = self.dtype(a=1, b=1) + g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) + h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) + i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) + + self.assertEqual(d, dd) + self.assertNotEqual(d, e) + self.assertNotEqual(d, f) + self.assertNotEqual(d, g) + self.assertNotEqual(f, d) + self.assertEqual(d, h) + self.assertNotEqual(d, i) + + def test_setattr_getattr(self): + d = self.dtype() + d.a = 1 + self.assertEqual(d.a, 1) + self.assertRaises(AttributeError, lambda: d.b) + + def test_setattr_raises_on_nonexisting_attr(self): + d = self.dtype() + def _f(): + d.x=1 + self.assertRaises(AttributeError, _f) + + def test_setattr_getattr_special(self): + d = self.strict_dict_class(["items"]) + d.items = 1 + self.assertEqual(d.items, 1) + + def test_get(self): + d = self.dtype(a=1) + self.assertEqual(d.get('a'), 1) + self.assertEqual(d.get('b', 'bla'), 'bla') + + def test_items(self): + d = self.dtype(a=1) + self.assertEqual(d.items(), [('a', 1)]) + d = self.dtype(a=1, b=2) + self.assertEqual(d.items(), [('a', 1), ('b', 2)]) + + def test_mappings_protocol(self): + d = self.dtype(a=1, b=2) + assert dict(d) == {'a': 1, 'b': 2} + assert dict(**d) == {'a': 1, 'b': 2} + + +class TestSemiSrictDict(TestStrictDict): + def strict_dict_class(self, *args, **kwargs): + return SemiStrictDict.create(*args, **kwargs) + + def test_init_fails_on_nonexisting_attrs(self): + # disable irrelevant test + pass + + def test_setattr_raises_on_nonexisting_attr(self): + # disable irrelevant test + pass + + def test_setattr_getattr_nonexisting_attr_succeeds(self): + d = self.dtype() + d.x = 1 + self.assertEqual(d.x, 1) + + def test_init_succeeds_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) + + def test_iter_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d), ['a', 'b', 'c', 'x']) + + def test_iteritems_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)]) + + def tets_cmp_with_strict_dicts(self): + d = self.dtype(a=1, b=1, c=1) + dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1) + self.assertEqual(d, dd) + + def test_cmp_with_strict_dict_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2) + self.assertEqual(d, dd) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 6f2664a3..dc416007 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -291,6 +291,30 @@ class FieldTest(unittest.TestCase): self.assertEqual(employee.friends, friends) self.assertEqual(q, 2) + def test_list_of_lists_of_references(self): + + class User(Document): + name = StringField() + + class Post(Document): + user_lists = ListField(ListField(ReferenceField(User))) + + class SimpleList(Document): + users = ListField(ReferenceField(User)) + + User.drop_collection() + Post.drop_collection() + + u1 = User.objects.create(name='u1') + u2 = User.objects.create(name='u2') + u3 = User.objects.create(name='u3') + + SimpleList.objects.create(users=[u1, u2, u3]) + self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) + + Post.objects.create(user_lists=[[u1, u2], [u3]]) + self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]]) + def test_circular_reference(self): """Ensure you can handle circular references """ diff --git a/tests/test_signals.py b/tests/test_signals.py index 50e5e6b8..3d0cbb3e 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -54,7 +54,9 @@ class SignalTests(unittest.TestCase): @classmethod def post_save(cls, sender, document, **kwargs): + dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() signal_output.append('post_save signal, %s' % document) + signal_output.append('post_save dirty keys, %s' % dirty_keys) if 'created' in kwargs: if kwargs['created']: signal_output.append('Is created') @@ -203,6 +205,7 @@ class SignalTests(unittest.TestCase): "pre_save_post_validation signal, Bill Shakespeare", "Is created", "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", "Is created" ]) @@ -213,6 +216,7 @@ class SignalTests(unittest.TestCase): "pre_save_post_validation signal, William Shakespeare", "Is updated", "post_save signal, William Shakespeare", + "post_save dirty keys, ['name']", "Is updated" ])