diff --git a/.travis.yml b/.travis.yml index 5739909b..40736165 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,26 +6,28 @@ python: - "2.7" - "3.2" - "3.3" + - "3.4" + - "pypy" env: - - PYMONGO=dev DJANGO=1.6 - - PYMONGO=dev DJANGO=1.5.5 - - PYMONGO=dev DJANGO=1.4.10 - - PYMONGO=2.5 DJANGO=1.6 - - PYMONGO=2.5 DJANGO=1.5.5 - - PYMONGO=2.5 DJANGO=1.4.10 - - PYMONGO=3.2 DJANGO=1.6 - - PYMONGO=3.2 DJANGO=1.5.5 - - PYMONGO=3.3 DJANGO=1.6 - - PYMONGO=3.3 DJANGO=1.5.5 + - PYMONGO=dev DJANGO=1.6.5 + - PYMONGO=dev DJANGO=1.5.8 + - PYMONGO=2.7.1 DJANGO=1.6.5 + - PYMONGO=2.7.1 DJANGO=1.5.8 + +matrix: + fast_finish: true + install: - - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then cp /usr/lib/*/libz.so $VIRTUAL_ENV/lib/; fi - - if [[ $TRAVIS_PYTHON_VERSION == '2.'* ]]; then pip install pil --use-mirrors ; true; fi + - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev python-tk - if [[ $PYMONGO == 'dev' ]]; then pip install https://github.com/mongodb/mongo-python-driver/tarball/master; true; fi - - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO --use-mirrors; true; fi + - if [[ $PYMONGO != 'dev' ]]; then pip install pymongo==$PYMONGO; true; fi + - pip install Django==$DJANGO - pip install https://pypi.python.org/packages/source/p/python-dateutil/python-dateutil-2.1.tar.gz#md5=1534bb15cf311f07afaa3aacba1c028b - python setup.py install script: - python setup.py test + - if [[ $TRAVIS_PYTHON_VERSION == '3.'* ]]; then 2to3 . -w; fi; + - python benchmark.py notifications: irc: "irc.freenode.org#mongoengine" branches: diff --git a/AUTHORS b/AUTHORS index d6994d50..c6c47d79 100644 --- a/AUTHORS +++ b/AUTHORS @@ -171,7 +171,7 @@ that much better: * Michael Bartnett (https://github.com/michaelbartnett) * Alon Horev (https://github.com/alonho) * Kelvin Hammond (https://github.com/kelvinhammond) - * Jatin- (https://github.com/jatin-) + * Jatin Chopra (https://github.com/jatin) * Paul Uithol (https://github.com/PaulUithol) * Thom Knowles (https://github.com/fleat) * Paul (https://github.com/squamous) @@ -189,3 +189,14 @@ that much better: * Tom (https://github.com/tomprimozic) * j0hnsmith (https://github.com/j0hnsmith) * Damien Churchill (https://github.com/damoxc) + * Jonathan Simon Prates (https://github.com/jonathansp) + * Thiago Papageorgiou (https://github.com/tmpapageorgiou) + * Omer Katz (https://github.com/thedrow) + * Falcon Dai (https://github.com/falcondai) + * Polyrabbit (https://github.com/polyrabbit) + * Sagiv Malihi (https://github.com/sagivmalihi) + * Dmitry Konishchev (https://github.com/KonishchevDmitry) + * Martyn Smith (https://github.com/martynsmith) + * Andrei Zbikowski (https://github.com/b1naryth1ef) + * Ronald van Rij (https://github.com/ronaldvanrij) + * François Schmidts (https://github.com/jaesivsm) diff --git a/README.rst b/README.rst index cc4524ae..8c3ee26e 100644 --- a/README.rst +++ b/README.rst @@ -29,9 +29,18 @@ setup.py install``. Dependencies ============ -- pymongo 2.5+ +- pymongo>=2.5 - sphinx (optional - for documentation generation) +Optional Dependencies +--------------------- +- **Django Integration:** Django>=1.4.0 for Python 2.x or PyPy and Django>=1.5.0 for Python 3.x +- **Image Fields**: Pillow>=2.0.0 or PIL (not recommended since MongoEngine is tested with Pillow) +- dateutil>=2.1.0 + +.. note + MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: Django 1.6.5 + Examples ======== Some simple examples of what MongoEngine code looks like:: diff --git a/benchmark.py b/benchmark.py index 16b2fd47..53ecf32c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -15,7 +15,7 @@ def cprofile_main(): class Noddy(Document): fields = DictField() - for i in xrange(1): + for i in range(1): noddy = Noddy() for j in range(20): noddy.fields["key" + str(j)] = "value " + str(j) @@ -113,6 +113,7 @@ def main(): 4.68946313858 ---------------------------------------------------------------------------------------------------- """ + print("Benchmarking...") setup = """ from pymongo import MongoClient @@ -127,7 +128,7 @@ connection = MongoClient() db = connection.timeit_test noddy = db.noddy -for i in xrange(10000): +for i in range(10000): example = {'fields': {}} for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) @@ -138,10 +139,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - Pymongo""" + print("-" * 100) + print("""Creating 10000 dictionaries - Pymongo""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ from pymongo import MongoClient @@ -150,7 +151,7 @@ connection = MongoClient() db = connection.timeit_test noddy = db.noddy -for i in xrange(10000): +for i in range(10000): example = {'fields': {}} for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) @@ -161,10 +162,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""" + print("-" * 100) + print("""Creating 10000 dictionaries - Pymongo write_concern={"w": 0}""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) setup = """ from pymongo import MongoClient @@ -180,7 +181,7 @@ class Noddy(Document): """ stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -190,13 +191,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() fields = {} for j in range(20): @@ -208,13 +209,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries without continual assign - MongoEngine""" + print("-" * 100) + print("""Creating 10000 dictionaries without continual assign - MongoEngine""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -224,13 +225,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -240,13 +241,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -256,13 +257,13 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) stmt = """ -for i in xrange(10000): +for i in range(10000): noddy = Noddy() for j in range(20): noddy.fields["key"+str(j)] = "value "+str(j) @@ -272,11 +273,11 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print "-" * 100 - print """Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""" + print("-" * 100) + print("""Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False""") t = timeit.Timer(stmt=stmt, setup=setup) - print t.timeit(1) + print(t.timeit(1)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/changelog.rst b/docs/changelog.rst index 51134238..c722b592 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,32 @@ Changelog ========= + +Changes in 0.9.X - DEV +====================== + +- post_save signal now has access to delta information about field changes #594 #589 +- Don't query with $orderby for qs.get() #600 +- Fix id shard key save issue #636 +- Fixes issue with recursive embedded document errors #557 +- Fix clear_changed_fields() clearing unsaved documents bug #602 +- Removing support for Django 1.4.x, pymongo 2.5.x, pymongo 2.6.x. +- Removing support for Python < 2.6.6 +- Fixed $maxDistance location for geoJSON $near queries with MongoDB 2.6+ #664 +- QuerySet.modify() method to provide find_and_modify() like behaviour #677 +- Added support for the using() method on a queryset #676 +- PYPY support #673 +- Connection pooling #674 +- Avoid to open all documents from cursors in an if stmt #655 +- Ability to clear the ordering #657 +- Raise NotUniqueError in Document.update() on pymongo.errors.DuplicateKeyError #626 +- Slots - memory improvements #625 +- Fixed incorrectly split a query key when it ends with "_" #619 +- Geo docs updates #613 +- Workaround a dateutil bug #608 +- Conditional save for atomic-style operations #511 +- Allow dynamic dictionary-style field access #559 + Changes in 0.8.7 ================ - Calling reload on deleted / nonexistant documents raises DoesNotExist (#538) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 5d8b628a..07bce3bb 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -531,6 +531,8 @@ field name to the index definition. Sometimes its more efficient to index parts of Embedded / dictionary fields, in this case use 'dot' notation to identify the value to index eg: `rank.title` +.. _geospatial-indexes: + Geospatial indexes ------------------ diff --git a/docs/guide/gridfs.rst b/docs/guide/gridfs.rst index 596585de..68e7a6d2 100644 --- a/docs/guide/gridfs.rst +++ b/docs/guide/gridfs.rst @@ -46,7 +46,7 @@ slightly different manner. First, a new file must be created by calling the marmot.photo.write('some_more_image_data') marmot.photo.close() - marmot.photo.save() + marmot.save() Deletion -------- diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 32cbb94e..34996dc6 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -488,8 +488,9 @@ calling it with keyword arguments:: Atomic updates ============== Documents may be updated atomically by using the -:meth:`~mongoengine.queryset.QuerySet.update_one` and -:meth:`~mongoengine.queryset.QuerySet.update` methods on a +:meth:`~mongoengine.queryset.QuerySet.update_one`, +:meth:`~mongoengine.queryset.QuerySet.update` and +:meth:`~mongoengine.queryset.QuerySet.modify` methods on a :meth:`~mongoengine.queryset.QuerySet`. There are several different "modifiers" that you may use with these methods: @@ -499,11 +500,13 @@ that you may use with these methods: * ``dec`` -- decrement a value by a given amount * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list -* ``pop`` -- remove the first or last element of a list +* ``pop`` -- remove the first or last element of a list `depending on the value`_ * ``pull`` -- remove a value from a list * ``pull_all`` -- remove several values from a list * ``add_to_set`` -- add value to a list only if its not in the list already +.. _depending on the value: http://docs.mongodb.org/manual/reference/operator/update/pop/ + The syntax for atomic updates is similar to the querying syntax, but the modifier comes before the field, not after it:: diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 4652fb56..32a66018 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,4 +1,6 @@ import weakref +import functools +import itertools from mongoengine.common import _import_class __all__ = ("BaseDict", "BaseList") @@ -156,3 +158,98 @@ class BaseList(list): def _mark_as_changed(self): if hasattr(self._instance, '_mark_as_changed'): self._instance._mark_as_changed(self._name) + + +class StrictDict(object): + __slots__ = () + _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) + _classes = {} + def __init__(self, **kwargs): + for k,v in kwargs.iteritems(): + setattr(self, k, v) + def __getitem__(self, key): + key = '_reserved_' + key if key in self._special_fields else key + try: + return getattr(self, key) + except AttributeError: + raise KeyError(key) + def __setitem__(self, key, value): + key = '_reserved_' + key if key in self._special_fields else key + return setattr(self, key, value) + def __contains__(self, key): + return hasattr(self, key) + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + def pop(self, key, default=None): + v = self.get(key, default) + try: + delattr(self, key) + except AttributeError: + pass + return v + def iteritems(self): + for key in self: + yield key, self[key] + def items(self): + return [(k, self[k]) for k in iter(self)] + def keys(self): + return list(iter(self)) + def __iter__(self): + return (key for key in self.__slots__ if hasattr(self, key)) + def __len__(self): + return len(list(self.iteritems())) + def __eq__(self, other): + return self.items() == other.items() + def __neq__(self, other): + return self.items() != other.items() + + @classmethod + def create(cls, allowed_keys): + allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) + allowed_keys = frozenset(allowed_keys_tuple) + if allowed_keys not in cls._classes: + class SpecificStrictDict(cls): + __slots__ = allowed_keys_tuple + cls._classes[allowed_keys] = SpecificStrictDict + return cls._classes[allowed_keys] + + +class SemiStrictDict(StrictDict): + __slots__ = ('_extras') + _classes = {} + def __getattr__(self, attr): + try: + super(SemiStrictDict, self).__getattr__(attr) + except AttributeError: + try: + return self.__getattribute__('_extras')[attr] + except KeyError as e: + raise AttributeError(e) + def __setattr__(self, attr, value): + try: + super(SemiStrictDict, self).__setattr__(attr, value) + except AttributeError: + try: + self._extras[attr] = value + except AttributeError: + self._extras = {attr: value} + + def __delattr__(self, attr): + try: + super(SemiStrictDict, self).__delattr__(attr) + except AttributeError: + try: + del self._extras[attr] + except KeyError as e: + raise AttributeError(e) + + def __iter__(self): + try: + extras_iter = iter(self.__getattribute__('_extras')) + except AttributeError: + extras_iter = () + return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter) + diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f5eae8ff..e77ea080 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -13,24 +13,23 @@ from mongoengine import signals from mongoengine.common import _import_class from mongoengine.errors import (ValidationError, InvalidDocumentError, LookUpError) -from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, - to_str_keys_recursive) +from mongoengine.python_support import PY3, txt_type from mongoengine.base.common import get_document, ALLOW_INHERITANCE -from mongoengine.base.datastructures import BaseDict, BaseList +from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict from mongoengine.base.fields import ComplexBaseField __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' - class BaseDocument(object): + __slots__ = ('_changed_fields', '_initialised', '_created', '_data', + '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') _dynamic = False - _created = True _dynamic_lock = True - _initialised = False + STRICT = False def __init__(self, *args, **values): """ @@ -39,6 +38,8 @@ class BaseDocument(object): :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ + self._initialised = False + self._created = True if args: # Combine positional arguments with named arguments. # We only want named arguments. @@ -54,7 +55,11 @@ class BaseDocument(object): __auto_convert = values.pop("__auto_convert", True) signals.pre_init.send(self.__class__, document=self, values=values) - self._data = {} + if self.STRICT and not self._dynamic: + self._data = StrictDict.create(allowed_keys=self._fields.keys())() + else: + self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())() + self._dynamic_fields = SON() # Assign default values to instance @@ -130,17 +135,25 @@ class BaseDocument(object): self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) + try: + self__created = self._created + except AttributeError: + self__created = True - if (self._is_document and not self._created and + if (self._is_document and not self__created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) + try: + self__initialised = self._initialised + except AttributeError: + self__initialised = False # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): + if (self._is_document and self__initialised + and self__created and name == self._meta['id_field']): super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -158,9 +171,11 @@ class BaseDocument(object): if isinstance(data["_data"], SON): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ('_changed_fields', '_initialised', '_created', '_data', - '_fields_ordered', '_dynamic_fields'): + '_dynamic_fields'): if k in data: setattr(self, k, data[k]) + if '_fields_ordered' in data: + setattr(type(self), '_fields_ordered', data['_fields_ordered']) dynamic_fields = data.get('_dynamic_fields') or SON() for k in dynamic_fields.keys(): setattr(self, k, data["_data"].get(k)) @@ -182,7 +197,7 @@ class BaseDocument(object): """Dictionary-style field access, set a field's value. """ # Ensure that the field exists before settings its value - if name not in self._fields: + if not self._dynamic and name not in self._fields: raise KeyError(name) return setattr(self, name, value) @@ -317,7 +332,7 @@ class BaseDocument(object): pk = "None" if hasattr(self, 'pk'): pk = self.pk - elif self._instance: + elif self._instance and hasattr(self._instance, 'pk'): pk = self._instance.pk message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) @@ -392,6 +407,8 @@ class BaseDocument(object): else: data = getattr(data, part, None) if hasattr(data, "_changed_fields"): + if hasattr(data, "_is_document") and data._is_document: + continue data._changed_fields = [] self._changed_fields = [] @@ -545,10 +562,6 @@ class BaseDocument(object): # class if unavailable class_name = son.get('_cls', cls._class_name) data = dict(("%s" % key, value) for key, value in son.iteritems()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) # Return correct subclass for document type if class_name != cls._class_name: @@ -586,6 +599,8 @@ class BaseDocument(object): % (cls._class_name, errors)) raise InvalidDocumentError(msg) + if cls.STRICT: + data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) obj = cls(__auto_convert=False, **data) obj._changed_fields = changed_fields obj._created = False @@ -825,7 +840,11 @@ class BaseDocument(object): """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): if field.choices: - setattr(self, + if self._dynamic: + obj = self + else: + obj = type(self) + setattr(obj, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index ff5afddf..4b2e8b9b 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -359,7 +359,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class.id = field # Set primary key if not defined by the document - new_class._auto_id_field = False + new_class._auto_id_field = getattr(parent_doc_cls, + '_auto_id_field', False) if not new_class._meta.get('id_field'): new_class._auto_id_field = True new_class._meta['id_field'] = 'id' diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 7cc626f4..d3efac62 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -93,20 +93,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): raise ConnectionError(msg) conn_settings = _connection_settings[alias].copy() - if hasattr(pymongo, 'version_tuple'): # Support for 2.1+ - conn_settings.pop('name', None) - conn_settings.pop('slaves', None) - conn_settings.pop('is_slave', None) - conn_settings.pop('username', None) - conn_settings.pop('password', None) - else: - # Get all the slave connections - if 'slaves' in conn_settings: - slaves = [] - for slave_alias in conn_settings['slaves']: - slaves.append(get_connection(slave_alias)) - conn_settings['slaves'] = slaves - conn_settings.pop('read_preference', None) + conn_settings.pop('name', None) + conn_settings.pop('slaves', None) + conn_settings.pop('is_slave', None) + conn_settings.pop('username', None) + conn_settings.pop('password', None) connection_class = MongoClient if 'replicaSet' in conn_settings: @@ -119,7 +110,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): connection_class = MongoReplicaSetClient try: - _connections[alias] = connection_class(**conn_settings) + connection = None + connection_settings_iterator = ((alias, settings.copy()) for alias, settings in _connection_settings.iteritems()) + for alias, connection_settings in connection_settings_iterator: + connection_settings.pop('name', None) + connection_settings.pop('slaves', None) + connection_settings.pop('is_slave', None) + connection_settings.pop('username', None) + connection_settings.pop('password', None) + if conn_settings == connection_settings and _connections.get(alias, None): + connection = _connections[alias] + break + + _connections[alias] = connection if connection else connection_class(**conn_settings) except Exception, e: raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) return _connections[alias] diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 13ed1009..cc860066 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,6 +1,5 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.queryset import QuerySet __all__ = ("switch_db", "switch_collection", "no_dereference", @@ -162,12 +161,6 @@ class no_sub_classes(object): return self.cls -class QuerySetNoDeRef(QuerySet): - """Special no_dereference QuerySet""" - def __dereference(items, max_depth=1, instance=None, name=None): - return items - - class query_counter(object): """ Query_counter context manager to get the number of queries. """ diff --git a/mongoengine/document.py b/mongoengine/document.py index 114778eb..7541ee57 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -13,7 +13,8 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) from mongoengine.errors import ValidationError -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.queryset import (OperationError, NotUniqueError, + QuerySet, transform) from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -53,16 +54,17 @@ class EmbeddedDocument(BaseDocument): `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. """ + + __slots__ = ('_instance') # The __metaclass__ attribute is removed by 2to3 when running with Python3 # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass - _instance = None - def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) + self._instance = None self._changed_fields = [] def __eq__(self, other): @@ -125,6 +127,8 @@ class Document(BaseDocument): my_metaclass = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass + __slots__ = ('__objects' ) + def pk(): """Primary key alias """ @@ -180,7 +184,7 @@ class Document(BaseDocument): def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, **kwargs): + _refs=None, save_condition=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -203,7 +207,8 @@ class Document(BaseDocument): :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves - + :param save_condition: only perform save if matching record in db + satisfies condition(s) (e.g., version number) .. versionchanged:: 0.5 In existing documents it only saves changed fields using set / unset. Saves are cascaded and any @@ -217,6 +222,9 @@ class Document(BaseDocument): meta['cascade'] = True. Also you can pass different kwargs to the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. + .. versionchanged:: 0.8.5 + Optional save_condition that only overwrites existing documents + if the condition is satisfied in the current db record. """ signals.pre_save.send(self.__class__, document=self) @@ -230,7 +238,8 @@ class Document(BaseDocument): created = ('_id' not in doc or self._created or force_insert) - signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + signals.pre_save_post_validation.send(self.__class__, document=self, + created=created) try: collection = self._get_collection() @@ -243,7 +252,12 @@ class Document(BaseDocument): object_id = doc['_id'] updates, removals = self._delta() # Need to add shard key to query, or you get an error - select_dict = {'_id': object_id} + if save_condition is not None: + select_dict = transform.query(self.__class__, + **save_condition) + else: + select_dict = {} + select_dict['_id'] = object_id shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: actual_key = self._db_field_map.get(k, k) @@ -263,10 +277,12 @@ class Document(BaseDocument): if removals: update_query["$unset"] = removals if updates or removals: + upsert = save_condition is None last_error = collection.update(select_dict, update_query, - upsert=True, **write_concern) + upsert=upsert, **write_concern) created = is_new_object(last_error) + if cascade is None: cascade = self._meta.get('cascade', False) or cascade_kwargs is not None @@ -293,12 +309,12 @@ class Document(BaseDocument): raise NotUniqueError(message % unicode(err)) raise OperationError(message % unicode(err)) id_field = self._meta['id_field'] - if id_field not in self._meta.get('shard_key', []): + if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) + signals.post_save.send(self.__class__, document=self, created=created) self._clear_changed_fields() self._created = False - signals.post_save.send(self.__class__, document=self, created=created) return self def cascade_save(self, *args, **kwargs): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 82642cda..abadad65 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -391,7 +391,7 @@ class DateTimeField(BaseField): if dateutil: try: return dateutil.parser.parse(value) - except ValueError: + except (TypeError, ValueError): return None # split usecs, because they are not recognized by strptime. @@ -760,7 +760,7 @@ class DictField(ComplexBaseField): similar to an embedded document, but the structure is not defined. .. note:: - Required means it cannot be empty - as the default for ListFields is [] + Required means it cannot be empty - as the default for DictFields is {} .. versionadded:: 0.3 .. versionchanged:: 0.5 - Can now handle complex / varying types of data @@ -1613,7 +1613,12 @@ class UUIDField(BaseField): class GeoPointField(BaseField): - """A list storing a latitude and longitude. + """A list storing a longitude and latitude coordinate. + + .. note:: this represents a generic point in a 2D plane and a legacy way of + representing a geo point. It admits 2d indexes but not "2dsphere" indexes + in MongoDB > 2.4 which are more natural for modeling geospatial points. + See :ref:`geospatial-indexes` .. versionadded:: 0.4 """ @@ -1635,7 +1640,7 @@ class GeoPointField(BaseField): class PointField(GeoJsonBaseField): - """A geo json field storing a latitude and longitude. + """A GeoJSON field storing a longitude and latitude coordinate. The data is represented as: @@ -1654,7 +1659,7 @@ class PointField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField): - """A geo json field storing a line of latitude and longitude coordinates. + """A GeoJSON field storing a line of longitude and latitude coordinates. The data is represented as: @@ -1672,7 +1677,7 @@ class LineStringField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField): - """A geo json field storing a polygon of latitude and longitude coordinates. + """A GeoJSON field storing a polygon of longitude and latitude coordinates. The data is represented as: diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 097740eb..2c4df00c 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -3,8 +3,6 @@ import sys PY3 = sys.version_info[0] == 3 -PY25 = sys.version_info[:2] == (2, 5) -UNICODE_KWARGS = int(''.join([str(x) for x in sys.version_info[:3]])) > 264 if PY3: import codecs @@ -29,33 +27,3 @@ else: txt_type = unicode str_types = (bin_type, txt_type) - -if PY25: - def product(*args, **kwds): - pools = map(tuple, args) * kwds.get('repeat', 1) - result = [[]] - for pool in pools: - result = [x + [y] for x in result for y in pool] - for prod in result: - yield tuple(prod) - reduce = reduce -else: - from itertools import product - from functools import reduce - - -# For use with Python 2.5 -# converts all keys from unicode to str for d and all nested dictionaries -def to_str_keys_recursive(d): - if isinstance(d, list): - for val in d: - if isinstance(val, (dict, list)): - to_str_keys_recursive(val) - elif isinstance(d, dict): - for key, val in d.items(): - if isinstance(val, (dict, list)): - to_str_keys_recursive(val) - if isinstance(key, unicode): - d[str(key)] = d.pop(key) - else: - raise ValueError("non list/dict parameter not allowed") diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 4bd7128e..4fb143bb 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -11,15 +11,16 @@ from bson import SON from bson.code import Code from bson import json_util import pymongo +import pymongo.errors from pymongo.common import validate_read_preference from mongoengine import signals from mongoengine.connection import get_db +from mongoengine.context_managers import switch_db from mongoengine.common import _import_class from mongoengine.base.common import get_document from mongoengine.errors import (OperationError, NotUniqueError, InvalidQueryError, LookUpError) - from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode @@ -52,7 +53,7 @@ class BaseQuerySet(object): self._initial_query = {} self._where_clause = None self._loaded_fields = QueryFieldList() - self._ordering = [] + self._ordering = None self._snapshot = False self._timeout = True self._class_check = True @@ -148,7 +149,7 @@ class BaseQuerySet(object): queryset._document._from_son(queryset._cursor[key], _auto_dereference=self._auto_dereference)) if queryset._as_pymongo: - return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._get_as_pymongo(queryset._cursor[key]) return queryset._document._from_son(queryset._cursor[key], _auto_dereference=self._auto_dereference) raise AttributeError @@ -156,6 +157,22 @@ class BaseQuerySet(object): def __iter__(self): raise NotImplementedError + def _has_data(self): + """ Retrieves whether cursor has any data. """ + + queryset = self.order_by() + return False if queryset.first() is None else True + + def __nonzero__(self): + """ Avoid to open all records in an if stmt in Py2. """ + + return self._has_data() + + def __bool__(self): + """ Avoid to open all records in an if stmt in Py3. """ + + return self._has_data() + # Core functions def all(self): @@ -177,7 +194,7 @@ class BaseQuerySet(object): .. versionadded:: 0.3 """ queryset = self.clone() - queryset = queryset.limit(2) + queryset = queryset.order_by().limit(2) queryset = queryset.filter(*q_objs, **query) try: @@ -391,7 +408,7 @@ class BaseQuerySet(object): ref_q = document_cls.objects(**{field_name + '__in': self}) ref_q_count = ref_q.count() if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): + or (doc == document_cls and ref_q_count > 0)): ref_q.delete(write_concern=write_concern) elif rule == NULLIFY: document_cls.objects(**{field_name + '__in': self}).update( @@ -445,6 +462,8 @@ class BaseQuerySet(object): return result elif result: return result['n'] + except pymongo.errors.DuplicateKeyError, err: + raise NotUniqueError(u'Update failed (%s)' % unicode(err)) except pymongo.errors.OperationFailure, err: if unicode(err) == u'multi not coded yet': message = u'update() method requires MongoDB 1.1.3+' @@ -468,6 +487,59 @@ class BaseQuerySet(object): return self.update( upsert=upsert, multi=False, write_concern=write_concern, **update) + def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): + """Update and return the updated document. + + Returns either the document before or after modification based on `new` + parameter. If no documents match the query and `upsert` is false, + returns ``None``. If upserting and `new` is false, returns ``None``. + + If the full_response parameter is ``True``, the return value will be + the entire response object from the server, including the 'ok' and + 'lastErrorObject' fields, rather than just the modified document. + This is useful mainly because the 'lastErrorObject' document holds + information about the command's execution. + + :param upsert: insert if document doesn't exist (default ``False``) + :param full_response: return the entire response object from the + server (default ``False``) + :param remove: remove rather than updating (default ``False``) + :param new: return updated rather than original document + (default ``False``) + :param update: Django-style update keyword arguments + + .. versionadded:: 0.9 + """ + + if remove and new: + raise OperationError("Conflicting parameters: remove and new") + + if not update and not upsert and not remove: + raise OperationError("No update parameters, must either update or remove") + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + sort = queryset._ordering + + try: + result = queryset._collection.find_and_modify( + query, update, upsert=upsert, sort=sort, remove=remove, new=new, + full_response=full_response, **self._cursor_args) + except pymongo.errors.DuplicateKeyError, err: + raise NotUniqueError(u"Update failed (%s)" % err) + except pymongo.errors.OperationFailure, err: + raise OperationError(u"Update failed (%s)" % err) + + if full_response: + if result["value"] is not None: + result["value"] = self._document._from_son(result["value"]) + else: + if result is not None: + result = self._document._from_son(result) + + return result + def with_id(self, object_id): """Retrieve the object matching the id provided. Uses `object_id` only and raises InvalidQueryError if a filter has been applied. Returns @@ -524,6 +596,19 @@ class BaseQuerySet(object): return self + def using(self, alias): + """This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database. + + :param alias: The database alias + + .. versionadded:: 0.8 + """ + + with switch_db(self._document, alias) as cls: + collection = cls._get_collection() + + return self.clone_into(self.__class__(self._document, collection)) + def clone(self): """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet` @@ -928,7 +1013,7 @@ class BaseQuerySet(object): if isinstance(output, basestring): mr_args['out'] = output - + elif isinstance(output, dict): ordered_output = [] @@ -937,27 +1022,27 @@ class BaseQuerySet(object): if value: ordered_output.append((part, value)) break - + else: raise OperationError("actionData not specified for output") db_alias = output.get('db_alias') remaing_args = ['db', 'sharded', 'nonAtomic'] - + if db_alias: ordered_output.append(('db', get_db(db_alias).name)) del remaing_args[0] - + for part in remaing_args: value = output.get(part) if value: ordered_output.append((part, value)) - + mr_args['out'] = SON(ordered_output) results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) + map_f, reduce_f, **mr_args) if map_reduce_function == 'map_reduce': results = results.find() @@ -1220,8 +1305,9 @@ class BaseQuerySet(object): if self._ordering: # Apply query ordering self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model + elif self._ordering is None and self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model, unless + # it's been explicitly cleared via order_by with no arguments order = self._get_order_by(self._document._meta['ordering']) self._cursor_obj.sort(order) @@ -1393,7 +1479,7 @@ class BaseQuerySet(object): for subdoc in subclasses: try: subfield = ".".join(f.db_field for f in - subdoc._lookup_field(field.split('.'))) + subdoc._lookup_field(field.split('.'))) ret.append(subfield) found = True break @@ -1423,7 +1509,7 @@ class BaseQuerySet(object): pass key_list.append((key, direction)) - if self._cursor_obj: + if self._cursor_obj and key_list: self._cursor_obj.sort(key_list) return key_list @@ -1481,6 +1567,7 @@ class BaseQuerySet(object): # type of this field and use the corresponding # .to_python(...) from mongoengine.fields import EmbeddedDocumentField + obj = self._document for chunk in path.split('.'): obj = getattr(obj, chunk, None) @@ -1491,6 +1578,7 @@ class BaseQuerySet(object): if obj and data is not None: data = obj.to_python(data) return data + return clean(row) def _sub_js_fields(self, code): @@ -1499,6 +1587,7 @@ class BaseQuerySet(object): substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ + def field_sub(match): # Extract just the field name, and look up the field objects field_name = match.group(1).split('.') diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 1437e76b..cebfcc50 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet): queryset = self.clone() queryset.rewind() return queryset + + +class QuerySetNoDeRef(QuerySet): + """Special no_dereference QuerySet""" + + def __dereference(items, max_depth=1, instance=None, name=None): + return items \ No newline at end of file diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index e31a8b7d..8e88e9fe 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -3,6 +3,7 @@ from collections import defaultdict import pymongo from bson import SON +from mongoengine.connection import get_connection from mongoengine.common import _import_class from mongoengine.errors import InvalidQueryError, LookUpError @@ -38,7 +39,7 @@ def query(_doc_cls=None, _field_operation=False, **query): mongo_query.update(value) continue - parts = key.split('__') + parts = key.rsplit('__') indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is @@ -115,14 +116,21 @@ def query(_doc_cls=None, _field_operation=False, **query): if key in mongo_query and isinstance(mongo_query[key], dict): mongo_query[key].update(value) # $maxDistance needs to come last - convert to SON - if '$maxDistance' in mongo_query[key]: - value_dict = mongo_query[key] + value_dict = mongo_query[key] + if ('$maxDistance' in value_dict and '$near' in value_dict and + isinstance(value_dict['$near'], dict)): + value_son = SON() for k, v in value_dict.iteritems(): if k == '$maxDistance': continue value_son[k] = v - value_son['$maxDistance'] = value_dict['$maxDistance'] + if (get_connection().max_wire_version <= 1): + value_son['$maxDistance'] = value_dict['$maxDistance'] + else: + value_son['$near'] = SON(value_son['$near']) + value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] + mongo_query[key] = value_son else: # Store for manually merging later diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 41f4ebf8..a39b05f0 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -1,8 +1,9 @@ import copy -from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import product, reduce +from itertools import product +from functools import reduce +from mongoengine.errors import InvalidQueryError from mongoengine.queryset import transform __all__ = ('Q',) diff --git a/setup.py b/setup.py index 85707d00..7270331a 100644 --- a/setup.py +++ b/setup.py @@ -38,12 +38,14 @@ CLASSIFIERS = [ 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.6", + "Programming Language :: Python :: 2.6.6", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.1", "Programming Language :: Python :: 3.2", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", 'Topic :: Database', 'Topic :: Software Development :: Libraries :: Python Modules', ] @@ -51,12 +53,15 @@ CLASSIFIERS = [ extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'django>=1.5.1'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'jinja2==2.6', 'Pillow>=2.0.0', 'django>=1.5.1'] if "test" in sys.argv or "nosetests" in sys.argv: extra_opts['packages'] = find_packages() extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} else: - extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'PIL', 'jinja2>=2.6', 'python-dateutil'] + extra_opts['tests_require'] = ['nose', 'coverage', 'blinker', 'django>=1.4.2', 'Pillow>=2.0.0', 'jinja2>=2.6', 'python-dateutil'] + + if sys.version_info[0] == 2 and sys.version_info[1] == 6: + extra_opts['tests_require'].append('unittest2') setup(name='mongoengine', version=VERSION, @@ -72,7 +77,7 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=2.5'], + install_requires=['pymongo>=2.7'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/delta.py b/tests/document/delta.py index b0f5f01a..738dfa78 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -735,5 +735,47 @@ class DeltaTest(unittest.TestCase): mydoc._clear_changed_fields() self.assertEqual([], mydoc._get_changed_fields()) + def test_referenced_object_changed_attributes(self): + """Ensures that when you save a new reference to a field, the referenced object isn't altered""" + + class Organization(Document): + name = StringField() + + class User(Document): + name = StringField() + org = ReferenceField('Organization', required=True) + + Organization.drop_collection() + User.drop_collection() + + org1 = Organization(name='Org 1') + org1.save() + + org2 = Organization(name='Org 2') + org2.save() + + user = User(name='Fred', org=org1) + user.save() + + org1.reload() + org2.reload() + user.reload() + self.assertEqual(org1.name, 'Org 1') + self.assertEqual(org2.name, 'Org 2') + self.assertEqual(user.name, 'Fred') + + user.name = 'Harold' + user.org = org2 + + org2.name = 'New Org 2' + self.assertEqual(org2.name, 'New Org 2') + + user.save() + org2.save() + + self.assertEqual(org2.name, 'New Org 2') + org2.reload() + self.assertEqual(org2.name, 'New Org 2') + if __name__ == '__main__': unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 6263e68c..bf69cb27 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -292,6 +292,44 @@ class DynamicTest(unittest.TestCase): person.save() self.assertEqual(Person.objects.first().age, 35) + def test_dynamic_and_embedded_dict_access(self): + """Ensure embedded dynamic documents work with dict[] style access""" + + class Address(EmbeddedDocument): + city = StringField() + + class Person(DynamicDocument): + name = StringField() + + Person.drop_collection() + + Person(name="Ross", address=Address(city="London")).save() + + person = Person.objects.first() + person.attrval = "This works" + + person["phone"] = "555-1212" # but this should too + + # Same thing two levels deep + person["address"]["city"] = "Lundenne" + person.save() + + self.assertEqual(Person.objects.first().address.city, "Lundenne") + + self.assertEqual(Person.objects.first().phone, "555-1212") + + person = Person.objects.first() + person.address = Address(city="Londinium") + person.save() + + self.assertEqual(Person.objects.first().address.city, "Londinium") + + + person = Person.objects.first() + person["age"] = 35 + person.save() + self.assertEqual(Person.objects.first().age, 35) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..54758955 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -15,7 +15,7 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, - InvalidQueryError) + InvalidQueryError, NotUniqueError) from mongoengine.queryset import NULLIFY, Q from mongoengine.connection import get_db from mongoengine.base import get_document @@ -57,7 +57,7 @@ class InstanceTest(unittest.TestCase): date = DateTimeField(default=datetime.now) meta = { 'max_documents': 10, - 'max_size': 90000, + 'max_size': 4096, } Log.drop_collection() @@ -75,7 +75,7 @@ class InstanceTest(unittest.TestCase): options = Log.objects._collection.options() self.assertEqual(options['capped'], True) self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 90000) + self.assertTrue(options['size'] >= 4096) # Check that the document cannot be redefined with different options def recreate_log_document(): @@ -820,6 +820,80 @@ class InstanceTest(unittest.TestCase): p1.reload() self.assertEqual(p1.name, p.parent.name) + def test_save_atomicity_condition(self): + + class Widget(Document): + toggle = BooleanField(default=False) + count = IntField(default=0) + save_id = UUIDField() + + def flip(widget): + widget.toggle = not widget.toggle + widget.count += 1 + + def UUID(i): + return uuid.UUID(int=i) + + Widget.drop_collection() + + w1 = Widget(toggle=False, save_id=UUID(1)) + + # ignore save_condition on new record creation + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.save_id, UUID(1)) + self.assertEqual(w1.count, 0) + + # mismatch in save_condition prevents save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 0) + + # matched save_condition allows save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(1)}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + + # save_condition can be used to ensure atomic read & updates + # i.e., prevent interleaved reads and writes from separate contexts + w2 = Widget.objects.get() + self.assertEqual(w1, w2) + old_id = w1.save_id + + flip(w1) + w1.save_id = UUID(2) + w1.save(save_condition={'save_id':old_id}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 2) + flip(w2) + flip(w2) + w2.save(save_condition={'save_id':old_id}) + w2.reload() + self.assertFalse(w2.toggle) + self.assertEqual(w2.count, 2) + + # save_condition uses mongoengine-style operator syntax + flip(w1) + w1.save(save_condition={'count__lt':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + flip(w1) + w1.save(save_condition={'count__gte':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" @@ -990,6 +1064,16 @@ class InstanceTest(unittest.TestCase): self.assertRaises(InvalidQueryError, update_no_op_raises) + def test_update_unique_field(self): + class Doc(Document): + name = StringField(unique=True) + + doc1 = Doc(name="first").save() + doc2 = Doc(name="second").save() + + self.assertRaises(NotUniqueError, lambda: + doc2.update(set__name=doc1.name)) + def test_embedded_update(self): """ Test update on `EmbeddedDocumentField` fields @@ -2281,6 +2365,8 @@ class InstanceTest(unittest.TestCase): log.machine = "Localhost" log.save() + self.assertTrue(log.id is not None) + log.log = "Saving" log.save() @@ -2304,6 +2390,8 @@ class InstanceTest(unittest.TestCase): log.machine = "Localhost" log.save() + self.assertTrue(log.id is not None) + log.log = "Saving" log.save() @@ -2411,7 +2499,7 @@ class InstanceTest(unittest.TestCase): for parameter_name, parameter in self.parameters.iteritems(): parameter.expand() - class System(Document): + class NodesSystem(Document): name = StringField(required=True) nodes = MapField(ReferenceField(Node, dbref=False)) @@ -2419,18 +2507,18 @@ class InstanceTest(unittest.TestCase): for node_name, node in self.nodes.iteritems(): node.expand() node.save(*args, **kwargs) - super(System, self).save(*args, **kwargs) + super(NodesSystem, self).save(*args, **kwargs) - System.drop_collection() + NodesSystem.drop_collection() Node.drop_collection() - system = System(name="system") + system = NodesSystem(name="system") system.nodes["node"] = Node() system.save() system.nodes["node"].parameters["param"] = Parameter() system.save() - system = System.objects.first() + system = NodesSystem.objects.first() self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) def test_embedded_document_equality(self): diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 902b1512..7ae53e8a 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -279,7 +279,7 @@ class FileTest(unittest.TestCase): t.image.put(f) self.fail("Should have raised an invalidation error") except ValidationError, e: - self.assertEqual("%s" % e, "Invalid image: cannot identify image file") + self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() t.image.put(open(TEST_IMAGE_PATH, 'rb')) diff --git a/tests/queryset/__init__.py b/tests/queryset/__init__.py index 8a93c19f..c36b2684 100644 --- a/tests/queryset/__init__.py +++ b/tests/queryset/__init__.py @@ -3,3 +3,4 @@ from field_list import * from queryset import * from visitor import * from geo import * +from modify import * \ No newline at end of file diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 65ab519a..5148a48e 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -5,6 +5,8 @@ import unittest from datetime import datetime, timedelta from mongoengine import * +from nose.plugins.skip import SkipTest + __all__ = ("GeoQueriesTest",) @@ -139,6 +141,7 @@ class GeoQueriesTest(unittest.TestCase): def test_spherical_geospatial_operators(self): """Ensure that spherical geospatial queries are working """ + raise SkipTest("https://jira.mongodb.org/browse/SERVER-14039") class Point(Document): location = GeoPointField() diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py new file mode 100644 index 00000000..e0c7d1fe --- /dev/null +++ b/tests/queryset/modify.py @@ -0,0 +1,102 @@ +import sys +sys.path[0:0] = [""] + +import unittest + +from mongoengine import connect, Document, IntField + +__all__ = ("FindAndModifyTest",) + + +class Doc(Document): + id = IntField(primary_key=True) + value = IntField() + + +class FindAndModifyTest(unittest.TestCase): + + def setUp(self): + connect(db="mongoenginetest") + Doc.drop_collection() + + def assertDbEqual(self, docs): + self.assertEqual(list(Doc._collection.find().sort("id")), docs) + + def test_modify(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_new(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) + doc.value = -1 + self.assertEqual(new_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_not_existing(self): + Doc(id=0, value=0).save() + self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_modify_with_upsert(self): + Doc(id=0, value=0).save() + old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) + self.assertEqual(old_doc, None) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_upsert_existing(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_upsert_with_new(self): + Doc(id=0, value=0).save() + new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) + self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1}) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_remove(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(remove=True) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_find_and_modify_with_remove_not_existing(self): + Doc(id=0, value=0).save() + self.assertEqual(Doc.objects(id=1).modify(remove=True), None) + self.assertDbEqual([{"_id": 0, "value": 0}]) + + def test_modify_with_order_by(self): + Doc(id=0, value=3).save() + Doc(id=1, value=2).save() + Doc(id=2, value=1).save() + doc = Doc(id=3, value=0).save() + + old_doc = Doc.objects().order_by("-id").modify(set__value=-1) + self.assertEqual(old_doc.to_json(), doc.to_json()) + self.assertDbEqual([ + {"_id": 0, "value": 3}, {"_id": 1, "value": 2}, + {"_id": 2, "value": 1}, {"_id": 3, "value": -1}]) + + def test_modify_with_fields(self): + Doc(id=0, value=0).save() + Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) + self.assertEqual(old_doc.to_mongo(), {"_id": 1}) + self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 2fcd466c..62a142f8 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -16,7 +16,7 @@ from bson import ObjectId from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.python_support import PY3 -from mongoengine.context_managers import query_counter +from mongoengine.context_managers import query_counter, switch_db from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, queryset_manager) @@ -25,10 +25,17 @@ from mongoengine.errors import InvalidQueryError __all__ = ("QuerySetTest",) +class db_ops_tracker(query_counter): + def get_ops(self): + ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + return list(self.db.system.profile.find(ignore_query)) + + class QuerySetTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') + connect(db='mongoenginetest2', alias='test2') class PersonMeta(EmbeddedDocument): weight = IntField() @@ -650,7 +657,10 @@ class QuerySetTest(unittest.TestCase): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) - self.assertEqual(q, 1) # 1 for the insert + if (get_connection().max_wire_version <= 1): + self.assertEqual(q, 1) + else: + self.assertEqual(q, 99) # profiling logs each doc now in the bulk op Blog.drop_collection() Blog.ensure_indexes() @@ -659,7 +669,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) Blog.objects.insert(blogs) - self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch + if (get_connection().max_wire_version <= 1): + self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch + else: + self.assertEqual(q, 100) # 99 for insert, and 1 for in bulk fetch Blog.drop_collection() @@ -1040,6 +1053,54 @@ class QuerySetTest(unittest.TestCase): expected = [blog_post_1, blog_post_2, blog_post_3] self.assertSequence(qs, expected) + def test_clear_ordering(self): + """ Ensure that the default ordering can be cleared by calling order_by(). + """ + class BlogPost(Document): + title = StringField() + published_date = DateTimeField() + + meta = { + 'ordering': ['-published_date'] + } + + BlogPost.drop_collection() + + with db_ops_tracker() as q: + BlogPost.objects.filter(title='whatever').first() + self.assertEqual(len(q.get_ops()), 1) + self.assertEqual(q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) + + with db_ops_tracker() as q: + BlogPost.objects.filter(title='whatever').order_by().first() + self.assertEqual(len(q.get_ops()), 1) + print q.get_ops()[0]['query'] + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + + def test_no_ordering_for_get(self): + """ Ensure that Doc.objects.get doesn't use any ordering. + """ + class BlogPost(Document): + title = StringField() + published_date = DateTimeField() + + meta = { + 'ordering': ['-published_date'] + } + + BlogPost.objects.create(title='whatever', published_date=datetime.utcnow()) + + with db_ops_tracker() as q: + BlogPost.objects.get(title='whatever') + self.assertEqual(len(q.get_ops()), 1) + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + + # Ordering should be ignored for .get even if we set it explicitly + with db_ops_tracker() as q: + BlogPost.objects.order_by('-title').get(title='whatever') + self.assertEqual(len(q.get_ops()), 1) + self.assertFalse('$orderby' in q.get_ops()[0]['query']) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -1930,7 +1991,7 @@ class QuerySetTest(unittest.TestCase): Test map/reduce custom output """ register_connection('test2', 'mongoenginetest2') - + class Family(Document): id = IntField( primary_key=True) @@ -1942,10 +2003,10 @@ class QuerySetTest(unittest.TestCase): name = StringField() age = IntField() family = ReferenceField(Family) - + Family.drop_collection() Person.drop_collection() - + # creating first family f1 = Family(id=1, log="Trav 02 de Julho") f1.save() @@ -1964,7 +2025,7 @@ class QuerySetTest(unittest.TestCase): Person(id=5, family=f2, name="Isabella Luanna", age=16).save() Person(id=6, family=f2, name="Sandra Mara", age=36).save() Person(id=7, family=f2, name="Igor Gabriel", age=10).save() - + # creating third family f3 = Family(id=3, log="Av brazil") f3.save() @@ -1997,7 +2058,7 @@ class QuerySetTest(unittest.TestCase): reduce_f = """ function (key, values) { var family = {persons: [], totalAge: 0}; - + values.forEach(function(value) { if (value.persons) { value.persons.forEach(function (person) { @@ -2025,7 +2086,7 @@ class QuerySetTest(unittest.TestCase): results = list(results) collection = get_db('test2').family_map - + self.assertEqual( collection.find_one({'_id': 1}), { '_id': 1, @@ -2058,7 +2119,7 @@ class QuerySetTest(unittest.TestCase): {'age': 25, 'name': u'Paula Leonel'}], 'totalAge': 55} }) - + def test_map_reduce_finalize(self): """Ensure that map, reduce, and finalize run and introduce "scope" by simulating "hotness" ranking with Reddit algorithm. @@ -3091,6 +3152,23 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_using(self): + """Ensure that switching databases for a queryset is possible + """ + class Number2(Document): + n = IntField() + + Number2.drop_collection() + with switch_db(Number2, 'test2') as Number2: + Number2.drop_collection() + + for i in range(1, 10): + t = Number2(n=i) + t.switch_db('test2') + t.save() + + self.assertEqual(len(Number2.objects.using('test2')), 9) + def test_unset_reference(self): class Comment(Document): text = StringField() @@ -3720,7 +3798,13 @@ class QuerySetTest(unittest.TestCase): [x for x in people] self.assertEqual(100, len(people._result_cache)) - self.assertEqual(None, people._len) + + import platform + + if platform.python_implementation() != "PyPy": + # PyPy evaluates __len__ when iterating with list comprehensions while CPython does not. + # This may be a bug in PyPy (PyPy/#1802) but it does not affect the behavior of MongoEngine. + self.assertEqual(None, people._len) self.assertEqual(q, 1) list(people) @@ -3948,6 +4032,111 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Example.objects(size=instance_size).count(), 1) self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) + def test_cursor_in_an_if_stmt(self): + + class Test(Document): + test_field = StringField() + + Test.drop_collection() + queryset = Test.objects + + if queryset: + raise AssertionError('Empty cursor returns True') + + test = Test() + test.test_field = 'test' + test.save() + + queryset = Test.objects + if not test: + raise AssertionError('Cursor has data and returned False') + + queryset.next() + if not queryset: + raise AssertionError('Cursor has data and it must returns True,' + ' even in the last item.') + + def test_bool_performance(self): + + class Person(Document): + name = StringField() + + Person.drop_collection() + for i in xrange(100): + Person(name="No: %s" % i).save() + + with query_counter() as q: + if Person.objects: + pass + + self.assertEqual(q, 1) + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertEqual(op['nreturned'], 1) + + + def test_bool_with_ordering(self): + + class Person(Document): + name = StringField() + + Person.drop_collection() + Person(name="Test").save() + + qs = Person.objects.order_by('name') + + with query_counter() as q: + + if qs: + pass + + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertFalse('$orderby' in op['query'], + 'BaseQuerySet cannot use orderby in if stmt') + + with query_counter() as p: + + for x in qs: + pass + + op = p.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertTrue('$orderby' in op['query'], + 'BaseQuerySet cannot remove orderby in for loop') + + def test_bool_with_ordering_from_meta_dict(self): + + class Person(Document): + name = StringField() + meta = { + 'ordering': ['name'] + } + + Person.drop_collection() + + Person(name="B").save() + Person(name="C").save() + Person(name="A").save() + + with query_counter() as q: + + if Person.objects: + pass + + op = q.db.system.profile.find({"ns": + {"$ne": "%s.system.indexes" % q.db.name}})[0] + + self.assertFalse('$orderby' in op['query'], + 'BaseQuerySet must remove orderby from meta in boolen test') + + self.assertEqual(Person.objects.first().name, 'A') + self.assertTrue(Person.objects._has_data(), + 'Cursor has data and returned False') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_connection.py b/tests/test_connection.py index 96135bc5..a5b1b089 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,11 @@ import sys sys.path[0:0] = [""] -import unittest + +try: + import unittest2 as unittest +except ImportError: + import unittest + import datetime import pymongo @@ -34,6 +39,17 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + def test_sharing_connections(self): + """Ensure that connections are shared when the connection settings are exactly the same + """ + connect('mongoenginetest', alias='testdb1') + + expected_connection = get_connection('testdb1') + + connect('mongoenginetest', alias='testdb2') + actual_connection = get_connection('testdb2') + self.assertIs(expected_connection, actual_connection) + def test_connect_uri(self): """Ensure that the connect() method works properly with uri's """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py new file mode 100644 index 00000000..c761a41e --- /dev/null +++ b/tests/test_datastructures.py @@ -0,0 +1,107 @@ +import unittest +from mongoengine.base.datastructures import StrictDict, SemiStrictDict + +class TestStrictDict(unittest.TestCase): + def strict_dict_class(self, *args, **kwargs): + return StrictDict.create(*args, **kwargs) + def setUp(self): + self.dtype = self.strict_dict_class(("a", "b", "c")) + def test_init(self): + d = self.dtype(a=1, b=1, c=1) + self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + + def test_init_fails_on_nonexisting_attrs(self): + self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) + + def test_eq(self): + d = self.dtype(a=1, b=1, c=1) + dd = self.dtype(a=1, b=1, c=1) + e = self.dtype(a=1, b=1, c=3) + f = self.dtype(a=1, b=1) + g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) + h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) + i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) + + self.assertEqual(d, dd) + self.assertNotEqual(d, e) + self.assertNotEqual(d, f) + self.assertNotEqual(d, g) + self.assertNotEqual(f, d) + self.assertEqual(d, h) + self.assertNotEqual(d, i) + + def test_setattr_getattr(self): + d = self.dtype() + d.a = 1 + self.assertEqual(d.a, 1) + self.assertRaises(AttributeError, lambda: d.b) + + def test_setattr_raises_on_nonexisting_attr(self): + d = self.dtype() + def _f(): + d.x=1 + self.assertRaises(AttributeError, _f) + + def test_setattr_getattr_special(self): + d = self.strict_dict_class(["items"]) + d.items = 1 + self.assertEqual(d.items, 1) + + def test_get(self): + d = self.dtype(a=1) + self.assertEqual(d.get('a'), 1) + self.assertEqual(d.get('b', 'bla'), 'bla') + + def test_items(self): + d = self.dtype(a=1) + self.assertEqual(d.items(), [('a', 1)]) + d = self.dtype(a=1, b=2) + self.assertEqual(d.items(), [('a', 1), ('b', 2)]) + + def test_mappings_protocol(self): + d = self.dtype(a=1, b=2) + assert dict(d) == {'a': 1, 'b': 2} + assert dict(**d) == {'a': 1, 'b': 2} + + +class TestSemiSrictDict(TestStrictDict): + def strict_dict_class(self, *args, **kwargs): + return SemiStrictDict.create(*args, **kwargs) + + def test_init_fails_on_nonexisting_attrs(self): + # disable irrelevant test + pass + + def test_setattr_raises_on_nonexisting_attr(self): + # disable irrelevant test + pass + + def test_setattr_getattr_nonexisting_attr_succeeds(self): + d = self.dtype() + d.x = 1 + self.assertEqual(d.x, 1) + + def test_init_succeeds_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) + + def test_iter_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d), ['a', 'b', 'c', 'x']) + + def test_iteritems_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)]) + + def tets_cmp_with_strict_dicts(self): + d = self.dtype(a=1, b=1, c=1) + dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1) + self.assertEqual(d, dd) + + def test_cmp_with_strict_dict_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2) + self.assertEqual(d, dd) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_signals.py b/tests/test_signals.py index 50e5e6b8..3d0cbb3e 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -54,7 +54,9 @@ class SignalTests(unittest.TestCase): @classmethod def post_save(cls, sender, document, **kwargs): + dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() signal_output.append('post_save signal, %s' % document) + signal_output.append('post_save dirty keys, %s' % dirty_keys) if 'created' in kwargs: if kwargs['created']: signal_output.append('Is created') @@ -203,6 +205,7 @@ class SignalTests(unittest.TestCase): "pre_save_post_validation signal, Bill Shakespeare", "Is created", "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", "Is created" ]) @@ -213,6 +216,7 @@ class SignalTests(unittest.TestCase): "pre_save_post_validation signal, William Shakespeare", "Is updated", "post_save signal, William Shakespeare", + "post_save dirty keys, ['name']", "Is updated" ])