diff --git a/.travis.yml b/.travis.yml index 34702192..0af5c269 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ python: - '3.2' - '3.3' - '3.4' +- '3.5' - pypy - pypy3 env: @@ -24,7 +25,8 @@ install: - sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev python-tk -- travis_retry pip install tox>=1.9 coveralls +# virtualenv>=14.0.0 has dropped Python 3.2 support +- travis_retry pip install "virtualenv<14.0.0" "tox>=1.9" coveralls - travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test script: - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage diff --git a/AUTHORS b/AUTHORS index 8b8a9264..e545b851 100644 --- a/AUTHORS +++ b/AUTHORS @@ -230,3 +230,14 @@ that much better: * Amit Lichtenberg (https://github.com/amitlicht) * Gang Li (https://github.com/iici-gli) * Lars Butler (https://github.com/larsbutler) + * George Macon (https://github.com/gmacon) + * Ashley Whetter (https://github.com/AWhetter) + * Paul-Armand Verhaegen (https://github.com/paularmand) + * Steven Rossiter (https://github.com/BeardedSteve) + * Luo Peng (https://github.com/RussellLuo) + * Bryan Bennett (https://github.com/bbenne10) + * Gilb's Gilb's (https://github.com/gilbsgilbs) + * Joshua Nedrud (https://github.com/Neurostack) + * Shu Shen (https://github.com/shushen) + * xiaost7 (https://github.com/xiaost7) + * Victor Varvaryuk diff --git a/README.rst b/README.rst index f4c92d5f..547ecbd9 100644 --- a/README.rst +++ b/README.rst @@ -19,10 +19,10 @@ MongoEngine About ===== MongoEngine is a Python Object-Document Mapper for working with MongoDB. -Documentation available at http://mongoengine-odm.rtfd.org - there is currently -a `tutorial `_, a `user guide -`_ and an `API reference -`_. +Documentation available at https://mongoengine-odm.readthedocs.io - there is currently +a `tutorial `_, a `user guide +`_ and an `API reference +`_. Installation ============ @@ -48,7 +48,9 @@ Optional Dependencies Examples ======== -Some simple examples of what MongoEngine code looks like:: +Some simple examples of what MongoEngine code looks like: + +.. code :: python class BlogPost(Document): title = StringField(required=True, max_length=200) @@ -97,7 +99,7 @@ Some simple examples of what MongoEngine code looks like:: Tests ===== To run the test suite, ensure you are running a local instance of MongoDB on -the standard port, and run: ``python setup.py nosetests``. +the standard port and have installed ``nose`` and ``rednose``, and run: ``python setup.py nosetests``. To run the test suite on every supported Python version and every supported PyMongo version, you can use ``tox``. diff --git a/docs/changelog.rst b/docs/changelog.rst index 7d3cfa84..b54b5cb6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,8 +2,50 @@ Changelog ========= -Changes in 0.10.1 - DEV +Changes in 0.10.7 - DEV ======================= +- Fixed the bug where dynamic doc has index inside a dict field #1278 +- Fixed not being able to specify `use_db_field=False` on `ListField(EmbeddedDocumentField)` instances +- Fixed cascade delete mixing among collections #1224 +- Add `signal_kwargs` argument to `Document.save`, `Document.delete` and `BaseQuerySet.insert` to be passed to signals calls #1206 +- Raise `OperationError` when trying to do a `drop_collection` on document with no collection set. +- count on ListField of EmbeddedDocumentField fails. #1187 +- Fixed long fields stored as int32 in Python 3. #1253 +- MapField now handles unicodes keys correctly. #1267 +- ListField now handles negative indicies correctly. #1270 +- Fixed AttributeError when initializing EmbeddedDocument with positional args. #681 +- Fixed no_cursor_timeout error with pymongo 3.0+ #1304 +- Replaced map-reduce based QuerySet.sum/average with aggregation-based implementations #1336 +- Fixed support for `__` to escape field names that match operators names in `update` #1351 + +Changes in 0.10.6 +================= +- Add support for mocking MongoEngine based on mongomock. #1151 +- Fixed not being able to run tests on Windows. #1153 +- Allow creation of sparse compound indexes. #1114 +- count on ListField of EmbeddedDocumentField fails. #1187 + +Changes in 0.10.5 +================= +- Fix for reloading of strict with special fields. #1156 + +Changes in 0.10.4 +================= +- SaveConditionError is now importable from the top level package. #1165 +- upsert_one method added. #1157 + +Changes in 0.10.3 +================= +- Fix `read_preference` (it had chaining issues with PyMongo 2.x and it didn't work at all with PyMongo 3.x) #1042 + +Changes in 0.10.2 +================= +- Allow shard key to point to a field in an embedded document. #551 +- Allow arbirary metadata in fields. #1129 +- ReferenceFields now support abstract document types. #837 + +Changes in 0.10.1 +================= - Fix infinite recursion with CASCADE delete rules under specific conditions. #1046 - Fix CachedReferenceField bug when loading cached docs as DBRef but failing to save them. #1047 - Fix ignored chained options #842 @@ -13,6 +55,8 @@ Changes in 0.10.1 - DEV - Fix ListField minus index assignment does not work. #1119 - Remove code that marks field as changed when the field has default but not existed in database #1126 - Remove test dependencies (nose and rednose) from install dependencies list. #1079 +- Recursively build query when using elemMatch operator. #1130 +- Fix instance back references for lists of embedded documents. #1131 Changes in 0.10.0 ================= diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py index 0e40e899..c10160ea 100644 --- a/docs/code/tumblelog.py +++ b/docs/code/tumblelog.py @@ -17,6 +17,10 @@ class Post(Document): tags = ListField(StringField(max_length=30)) comments = ListField(EmbeddedDocumentField(Comment)) + # bugfix + meta = {'allow_inheritance': True} + + class TextPost(Post): content = StringField() @@ -45,7 +49,8 @@ print 'ALL POSTS' print for post in Post.objects: print post.title - print '=' * post.title.count() + #print '=' * post.title.count() + print "=" * 20 if isinstance(post, TextPost): print post.content diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 8f7382ee..6ac88f01 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -29,7 +29,7 @@ documents are serialized based on their field order. Dynamic document schemas ======================== -One of the benefits of MongoDb is dynamic schemas for a collection, whilst data +One of the benefits of MongoDB is dynamic schemas for a collection, whilst data should be planned and organised (after all explicit is better than implicit!) there are scenarios where having dynamic / expando style documents is desirable. @@ -75,6 +75,7 @@ are as follows: * :class:`~mongoengine.fields.DynamicField` * :class:`~mongoengine.fields.EmailField` * :class:`~mongoengine.fields.EmbeddedDocumentField` +* :class:`~mongoengine.fields.EmbeddedDocumentListField` * :class:`~mongoengine.fields.FileField` * :class:`~mongoengine.fields.FloatField` * :class:`~mongoengine.fields.GenericEmbeddedDocumentField` @@ -172,11 +173,11 @@ arguments can be set on all fields: class Shirt(Document): size = StringField(max_length=3, choices=SIZE) -:attr:`help_text` (Default: None) - Optional help text to output with the field -- used by form libraries - -:attr:`verbose_name` (Default: None) - Optional human-readable name for the field -- used by form libraries +:attr:`**kwargs` (Optional) + You can supply additional metadata as arbitrary additional keyword + arguments. You can not override existing attributes, however. Common + choices include `help_text` and `verbose_name`, commonly used by form and + widget libraries. List fields diff --git a/docs/guide/index.rst b/docs/guide/index.rst index c4077888..46eb7af2 100644 --- a/docs/guide/index.rst +++ b/docs/guide/index.rst @@ -13,3 +13,4 @@ User Guide gridfs signals text-indexes + mongomock diff --git a/docs/guide/mongomock.rst b/docs/guide/mongomock.rst new file mode 100644 index 00000000..1d5227ec --- /dev/null +++ b/docs/guide/mongomock.rst @@ -0,0 +1,21 @@ +============================== +Use mongomock for testing +============================== + +`mongomock `_ is a package to do just +what the name implies, mocking a mongo database. + +To use with mongoengine, simply specify mongomock when connecting with +mongoengine: + +.. code-block:: python + + connect('mongoenginetest', host='mongomock://localhost') + conn = get_connection() + +or with an alias: + +.. code-block:: python + + connect('mongoenginetest', host='mongomock://localhost', alias='testdb') + conn = get_connection('testdb') diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 5f7a3de9..913de5d6 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -237,7 +237,7 @@ is preferred for achieving this:: # All except for the first 5 people users = User.objects[5:] - # 5 users, starting from the 10th user found + # 5 users, starting from the 11th user found users = User.objects[10:15] You may also index the query to retrieve a single result. If an item at that diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 474c2154..65250b62 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -14,7 +14,7 @@ import errors __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 10, 0) +VERSION = (0, 10, 6) def get_version(): diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index e4d2b392..466b5e88 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -199,7 +199,8 @@ class BaseList(list): def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): if key: - self._instance._mark_as_changed('%s.%s' % (self._name, key)) + self._instance._mark_as_changed('%s.%s' % (self._name, + key % len(self))) else: self._instance._mark_as_changed(self._name) @@ -210,7 +211,7 @@ class EmbeddedDocumentList(BaseList): def __match_all(cls, i, kwargs): items = kwargs.items() return all([ - getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items + getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items ]) @classmethod diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 12d3dfa0..4959991e 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -51,7 +51,7 @@ class BaseDocument(object): # We only want named arguments. field = iter(self._fields_ordered) # If its an automatic id field then skip to the first defined field - if self._auto_id_field: + if getattr(self, '_auto_id_field', False): next(field) for value in args: name = next(field) @@ -325,20 +325,17 @@ class BaseDocument(object): if value is not None: - if isinstance(field, EmbeddedDocumentField): - if fields: - key = '%s.' % field_name - embedded_fields = [ - i.replace(key, '') for i in fields - if i.startswith(key)] + if fields: + key = '%s.' % field_name + embedded_fields = [ + i.replace(key, '') for i in fields + if i.startswith(key)] - else: - embedded_fields = [] - - value = field.to_mongo(value, use_db_field=use_db_field, - fields=embedded_fields) else: - value = field.to_mongo(value) + embedded_fields = [] + + value = field.to_mongo(value, use_db_field=use_db_field, + fields=embedded_fields) # Handle self generating fields if value is None and field._auto_gen: @@ -835,10 +832,6 @@ class BaseDocument(object): if index_list: spec['fields'] = index_list - if spec.get('sparse', False) and len(spec['fields']) > 1: - raise ValueError( - 'Sparse indexes can only have one field in them. ' - 'See https://jira.mongodb.org/browse/SERVER-2193') return spec @@ -974,7 +967,7 @@ class BaseDocument(object): if hasattr(getattr(field, 'field', None), 'lookup_member'): new_field = field.field.lookup_member(field_name) elif cls._dynamic and (isinstance(field, DynamicField) or - getattr(getattr(field, 'document_type'), '_dynamic')): + getattr(getattr(field, 'document_type', None), '_dynamic', None)): new_field = DynamicField(db_field=field_name) else: # Look up subfield on the previous field or raise diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 304c084d..a803657d 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -41,8 +41,8 @@ class BaseField(object): def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, - validation=None, choices=None, verbose_name=None, - help_text=None, null=False, sparse=False, custom_data=None): + validation=None, choices=None, null=False, sparse=False, + **kwargs): """ :param db_field: The database field to store this field in (defaults to the name of the field) @@ -60,16 +60,15 @@ class BaseField(object): field. Generally this is deprecated in favour of the `FIELD.validate` method :param choices: (optional) The valid choices - :param verbose_name: (optional) The verbose name for the field. - Designed to be human readable and is often used when generating - model forms from the document model. - :param help_text: (optional) The help text for this field and is often - used when generating model forms from the document model. :param null: (optional) Is the field value can be null. If no and there is a default value then the default value is set :param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False` means that uniqueness won't be enforced for `None` values - :param custom_data: (optional) Custom metadata for this field. + :param **kwargs: (optional) Arbitrary indirection-free metadata for + this field can be supplied as additional keyword arguments and + accessed as attributes of the field. Must not conflict with any + existing attributes. Common metadata includes `verbose_name` and + `help_text`. """ self.db_field = (db_field or name) if not primary_key else '_id' @@ -83,12 +82,19 @@ class BaseField(object): self.primary_key = primary_key self.validation = validation self.choices = choices - self.verbose_name = verbose_name - self.help_text = help_text self.null = null self.sparse = sparse self._owner_document = None - self.custom_data = custom_data + + # Detect and report conflicts between metadata and base properties. + conflicts = set(dir(self)) & set(kwargs) + if conflicts: + raise TypeError("%s already has attribute(s): %s" % ( + self.__class__.__name__, ', '.join(conflicts) )) + + # Assign metadata to the instance + # This efficient method is available because no __slots__ are defined. + self.__dict__.update(kwargs) # Adjust the appropriate creation counter, and save our local copy. if self.db_field == '_id': @@ -127,7 +133,7 @@ class BaseField(object): if (self.name not in instance._data or instance._data[self.name] != value): instance._mark_as_changed(self.name) - except: + except Exception: # Values cant be compared eg: naive and tz datetimes # So mark it as changed instance._mark_as_changed(self.name) @@ -135,6 +141,10 @@ class BaseField(object): EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument): value._instance = weakref.proxy(instance) + elif isinstance(value, (list, tuple)): + for v in value: + if isinstance(v, EmbeddedDocument): + v._instance = weakref.proxy(instance) instance._data[self.name] = value def error(self, message="", errors=None, field_name=None): @@ -148,7 +158,7 @@ class BaseField(object): """ return value - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): """Convert a Python type to a MongoDB-compatible type. """ return self.to_python(value) @@ -275,8 +285,6 @@ class ComplexBaseField(BaseField): def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ - Document = _import_class('Document') - if isinstance(value, basestring): return value @@ -296,6 +304,7 @@ class ComplexBaseField(BaseField): value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) else: + Document = _import_class('Document') value_dict = {} for k, v in value.items(): if isinstance(v, Document): @@ -315,7 +324,7 @@ class ComplexBaseField(BaseField): key=operator.itemgetter(0))] return value_dict - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): """Convert a Python type to a MongoDB-compatible type. """ Document = _import_class("Document") @@ -327,9 +336,10 @@ class ComplexBaseField(BaseField): if hasattr(value, 'to_mongo'): if isinstance(value, Document): - return GenericReferenceField().to_mongo(value) + return GenericReferenceField().to_mongo( + value, **kwargs) cls = value.__class__ - val = value.to_mongo() + val = value.to_mongo(**kwargs) # If it's a document that is not inherited add _cls if isinstance(value, EmbeddedDocument): val['_cls'] = cls.__name__ @@ -344,7 +354,7 @@ class ComplexBaseField(BaseField): return value if self.field: - value_dict = dict([(key, self.field.to_mongo(item)) + value_dict = dict([(key, self.field.to_mongo(item, **kwargs)) for key, item in value.iteritems()]) else: value_dict = {} @@ -363,19 +373,20 @@ class ComplexBaseField(BaseField): meta.get('allow_inheritance', ALLOW_INHERITANCE) is True) if not allow_inheritance and not self.field: - value_dict[k] = GenericReferenceField().to_mongo(v) + value_dict[k] = GenericReferenceField().to_mongo( + v, **kwargs) else: collection = v._get_collection_name() value_dict[k] = DBRef(collection, v.pk) elif hasattr(v, 'to_mongo'): cls = v.__class__ - val = v.to_mongo() + val = v.to_mongo(**kwargs) # If it's a document that is not inherited add _cls if isinstance(v, (Document, EmbeddedDocument)): val['_cls'] = cls.__name__ value_dict[k] = val else: - value_dict[k] = self.to_mongo(v) + value_dict[k] = self.to_mongo(v, **kwargs) if is_list: # Convert back to a list return [v for _, v in sorted(value_dict.items(), @@ -429,11 +440,11 @@ class ObjectIdField(BaseField): try: if not isinstance(value, ObjectId): value = ObjectId(value) - except: + except Exception: pass return value - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): if not isinstance(value, ObjectId): try: return ObjectId(unicode(value)) @@ -448,7 +459,7 @@ class ObjectIdField(BaseField): def validate(self, value): try: ObjectId(unicode(value)) - except: + except Exception: self.error('Invalid Object ID') @@ -500,7 +511,7 @@ class GeoJsonBaseField(BaseField): # Quick and dirty validator try: value[0][0][0] - except: + except (TypeError, IndexError): return "Invalid Polygon must contain at least one valid linestring" errors = [] @@ -524,7 +535,7 @@ class GeoJsonBaseField(BaseField): # Quick and dirty validator try: value[0][0] - except: + except (TypeError, IndexError): return "Invalid LineString must contain at least one valid point" errors = [] @@ -555,7 +566,7 @@ class GeoJsonBaseField(BaseField): # Quick and dirty validator try: value[0][0] - except: + except (TypeError, IndexError): return "Invalid MultiPoint must contain at least one valid point" errors = [] @@ -574,7 +585,7 @@ class GeoJsonBaseField(BaseField): # Quick and dirty validator try: value[0][0][0] - except: + except (TypeError, IndexError): return "Invalid MultiLineString must contain at least one valid linestring" errors = [] @@ -596,7 +607,7 @@ class GeoJsonBaseField(BaseField): # Quick and dirty validator try: value[0][0][0][0] - except: + except (TypeError, IndexError): return "Invalid MultiPolygon must contain at least one valid Polygon" errors = [] @@ -608,7 +619,7 @@ class GeoJsonBaseField(BaseField): if errors: return "Invalid MultiPolygon:\n%s" % ", ".join(errors) - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): if isinstance(value, dict): return value return SON([("type", self._type), ("coordinates", value)]) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 4abca1ab..4055a9b6 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -38,8 +38,11 @@ def register_connection(alias, name=None, host=None, port=None, :param username: username to authenticate with :param password: password to authenticate with :param authentication_source: database to authenticate against + :param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock://` as db host prefix) :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver + .. versionchanged:: 0.10.6 - added mongomock support """ global _connection_settings @@ -54,8 +57,13 @@ def register_connection(alias, name=None, host=None, port=None, } # Handle uri style connections - if "://" in conn_settings['host']: - uri_dict = uri_parser.parse_uri(conn_settings['host']) + conn_host = conn_settings['host'] + if conn_host.startswith('mongomock://'): + conn_settings['is_mock'] = True + # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` + conn_settings['host'] = conn_host.replace('mongomock://', 'mongodb://', 1) + elif '://' in conn_host: + uri_dict = uri_parser.parse_uri(conn_host) conn_settings.update({ 'name': uri_dict.get('database') or name, 'username': uri_dict.get('username'), @@ -106,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings.pop('password', None) conn_settings.pop('authentication_source', None) - connection_class = MongoClient + is_mock = conn_settings.pop('is_mock', None) + if is_mock: + # Use MongoClient from mongomock + try: + import mongomock + except ImportError: + raise RuntimeError('You need mongomock installed ' + 'to mock MongoEngine.') + connection_class = mongomock.MongoClient + else: + # Use MongoClient from pymongo + connection_class = MongoClient + if 'replicaSet' in conn_settings: # Discard port since it can't be used on MongoReplicaSetClient conn_settings.pop('port', None) @@ -126,6 +146,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): connection_settings.pop('name', None) connection_settings.pop('username', None) connection_settings.pop('password', None) + connection_settings.pop('authentication_source', None) if conn_settings == connection_settings and _connections.get(db_alias, None): connection = _connections[db_alias] break diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 7fcc2ad2..0428095a 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,5 +1,7 @@ from bson import DBRef, SON +from mongoengine.python_support import txt_type + from base import ( BaseDict, BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document @@ -226,7 +228,7 @@ class DeReference(object): data[k]._data[field_name] = self.object_map.get( (v['_ref'].collection, v['_ref'].id), v) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - item_name = "{0}.{1}.{2}".format(name, k, field_name) + item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = '%s.%s' % (name, k) if name else name diff --git a/mongoengine/document.py b/mongoengine/document.py index 9d2d9c5f..2fac15b0 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -217,7 +217,7 @@ class Document(BaseDocument): Returns True if the document has been updated or False if the document in the database doesn't match the query. - .. note:: All unsaved changes that has been made to the document are + .. note:: All unsaved changes that have been made to the document are rejected if the method returns True. :param query: the update will be performed only if the document in the @@ -250,7 +250,7 @@ class Document(BaseDocument): def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, save_condition=None, **kwargs): + _refs=None, save_condition=None, signal_kwargs=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -276,6 +276,8 @@ class Document(BaseDocument): :param save_condition: only perform save if matching record in db satisfies condition(s) (e.g. version number). Raises :class:`OperationError` if the conditions are not satisfied + :parm signal_kwargs: (optional) kwargs dictionary to be passed to + the signal calls. .. versionchanged:: 0.5 In existing documents it only saves changed fields using @@ -297,8 +299,11 @@ class Document(BaseDocument): :class:`OperationError` exception raised if save_condition fails. .. versionchanged:: 0.10.1 :class: save_condition failure now raises a `SaveConditionError` + .. versionchanged:: 0.10.7 + Add signal_kwargs argument """ - signals.pre_save.send(self.__class__, document=self) + signal_kwargs = signal_kwargs or {} + signals.pre_save.send(self.__class__, document=self, **signal_kwargs) if validate: self.validate(clean=clean) @@ -311,7 +316,7 @@ class Document(BaseDocument): created = ('_id' not in doc or self._created or force_insert) signals.pre_save_post_validation.send(self.__class__, document=self, - created=created) + created=created, **signal_kwargs) try: collection = self._get_collection() @@ -341,8 +346,12 @@ class Document(BaseDocument): select_dict['_id'] = object_id shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: - actual_key = self._db_field_map.get(k, k) - select_dict[actual_key] = doc[actual_key] + path = self._lookup_field(k.split('.')) + actual_key = [p.db_field for p in path] + val = doc + for ak in actual_key: + val = val[ak] + select_dict['.'.join(actual_key)] = val def is_new_object(last_error): if last_error is not None: @@ -396,14 +405,15 @@ class Document(BaseDocument): if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) - signals.post_save.send(self.__class__, document=self, created=created) + signals.post_save.send(self.__class__, document=self, + created=created, **signal_kwargs) self._clear_changed_fields() self._created = False return self def cascade_save(self, *args, **kwargs): """Recursively saves any references / - generic references on an objects""" + generic references on the document""" _refs = kwargs.get('_refs', []) or [] ReferenceField = _import_class('ReferenceField') @@ -444,7 +454,12 @@ class Document(BaseDocument): select_dict = {'pk': self.pk} shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: - select_dict[k] = getattr(self, k) + path = self._lookup_field(k.split('.')) + actual_key = [p.db_field for p in path] + val = self + for ak in actual_key: + val = getattr(val, ak) + select_dict['__'.join(actual_key)] = val return select_dict def update(self, **kwargs): @@ -467,18 +482,24 @@ class Document(BaseDocument): # Need to add shard key to query, or you get an error return self._qs.filter(**self._object_key).update_one(**kwargs) - def delete(self, **write_concern): + def delete(self, signal_kwargs=None, **write_concern): """Delete the :class:`~mongoengine.Document` from the database. This will only take effect if the document has been previously saved. + :parm signal_kwargs: (optional) kwargs dictionary to be passed to + the signal calls. :param write_concern: Extra keyword arguments are passed down which will be used as options for the resultant ``getLastError`` command. For example, ``save(..., write_concern={w: 2, fsync: True}, ...)`` will wait until at least two servers have recorded the write and will force an fsync on the primary server. + + .. versionchanged:: 0.10.7 + Add signal_kwargs argument """ - signals.pre_delete.send(self.__class__, document=self) + signal_kwargs = signal_kwargs or {} + signals.pre_delete.send(self.__class__, document=self, **signal_kwargs) # Delete FileFields separately FileField = _import_class('FileField') @@ -492,7 +513,7 @@ class Document(BaseDocument): except pymongo.errors.OperationFailure, err: message = u'Could not delete document (%s)' % err.message raise OperationError(message) - signals.post_delete.send(self.__class__, document=self) + signals.post_delete.send(self.__class__, document=self, **signal_kwargs) def switch_db(self, db_alias, keep_created=True): """ @@ -595,11 +616,16 @@ class Document(BaseDocument): if not fields or field in fields: try: setattr(self, field, self._reload(field, obj[field])) - except KeyError: - # If field is removed from the database while the object - # is in memory, a reload would cause a KeyError - # i.e. obj.update(unset__field=1) followed by obj.reload() - delattr(self, field) + except (KeyError, AttributeError): + try: + # If field is a special field, e.g. items is stored as _reserved_items, + # an KeyError is thrown. So try to retrieve the field from _data + setattr(self, field, self._reload(field, obj._data.get(field))) + except KeyError: + # If field is removed from the database while the object + # is in memory, a reload would cause a KeyError + # i.e. obj.update(unset__field=1) followed by obj.reload() + delattr(self, field) self._changed_fields = obj._changed_fields self._created = False @@ -653,10 +679,20 @@ class Document(BaseDocument): def drop_collection(cls): """Drops the entire collection associated with this :class:`~mongoengine.Document` type from the database. + + Raises :class:`OperationError` if the document has no collection set + (i.g. if it is `abstract`) + + .. versionchanged:: 0.10.7 + :class:`OperationError` exception raised if no collection available """ + col_name = cls._get_collection_name() + if not col_name: + raise OperationError('Document %s has no collection defined ' + '(is it abstract ?)' % cls) cls._collection = None db = cls._get_db() - db.drop_collection(cls._get_collection_name()) + db.drop_collection(col_name) @classmethod def create_index(cls, keys, background=False, **kwargs): @@ -945,7 +981,7 @@ class MapReduceDocument(object): if not isinstance(self.key, id_field_type): try: self.key = id_field_type(self.key) - except: + except Exception: raise Exception("Could not cast key as %s" % id_field_type.__name__) diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 2c5c2946..15830b5c 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -6,7 +6,7 @@ from mongoengine.python_support import txt_type __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', 'OperationError', 'NotUniqueError', 'FieldDoesNotExist', - 'ValidationError') + 'ValidationError', 'SaveConditionError') class NotRegistered(Exception): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f5899311..2db6383c 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -8,6 +8,8 @@ import uuid import warnings from operator import itemgetter +import six + try: import dateutil except ImportError: @@ -18,6 +20,10 @@ else: import pymongo import gridfs from bson import Binary, DBRef, SON, ObjectId +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, @@ -65,7 +71,7 @@ class StringField(BaseField): return value try: value = value.decode('utf-8') - except: + except Exception: pass return value @@ -194,7 +200,7 @@ class IntField(BaseField): def validate(self, value): try: value = int(value) - except: + except Exception: self.error('%s could not be converted to int' % value) if self.min_value is not None and value < self.min_value: @@ -225,10 +231,13 @@ class LongField(BaseField): pass return value + def to_mongo(self, value, **kwargs): + return Int64(value) + def validate(self, value): try: value = long(value) - except: + except Exception: self.error('%s could not be converted to long' % value) if self.min_value is not None and value < self.min_value: @@ -260,10 +269,14 @@ class FloatField(BaseField): return value def validate(self, value): - if isinstance(value, int): - value = float(value) + if isinstance(value, six.integer_types): + try: + value = float(value) + except OverflowError: + self.error('The value is too large to be converted to float') + if not isinstance(value, float): - self.error('FloatField only accepts float values') + self.error('FloatField only accepts float and integer values') if self.min_value is not None and value < self.min_value: self.error('Float value is too small') @@ -325,7 +338,7 @@ class DecimalField(BaseField): return value return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) - def to_mongo(self, value, use_db_field=True): + def to_mongo(self, value, **kwargs): if value is None: return value if self.force_string: @@ -388,7 +401,7 @@ class DateTimeField(BaseField): if not isinstance(new_value, (datetime.datetime, datetime.date)): self.error(u'cannot parse date "%s"' % value) - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): if value is None: return value if isinstance(value, datetime.datetime): @@ -508,10 +521,10 @@ class ComplexDateTimeField(StringField): original_value = value try: return self._convert_from_string(value) - except: + except Exception: return original_value - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): value = self.to_python(value) return self._convert_from_datetime(value) @@ -546,11 +559,10 @@ class EmbeddedDocumentField(BaseField): return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) return value - def to_mongo(self, value, use_db_field=True, fields=[]): + def to_mongo(self, value, **kwargs): if not isinstance(value, self.document_type): return value - return self.document_type.to_mongo(value, use_db_field, - fields=fields) + return self.document_type.to_mongo(value, **kwargs) def validate(self, value, clean=True): """Make sure that the document instance is an instance of the @@ -600,11 +612,11 @@ class GenericEmbeddedDocumentField(BaseField): value.validate(clean=clean) - def to_mongo(self, document, use_db_field=True): + def to_mongo(self, document, **kwargs): if document is None: return None - data = document.to_mongo(use_db_field) + data = document.to_mongo(**kwargs) if '_cls' not in data: data['_cls'] = document._class_name return data @@ -616,7 +628,7 @@ class DynamicField(BaseField): Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): """Convert a Python type to a MongoDB compatible type. """ @@ -625,7 +637,7 @@ class DynamicField(BaseField): if hasattr(value, 'to_mongo'): cls = value.__class__ - val = value.to_mongo() + val = value.to_mongo(**kwargs) # If we its a document thats not inherited add _cls if isinstance(value, Document): val = {"_ref": value.to_dbref(), "_cls": cls.__name__} @@ -643,7 +655,7 @@ class DynamicField(BaseField): data = {} for k, v in value.iteritems(): - data[k] = self.to_mongo(v) + data[k] = self.to_mongo(v, **kwargs) value = data if is_list: # Convert back to a list @@ -697,7 +709,7 @@ class ListField(ComplexBaseField): def prepare_query_value(self, op, value): if self.field: - if op in ('set', 'unset') and ( + if op in ('set', 'unset', None) and ( not isinstance(value, basestring) and not isinstance(value, BaseDocument) and hasattr(value, '__iter__')): @@ -755,8 +767,8 @@ class SortedListField(ListField): self._order_reverse = kwargs.pop('reverse') super(SortedListField, self).__init__(field, **kwargs) - def to_mongo(self, value): - value = super(SortedListField, self).to_mongo(value) + def to_mongo(self, value, **kwargs): + value = super(SortedListField, self).to_mongo(value, **kwargs) if self._ordering is not None: return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse) @@ -863,12 +875,11 @@ class ReferenceField(BaseField): The options are: - * DO_NOTHING - don't do anything (default). - * NULLIFY - Updates the reference to null. - * CASCADE - Deletes the documents associated with the reference. - * DENY - Prevent the deletion of the reference object. - * PULL - Pull the reference from a :class:`~mongoengine.fields.ListField` - of references + * DO_NOTHING (0) - don't do anything (default). + * NULLIFY (1) - Updates the reference to null. + * CASCADE (2) - Deletes the documents associated with the reference. + * DENY (3) - Prevent the deletion of the reference object. + * PULL (4) - Pull the reference from a :class:`~mongoengine.fields.ListField` of references Alternative syntax for registering delete rules (useful when implementing bi-directional delete rules) @@ -879,7 +890,7 @@ class ReferenceField(BaseField): content = StringField() foo = ReferenceField('Foo') - Bar.register_delete_rule(Foo, 'bar', NULLIFY) + Foo.register_delete_rule(Bar, 'foo', NULLIFY) .. note :: `reverse_delete_rule` does not trigger pre / post delete signals to be @@ -896,6 +907,10 @@ class ReferenceField(BaseField): or as the :class:`~pymongo.objectid.ObjectId`.id . :param reverse_delete_rule: Determines what to do when the referring object is deleted + + .. note :: + A reference to an abstract document type is always stored as a + :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. """ if not isinstance(document_type, basestring): if not issubclass(document_type, (Document, basestring)): @@ -928,33 +943,46 @@ class ReferenceField(BaseField): self._auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs if self._auto_dereference and isinstance(value, DBRef): - value = self.document_type._get_db().dereference(value) + if hasattr(value, 'cls'): + # Dereference using the class type specified in the reference + cls = get_document(value.cls) + else: + cls = self.document_type + value = cls._get_db().dereference(value) if value is not None: - instance._data[self.name] = self.document_type._from_son(value) + instance._data[self.name] = cls._from_son(value) return super(ReferenceField, self).__get__(instance, owner) - def to_mongo(self, document): + def to_mongo(self, document, **kwargs): if isinstance(document, DBRef): if not self.dbref: return document.id return document - id_field_name = self.document_type._meta['id_field'] - id_field = self.document_type._fields[id_field_name] - if isinstance(document, Document): # We need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: self.error('You can only reference documents once they have' ' been saved to the database') + + # Use the attributes from the document instance, so that they + # override the attributes of this field's document type + cls = document else: id_ = document + cls = self.document_type - id_ = id_field.to_mongo(id_) - if self.dbref: - collection = self.document_type._get_collection_name() + id_field_name = cls._meta['id_field'] + id_field = cls._fields[id_field_name] + + id_ = id_field.to_mongo(id_, **kwargs) + if self.document_type._meta.get('abstract'): + collection = cls._get_collection_name() + return DBRef(collection, id_, cls=cls._class_name) + elif self.dbref: + collection = cls._get_collection_name() return DBRef(collection, id_) return id_ @@ -983,6 +1011,14 @@ class ReferenceField(BaseField): self.error('You can only reference documents once they have been ' 'saved to the database') + if self.document_type._meta.get('abstract') and \ + not isinstance(value, self.document_type): + self.error('%s is not an instance of abstract reference' + ' type %s' % (value._class_name, + self.document_type._class_name) + ) + + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -990,7 +1026,7 @@ class ReferenceField(BaseField): class CachedReferenceField(BaseField): """ A referencefield with cache fields to purpose pseudo-joins - + .. versionadded:: 0.9 """ @@ -1064,7 +1100,7 @@ class CachedReferenceField(BaseField): return super(CachedReferenceField, self).__get__(instance, owner) - def to_mongo(self, document): + def to_mongo(self, document, **kwargs): id_field_name = self.document_type._meta['id_field'] id_field = self.document_type._fields[id_field_name] @@ -1079,10 +1115,11 @@ class CachedReferenceField(BaseField): # TODO: should raise here or will fail next statement value = SON(( - ("_id", id_field.to_mongo(id_)), + ("_id", id_field.to_mongo(id_, **kwargs)), )) - value.update(dict(document.to_mongo(fields=self.fields))) + kwargs['fields'] = self.fields + value.update(dict(document.to_mongo(**kwargs))) return value def prepare_query_value(self, op, value): @@ -1198,7 +1235,7 @@ class GenericReferenceField(BaseField): doc = doc_cls._from_son(doc) return doc - def to_mongo(self, document, use_db_field=True): + def to_mongo(self, document, **kwargs): if document is None: return None @@ -1217,7 +1254,7 @@ class GenericReferenceField(BaseField): else: id_ = document - id_ = id_field.to_mongo(id_) + id_ = id_field.to_mongo(id_, **kwargs) collection = document._get_collection_name() ref = DBRef(collection, id_) return SON(( @@ -1246,7 +1283,7 @@ class BinaryField(BaseField): value = bin_type(value) return super(BinaryField, self).__set__(instance, value) - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): return Binary(value) def validate(self, value): @@ -1346,7 +1383,7 @@ class GridFSProxy(object): if self.gridout is None: self.gridout = self.fs.get(self.grid_id) return self.gridout - except: + except Exception: # File has been deleted return None @@ -1384,7 +1421,7 @@ class GridFSProxy(object): else: try: return gridout.read(size) - except: + except Exception: return "" def delete(self): @@ -1449,7 +1486,7 @@ class FileField(BaseField): if grid_file: try: grid_file.delete() - except: + except Exception: pass # Create a new proxy object as we don't already have one @@ -1471,7 +1508,7 @@ class FileField(BaseField): db_alias=db_alias, collection_name=collection_name) - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): # Store the GridFS file id in MongoDB if isinstance(value, self.proxy_class) and value.grid_id is not None: return value.grid_id @@ -1683,17 +1720,17 @@ class SequenceField(BaseField): :param collection_name: Name of the counter collection (default 'mongoengine.counters') :param sequence_name: Name of the sequence in the collection (default 'ClassName.counter') :param value_decorator: Any callable to use as a counter (default int) - + Use any callable as `value_decorator` to transform calculated counter into any value suitable for your needs, e.g. string or hexadecimal representation of the default integer counter value. - + .. note:: - - In case the counter is defined in the abstract document, it will be - common to all inherited documents and the default sequence name will + + In case the counter is defined in the abstract document, it will be + common to all inherited documents and the default sequence name will be the class name of the abstract document. - + .. versionadded:: 0.5 .. versionchanged:: 0.8 added `value_decorator` """ @@ -1817,11 +1854,11 @@ class UUIDField(BaseField): if not isinstance(value, basestring): value = unicode(value) return uuid.UUID(value) - except: + except Exception: return original_value return value - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): if not self._binary: return unicode(value) elif isinstance(value, basestring): diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 0e183889..7efb0fb6 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -266,7 +266,8 @@ class BaseQuerySet(object): result = None return result - def insert(self, doc_or_docs, load_bulk=True, write_concern=None): + def insert(self, doc_or_docs, load_bulk=True, + write_concern=None, signal_kwargs=None): """bulk insert documents :param doc_or_docs: a document or list of documents to be inserted @@ -279,11 +280,15 @@ class BaseQuerySet(object): ``insert(..., {w: 2, fsync: True})`` will wait until at least two servers have recorded the write and will force an fsync on each server being written to. + :parm signal_kwargs: (optional) kwargs dictionary to be passed to + the signal calls. By default returns document instances, set ``load_bulk`` to False to return just ``ObjectIds`` .. versionadded:: 0.5 + .. versionchanged:: 0.10.7 + Add signal_kwargs argument """ Document = _import_class('Document') @@ -296,7 +301,6 @@ class BaseQuerySet(object): return_one = True docs = [docs] - raw = [] for doc in docs: if not isinstance(doc, self._document): msg = ("Some documents inserted aren't instances of %s" @@ -305,9 +309,12 @@ class BaseQuerySet(object): if doc.pk and not doc._created: msg = "Some documents have ObjectIds use doc.update() instead" raise OperationError(msg) - raw.append(doc.to_mongo()) - signals.pre_bulk_insert.send(self._document, documents=docs) + signal_kwargs = signal_kwargs or {} + signals.pre_bulk_insert.send(self._document, + documents=docs, **signal_kwargs) + + raw = [doc.to_mongo() for doc in docs] try: ids = self._collection.insert(raw, **write_concern) except pymongo.errors.DuplicateKeyError, err: @@ -324,7 +331,7 @@ class BaseQuerySet(object): if not load_bulk: signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False) + self._document, documents=docs, loaded=False, **signal_kwargs) return return_one and ids[0] or ids documents = self.in_bulk(ids) @@ -332,7 +339,7 @@ class BaseQuerySet(object): for obj_id in ids: results.append(documents.get(obj_id)) signals.post_bulk_insert.send( - self._document, documents=results, loaded=True) + self._document, documents=results, loaded=True, **signal_kwargs) return return_one and results[0] or results def count(self, with_limit_and_skip=False): @@ -403,8 +410,10 @@ class BaseQuerySet(object): rule = doc._meta['delete_rules'][rule_entry] if rule == CASCADE: cascade_refs = set() if cascade_refs is None else cascade_refs - for ref in queryset: - cascade_refs.add(ref.id) + # Handle recursive reference + if doc._collection == document_cls._collection: + for ref in queryset: + cascade_refs.add(ref.id) ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) ref_q_count = ref_q.count() if ref_q_count > 0: @@ -425,7 +434,7 @@ class BaseQuerySet(object): full_result=False, **update): """Perform an atomic update on the fields matched by the query. - :param upsert: Any existing document with that "_id" is overwritten. + :param upsert: insert if document doesn't exist (default ``False``) :param multi: Update multiple documents. :param write_concern: Extra keyword arguments are passed down which will be used as options for the resultant @@ -471,10 +480,36 @@ class BaseQuerySet(object): raise OperationError(message) raise OperationError(u'Update failed (%s)' % unicode(err)) - def update_one(self, upsert=False, write_concern=None, **update): - """Perform an atomic update on first field matched by the query. + def upsert_one(self, write_concern=None, **update): + """Overwrite or add the first document matched by the query. - :param upsert: Any existing document with that "_id" is overwritten. + :param write_concern: Extra keyword arguments are passed down which + will be used as options for the resultant + ``getLastError`` command. For example, + ``save(..., write_concern={w: 2, fsync: True}, ...)`` will + wait until at least two servers have recorded the write and + will force an fsync on the primary server. + :param update: Django-style update keyword arguments + + :returns the new or overwritten document + + .. versionadded:: 0.10.2 + """ + + atomic_update = self.update(multi=False, upsert=True, write_concern=write_concern, + full_result=True, **update) + + if atomic_update['updatedExisting']: + document = self.get() + else: + document = self._document.objects.with_id(atomic_update['upserted']) + return document + + def update_one(self, upsert=False, write_concern=None, **update): + """Perform an atomic update on the fields of the first document + matched by the query. + + :param upsert: insert if document doesn't exist (default ``False``) :param write_concern: Extra keyword arguments are passed down which will be used as options for the resultant ``getLastError`` command. For example, @@ -929,6 +964,7 @@ class BaseQuerySet(object): validate_read_preference('read_preference', read_preference) queryset = self.clone() queryset._read_preference = read_preference + queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference return queryset def scalar(self, *fields): @@ -1201,66 +1237,28 @@ class BaseQuerySet(object): def sum(self, field): """Sum over the values of the specified field. - :param field: the field to sum over; use dot-notation to refer to + :param field: the field to sum over; use dot notation to refer to embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. """ - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'), - field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(1, item||0); - }); - } else if (typeof field != 'undefined') { - emit(1, field||0); - } - } - """ % dict(field=field) - - reduce_func = Code(""" - function(key, values) { - var sum = 0; - for (var i in values) { - sum += values[i]; - } - return sum; - } - """) - - for result in self.map_reduce(map_func, reduce_func, output='inline'): - return result.value - else: - return 0 - - def aggregate_sum(self, field): - """Sum over the values of the specified field. - - :param field: the field to sum over; use dot-notation to refer to - embedded document fields - - This method is more performant than the regular `sum`, because it uses - the aggregation framework instead of map-reduce. - """ - result = self._document._get_collection().aggregate([ + pipeline = [ {'$match': self._query}, {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} - ]) + ] + + # if we're performing a sum over a list field, we sum up all the + # elements in the list, hence we need to $unwind the arrays first + ListField = _import_class('ListField') + field_parts = field.split('.') + field_instances = self._document._lookup_field(field_parts) + if isinstance(field_instances[-1], ListField): + pipeline.insert(1, {'$unwind': '$' + field}) + + result = self._document._get_collection().aggregate(pipeline) if IS_PYMONGO_3: - result = list(result) + result = tuple(result) else: result = result.get('result') + if result: return result[0]['total'] return 0 @@ -1268,73 +1266,26 @@ class BaseQuerySet(object): def average(self, field): """Average over the values of the specified field. - :param field: the field to average over; use dot-notation to refer to + :param field: the field to average over; use dot notation to refer to embedded document fields - - .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work - with sharding. """ - map_func = """ - function() { - var path = '{{~%(field)s}}'.split('.'), - field = this; - - for (p in path) { - if (typeof field != 'undefined') - field = field[path[p]]; - else - break; - } - - if (field && field.constructor == Array) { - field.forEach(function(item) { - emit(1, {t: item||0, c: 1}); - }); - } else if (typeof field != 'undefined') { - emit(1, {t: field||0, c: 1}); - } - } - """ % dict(field=field) - - reduce_func = Code(""" - function(key, values) { - var out = {t: 0, c: 0}; - for (var i in values) { - var value = values[i]; - out.t += value.t; - out.c += value.c; - } - return out; - } - """) - - finalize_func = Code(""" - function(key, value) { - return value.t / value.c; - } - """) - - for result in self.map_reduce(map_func, reduce_func, - finalize_f=finalize_func, output='inline'): - return result.value - else: - return 0 - - def aggregate_average(self, field): - """Average over the values of the specified field. - - :param field: the field to average over; use dot-notation to refer to - embedded document fields - - This method is more performant than the regular `average`, because it - uses the aggregation framework instead of map-reduce. - """ - result = self._document._get_collection().aggregate([ + pipeline = [ {'$match': self._query}, {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} - ]) + ] + + # if we're performing an average over a list field, we average out + # all the elements in the list, hence we need to $unwind the arrays + # first + ListField = _import_class('ListField') + field_parts = field.split('.') + field_instances = self._document._lookup_field(field_parts) + if isinstance(field_instances[-1], ListField): + pipeline.insert(1, {'$unwind': '$' + field}) + + result = self._document._get_collection().aggregate(pipeline) if IS_PYMONGO_3: - result = list(result) + result = tuple(result) else: result = result.get('result') if result: @@ -1351,7 +1302,7 @@ class BaseQuerySet(object): Can only do direct simple mappings and cannot map across :class:`~mongoengine.fields.ReferenceField` or :class:`~mongoengine.fields.GenericReferenceField` for more complex - counting a manual map reduce call would is required. + counting a manual map reduce call is required. If the field is a :class:`~mongoengine.fields.ListField`, the items within each list will be counted individually. @@ -1425,7 +1376,7 @@ class BaseQuerySet(object): msg = "The snapshot option is not anymore available with PyMongo 3+" warnings.warn(msg, DeprecationWarning) cursor_args = { - 'no_cursor_timeout': self._timeout + 'no_cursor_timeout': not self._timeout } if self._loaded_fields: cursor_args[fields_name] = self._loaded_fields.as_dict() @@ -1442,8 +1393,16 @@ class BaseQuerySet(object): def _cursor(self): if self._cursor_obj is None: - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) + # In PyMongo 3+, we define the read preference on a collection + # level, not a cursor level. Thus, we need to get a cloned + # collection object using `with_options` first. + if IS_PYMONGO_3 and self._read_preference is not None: + self._cursor_obj = self._collection\ + .with_options(read_preference=self._read_preference)\ + .find(self._query, **self._cursor_args) + else: + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) # Apply where clauses to cursor if self._where_clause: where_clause = self._sub_js_fields(self._where_clause) @@ -1660,7 +1619,7 @@ class BaseQuerySet(object): key = key.replace('__', '.') try: key = self._document._translate_field_name(key) - except: + except Exception: pass key_list.append((key, direction)) diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py index 47c2143d..199205e9 100644 --- a/mongoengine/queryset/manager.py +++ b/mongoengine/queryset/manager.py @@ -29,7 +29,7 @@ class QuerySetManager(object): Document.objects is accessed. """ if instance is not None: - # Document class being used rather than a document object + # Document object being used rather than a document class return self # owner is the document that contains the QuerySetManager diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 6e5f7220..5121463b 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -38,7 +38,7 @@ class QuerySet(BaseQuerySet): def __len__(self): """Since __len__ is called quite frequently (for example, as part of - list(qs) we populate the result cache and cache the length. + list(qs)), we populate the result cache and cache the length. """ if self._len is not None: return self._len diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 03f3acf0..e5e7f83f 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -26,12 +26,12 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS) -def query(_doc_cls=None, **query): +def query(_doc_cls=None, **kwargs): """Transform a query from Django-style format to Mongo format. """ mongo_query = {} merge_query = defaultdict(list) - for key, value in sorted(query.items()): + for key, value in sorted(kwargs.items()): if key == "__raw__": mongo_query.update(value) continue @@ -44,7 +44,7 @@ def query(_doc_cls=None, **query): if len(parts) > 1 and parts[-1] in MATCH_OPERATORS: op = parts.pop() - # Allw to escape operator-like field name by __ + # Allow to escape operator-like field name by __ if len(parts) > 1 and parts[-1] == "": parts.pop() @@ -105,13 +105,18 @@ def query(_doc_cls=None, **query): if op: if op in GEO_OPERATORS: value = _geo_operator(field, op, value) - elif op in CUSTOM_OPERATORS: - if op in ('elem_match', 'match'): - value = field.prepare_query_value(op, value) - value = {"$elemMatch": value} + elif op in ('match', 'elemMatch'): + ListField = _import_class('ListField') + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') + if (isinstance(value, dict) and isinstance(field, ListField) and + isinstance(field.field, EmbeddedDocumentField)): + value = query(field.field.document_type, **value) else: - NotImplementedError("Custom method '%s' has not " - "been implemented" % op) + value = field.prepare_query_value(op, value) + value = {"$elemMatch": value} + elif op in CUSTOM_OPERATORS: + NotImplementedError("Custom method '%s' has not " + "been implemented" % op) elif op not in STRING_OPERATORS: value = {'$' + op: value} @@ -207,6 +212,10 @@ def update(_doc_cls=None, **update): if parts[-1] in COMPARISON_OPERATORS: match = parts.pop() + # Allow to escape operator-like field name by __ + if len(parts) > 1 and parts[-1] == "": + parts.pop() + if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] try: @@ -359,20 +368,24 @@ def _infer_geometry(value): "type and coordinates keys") elif isinstance(value, (list, set)): # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? + # TODO: should both TypeError and IndexError be alike interpreted? + try: value[0][0][0] return {"$geometry": {"type": "Polygon", "coordinates": value}} - except: + except (TypeError, IndexError): pass + try: value[0][0] return {"$geometry": {"type": "LineString", "coordinates": value}} - except: + except (TypeError, IndexError): pass + try: value[0] return {"$geometry": {"type": "Point", "coordinates": value}} - except: + except (TypeError, IndexError): pass raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " diff --git a/requirements.txt b/requirements.txt index 03935868..b6a5b06c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -pymongo>=2.7.1 nose +pymongo>=2.7.1 +six==1.10.0 diff --git a/setup.py b/setup.py index 7384e04d..e34d834a 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,12 @@ except ImportError: DESCRIPTION = 'MongoEngine is a Python Object-Document ' + \ 'Mapper for working with MongoDB.' -LONG_DESCRIPTION = None + try: - LONG_DESCRIPTION = open('README.rst').read() -except: - pass + with open('README.rst') as fin: + LONG_DESCRIPTION = fin.read() +except Exception: + LONG_DESCRIPTION = None def get_version(version_tuple): @@ -77,7 +78,7 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=2.7.1'], + install_requires=['pymongo>=2.7.1', 'six'], test_suite='nose.collector', **extra_opts ) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index ccc6cf44..e13d8b84 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -5,6 +5,7 @@ import sys sys.path[0:0] = [""] import pymongo +from random import randint from nose.plugins.skip import SkipTest from datetime import datetime @@ -16,9 +17,11 @@ __all__ = ("IndexesTest", ) class IndexesTest(unittest.TestCase): + _MAX_RAND = 10 ** 10 def setUp(self): - self.connection = connect(db='mongoenginetest') + self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND)) + self.connection = connect(db=self.db_name) self.db = get_db() class Person(Document): @@ -32,10 +35,7 @@ class IndexesTest(unittest.TestCase): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue - self.db.drop_collection(collection) + self.connection.drop_database(self.db) def test_indexes_document(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -822,33 +822,29 @@ class IndexesTest(unittest.TestCase): name = StringField(required=True) term = StringField(required=True) - class Report(Document): + class ReportEmbedded(Document): key = EmbeddedDocumentField(CompoundKey, primary_key=True) text = StringField() - Report.drop_collection() - my_key = CompoundKey(name="n", term="ok") - report = Report(text="OK", key=my_key).save() + report = ReportEmbedded(text="OK", key=my_key).save() self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, report.to_mongo()) - self.assertEqual(report, Report.objects.get(pk=my_key)) + self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) def test_compound_key_dictfield(self): - class Report(Document): + class ReportDictField(Document): key = DictField(primary_key=True) text = StringField() - Report.drop_collection() - my_key = {"name": "n", "term": "ok"} - report = Report(text="OK", key=my_key).save() + report = ReportDictField(text="OK", key=my_key).save() self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, report.to_mongo()) - self.assertEqual(report, Report.objects.get(pk=my_key)) + self.assertEqual(report, ReportDictField.objects.get(pk=my_key)) def test_string_indexes(self): @@ -863,6 +859,20 @@ class IndexesTest(unittest.TestCase): self.assertTrue([('provider_ids.foo', 1)] in info) self.assertTrue([('provider_ids.bar', 1)] in info) + def test_sparse_compound_indexes(self): + + class MyDoc(Document): + provider_ids = DictField() + meta = { + "indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"), + 'sparse': True}], + } + + info = MyDoc.objects._collection.index_information() + self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)], + info['provider_ids.foo_1_provider_ids.bar_1']['key']) + self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) + def test_text_indexes(self): class Book(Document): @@ -895,26 +905,38 @@ class IndexesTest(unittest.TestCase): Issue #812 """ + # Use a new connection and database since dropping the database could + # cause concurrent tests to fail. + connection = connect(db='tempdatabase', + alias='test_indexes_after_database_drop') + class BlogPost(Document): title = StringField() slug = StringField(unique=True) - BlogPost.drop_collection() + meta = {'db_alias': 'test_indexes_after_database_drop'} - # Create Post #1 - post1 = BlogPost(title='test1', slug='test') - post1.save() + try: + BlogPost.drop_collection() - # Drop the Database - self.connection.drop_database(BlogPost._get_db().name) + # Create Post #1 + post1 = BlogPost(title='test1', slug='test') + post1.save() - # Re-create Post #1 - post1 = BlogPost(title='test1', slug='test') - post1.save() + # Drop the Database + connection.drop_database('tempdatabase') + + # Re-create Post #1 + post1 = BlogPost(title='test1', slug='test') + post1.save() + + # Create Post #2 + post2 = BlogPost(title='test2', slug='test') + self.assertRaises(NotUniqueError, post2.save) + finally: + # Drop the temporary database at the end + connection.drop_database('tempdatabase') - # Create Post #2 - post2 = BlogPost(title='test2', slug='test') - self.assertRaises(NotUniqueError, post2.save) def test_index_dont_send_cls_option(self): """ diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 7673a103..957938be 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -411,7 +411,7 @@ class InheritanceTest(unittest.TestCase): try: class MyDocument(DateCreatedDocument, DateUpdatedDocument): pass - except: + except Exception: self.assertTrue(False, "Couldn't create MyDocument class") def test_abstract_documents(self): diff --git a/tests/document/instance.py b/tests/document/instance.py index 3d41857e..cb2c1746 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -7,12 +7,13 @@ import os import pickle import unittest import uuid +import weakref from datetime import datetime from bson import DBRef, ObjectId from tests import fixtures from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, - PickleDyanmicEmbedded, PickleDynamicTest) + PickleDynamicEmbedded, PickleDynamicTest) from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, @@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), __all__ = ("InstanceTest",) + + class InstanceTest(unittest.TestCase): def setUp(self): @@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase): list(self.Person._get_collection().find().sort("id")), sorted(docs, key=lambda doc: doc["_id"])) + def assertHasInstance(self, field, instance): + self.assertTrue(hasattr(field, "_instance")) + self.assertTrue(field._instance is not None) + if isinstance(field._instance, weakref.ProxyType): + self.assertTrue(field._instance.__eq__(instance)) + else: + self.assertEqual(field._instance, instance) + def test_capped_collection(self): """Ensure that capped collections work properly. """ @@ -473,6 +484,20 @@ class InstanceTest(unittest.TestCase): doc.reload() Animal.drop_collection() + def test_reload_sharded_nested(self): + class SuperPhylum(EmbeddedDocument): + name = StringField() + + class Animal(Document): + superphylum = EmbeddedDocumentField(SuperPhylum) + meta = {'shard_key': ('superphylum.name',)} + + Animal.drop_collection() + doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) + doc.save() + doc.reload() + Animal.drop_collection() + def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly """ @@ -546,6 +571,28 @@ class InstanceTest(unittest.TestCase): except Exception: self.assertFalse("Threw wrong exception") + def test_reload_of_non_strict_with_special_field_name(self): + """Ensures reloading works for documents with meta strict == False + """ + class Post(Document): + meta = { + 'strict': False + } + title = StringField() + items = ListField() + + Post.drop_collection() + + Post._get_collection().insert({ + "title": "Items eclipse", + "items": ["more lorem", "even more ipsum"] + }) + + post = Post.objects.first() + post.reload() + self.assertEqual(post.title, "Items eclipse") + self.assertEqual(post.items, ["more lorem", "even more ipsum"]) + def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -608,10 +655,12 @@ class InstanceTest(unittest.TestCase): embedded_field = EmbeddedDocumentField(Embedded) Doc.drop_collection() - Doc(embedded_field=Embedded(string="Hi")).save() + doc = Doc(embedded_field=Embedded(string="Hi")) + self.assertHasInstance(doc.embedded_field, doc) + doc.save() doc = Doc.objects.get() - self.assertEqual(doc, doc.embedded_field._instance) + self.assertHasInstance(doc.embedded_field, doc) def test_embedded_document_complex_instance(self): """Ensure that embedded documents in complex fields can reference @@ -623,10 +672,25 @@ class InstanceTest(unittest.TestCase): embedded_field = ListField(EmbeddedDocumentField(Embedded)) Doc.drop_collection() - Doc(embedded_field=[Embedded(string="Hi")]).save() + doc = Doc(embedded_field=[Embedded(string="Hi")]) + self.assertHasInstance(doc.embedded_field[0], doc) + doc.save() doc = Doc.objects.get() - self.assertEqual(doc, doc.embedded_field[0]._instance) + self.assertHasInstance(doc.embedded_field[0], doc) + + def test_embedded_document_complex_instance_no_use_db_field(self): + """Ensure that use_db_field is propagated to list of Emb Docs + """ + class Embedded(EmbeddedDocument): + string = StringField(db_field='s') + + class Doc(Document): + embedded_field = ListField(EmbeddedDocumentField(Embedded)) + + d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo( + use_db_field=False).to_dict() + self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) def test_instance_is_set_on_setattr(self): @@ -639,11 +703,28 @@ class InstanceTest(unittest.TestCase): Account.drop_collection() acc = Account() acc.email = Email(email='test@example.com') - self.assertTrue(hasattr(acc._data["email"], "_instance")) + self.assertHasInstance(acc._data["email"], acc) acc.save() acc1 = Account.objects.first() - self.assertTrue(hasattr(acc1._data["email"], "_instance")) + self.assertHasInstance(acc1._data["email"], acc1) + + def test_instance_is_set_on_setattr_on_embedded_document_list(self): + + class Email(EmbeddedDocument): + email = EmailField() + + class Account(Document): + emails = EmbeddedDocumentListField(Email) + + Account.drop_collection() + acc = Account() + acc.emails = [Email(email='test@example.com')] + self.assertHasInstance(acc._data["emails"][0], acc) + acc.save() + + acc1 = Account.objects.first() + self.assertHasInstance(acc1._data["emails"][0], acc1) def test_document_clean(self): class TestDocument(Document): @@ -1825,6 +1906,62 @@ class InstanceTest(unittest.TestCase): author.delete() self.assertEqual(BlogPost.objects.count(), 0) + def test_reverse_delete_rule_with_custom_id_field(self): + """Ensure that a referenced document with custom primary key + is also deleted upon deletion. + """ + class User(Document): + name = StringField(primary_key=True) + + class Book(Document): + author = ReferenceField(User, reverse_delete_rule=CASCADE) + reviewer = ReferenceField(User, reverse_delete_rule=NULLIFY) + + User.drop_collection() + Book.drop_collection() + + user = User(name='Mike').save() + reviewer = User(name='John').save() + book = Book(author=user, reviewer=reviewer).save() + + reviewer.delete() + self.assertEqual(Book.objects.count(), 1) + self.assertEqual(Book.objects.get().reviewer, None) + + user.delete() + self.assertEqual(Book.objects.count(), 0) + + def test_reverse_delete_rule_with_shared_id_among_collections(self): + """Ensure that cascade delete rule doesn't mix id among collections. + """ + class User(Document): + id = IntField(primary_key=True) + + class Book(Document): + id = IntField(primary_key=True) + author = ReferenceField(User, reverse_delete_rule=CASCADE) + + User.drop_collection() + Book.drop_collection() + + user_1 = User(id=1).save() + user_2 = User(id=2).save() + book_1 = Book(id=1, author=user_2).save() + book_2 = Book(id=2, author=user_1).save() + + user_2.delete() + # Deleting user_2 should also delete book_1 but not book_2 + self.assertEqual(Book.objects.count(), 1) + self.assertEqual(Book.objects.get(), book_2) + + user_3 = User(id=3).save() + book_3 = Book(id=3, author=user_3).save() + + user_3.delete() + # Deleting user_3 should also delete book_3 + self.assertEqual(Book.objects.count(), 1) + self.assertEqual(Book.objects.get(), book_2) + def test_reverse_delete_rule_with_document_inheritance(self): """Ensure that a referenced document is also deleted upon deletion of a child document. @@ -2180,7 +2317,7 @@ class InstanceTest(unittest.TestCase): pickle_doc = PickleDynamicTest( name="test", number=1, string="One", lists=['1', '2']) - pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") + pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved pickle_doc.save() @@ -2683,6 +2820,32 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + def test_shard_key_in_embedded_document(self): + class Foo(EmbeddedDocument): + foo = StringField() + + class Bar(Document): + meta = { + 'shard_key': ('foo.foo',) + } + foo = EmbeddedDocumentField(Foo) + bar = StringField() + + foo_doc = Foo(foo='hello') + bar_doc = Bar(foo=foo_doc, bar='world') + bar_doc.save() + + self.assertTrue(bar_doc.id is not None) + + bar_doc.bar = 'baz' + bar_doc.save() + + def change_shard_key(): + bar_doc.foo.foo = 'something' + bar_doc.save() + + self.assertRaises(OperationError, change_shard_key) + def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2765,6 +2928,20 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 42) + def test_positional_creation_embedded(self): + """Ensure that embedded document may be created using positional arguments. + """ + job = self.Job("Test Job", 4) + self.assertEqual(job.name, "Test Job") + self.assertEqual(job.years, 4) + + def test_mixed_creation_embedded(self): + """Ensure that embedded document may be created using mixed arguments. + """ + job = self.Job("Test Job", years=4) + self.assertEqual(job.name, "Test Job") + self.assertEqual(job.years, 4) + def test_mixed_creation_dynamic(self): """Ensure that document may be created using mixed arguments. """ diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 0089f60d..ef35874d 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import sys + +import six from nose.plugins.skip import SkipTest sys.path[0:0] = [""] @@ -10,6 +12,7 @@ import uuid import math import itertools import re +import six try: import dateutil @@ -19,6 +22,10 @@ except ImportError: from decimal import Decimal from bson import Binary, DBRef, ObjectId +try: + from bson.int64 import Int64 +except ImportError: + Int64 = long from mongoengine import * from mongoengine.connection import get_db @@ -399,20 +406,37 @@ class FieldTest(unittest.TestCase): class Person(Document): height = FloatField(min_value=0.1, max_value=3.5) + class BigPerson(Document): + height = FloatField() + person = Person() person.height = 1.89 person.validate() person.height = '2.0' self.assertRaises(ValidationError, person.validate) + person.height = 0.01 self.assertRaises(ValidationError, person.validate) + person.height = 4.0 self.assertRaises(ValidationError, person.validate) person_2 = Person(height='something invalid') self.assertRaises(ValidationError, person_2.validate) + big_person = BigPerson() + + for value, value_type in enumerate(six.integer_types): + big_person.height = value_type(value) + big_person.validate() + + big_person.height = 2 ** 500 + big_person.validate() + + big_person.height = 2 ** 100000 # Too big for a float value + self.assertRaises(ValidationError, big_person.validate) + def test_decimal_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ @@ -1184,6 +1208,19 @@ class FieldTest(unittest.TestCase): simple = simple.reload() self.assertEqual(simple.widgets, [4]) + def test_list_field_with_negative_indices(self): + + class Simple(Document): + widgets = ListField() + + simple = Simple(widgets=[1, 2, 3, 4]).save() + simple.widgets[-1] = 5 + self.assertEqual(['widgets.3'], simple._changed_fields) + simple.save() + + simple = simple.reload() + self.assertEqual(simple.widgets, [1, 2, 3, 5]) + def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" @@ -1563,6 +1600,29 @@ class FieldTest(unittest.TestCase): actions__friends__operation='drink', actions__friends__object='beer').count()) + def test_map_field_unicode(self): + + class Info(EmbeddedDocument): + description = StringField() + value_list = ListField(field=StringField()) + + class BlogPost(Document): + info_dict = MapField(field=EmbeddedDocumentField(Info)) + + BlogPost.drop_collection() + + tree = BlogPost(info_dict={ + u"éééé": { + 'description': u"VALUE: éééé" + } + }) + + tree.save() + + self.assertEqual(BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, u"VALUE: éééé") + + BlogPost.drop_collection() + def test_embedded_db_field(self): class Embedded(EmbeddedDocument): @@ -1599,6 +1659,8 @@ class FieldTest(unittest.TestCase): name = StringField() preferences = EmbeddedDocumentField(PersonPreferences) + Person.drop_collection() + person = Person(name='Test User') person.preferences = 'My Preferences' self.assertRaises(ValidationError, person.validate) @@ -1631,12 +1693,39 @@ class FieldTest(unittest.TestCase): content = StringField() author = EmbeddedDocumentField(User) + BlogPost.drop_collection() + post = BlogPost(content='What I did today...') post.author = PowerUser(name='Test User', power=47) post.save() self.assertEqual(47, BlogPost.objects.first().author.power) + def test_embedded_document_inheritance_with_list(self): + """Ensure that nested list of subclassed embedded documents is + handled correctly. + """ + + class Group(EmbeddedDocument): + name = StringField() + content = ListField(StringField()) + + class Basedoc(Document): + groups = ListField(EmbeddedDocumentField(Group)) + meta = {'abstract': True} + + class User(Basedoc): + doctype = StringField(require=True, default='userdata') + + User.drop_collection() + + content = ['la', 'le', 'lu'] + group = Group(name='foo', content=content) + foobar = User(groups=[group]) + foobar.save() + + self.assertEqual(content, User.objects.first().groups[0].content) + def test_reference_validation(self): """Ensure that invalid docment objects cannot be assigned to reference fields. @@ -2329,6 +2418,91 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() + def test_drop_abstract_document(self): + """Ensure that an abstract document cannot be dropped given it + has no underlying collection. + """ + class AbstractDoc(Document): + name = StringField() + meta = {"abstract": True} + + self.assertRaises(OperationError, AbstractDoc.drop_collection) + + def test_reference_class_with_abstract_parent(self): + """Ensure that a class with an abstract parent can be referenced. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sister(name="Alice") + sister.save() + brother = Brother(name="Bob", sibling=sister) + brother.save() + + self.assertEquals(Brother.objects[0].sibling.name, sister.name) + + Sister.drop_collection() + Brother.drop_collection() + + def test_reference_abstract_class(self): + """Ensure that an abstract class instance cannot be used in the + reference of that abstract class. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Sister(Sibling): + pass + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + Sister.drop_collection() + Brother.drop_collection() + + sister = Sibling(name="Alice") + brother = Brother(name="Bob", sibling=sister) + self.assertRaises(ValidationError, brother.save) + + Sister.drop_collection() + Brother.drop_collection() + + def test_abstract_reference_base_type(self): + """Ensure that an an abstract reference fails validation when given a + Document that does not inherit from the abstract type. + """ + class Sibling(Document): + name = StringField() + meta = {"abstract": True} + + class Brother(Sibling): + sibling = ReferenceField(Sibling) + + class Mother(Document): + name = StringField() + + Brother.drop_collection() + Mother.drop_collection() + + mother = Mother(name="Carol") + mother.save() + brother = Brother(name="Bob", sibling=mother) + self.assertRaises(ValidationError, brother.save) + + Brother.drop_collection() + Mother.drop_collection() + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """ @@ -3353,7 +3527,7 @@ class FieldTest(unittest.TestCase): def __init__(self, **kwargs): super(EnumField, self).__init__(**kwargs) - def to_mongo(self, value): + def to_mongo(self, value, **kwargs): return value def to_python(self, value): @@ -3520,6 +3694,19 @@ class FieldTest(unittest.TestCase): self.assertRaises(FieldDoesNotExist, test) + def test_long_field_is_considered_as_int64(self): + """ + Tests that long fields are stored as long in mongo, even if long value + is small enough to be an int. + """ + class TestLongFieldConsideredAsInt64(Document): + some_long = LongField() + + doc = TestLongFieldConsideredAsInt64(some_long=42).save() + db = get_db() + self.assertTrue(isinstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64)) + self.assertTrue(isinstance(doc.some_long, six.integer_types)) + class EmbeddedDocumentListFieldTestCase(unittest.TestCase): @@ -3907,6 +4094,17 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase): # modified self.assertEqual(number, 2) + def test_unicode(self): + """ + Tests that unicode strings handled correctly + """ + post = self.BlogPost(comments=[ + self.Comments(author='user1', message=u'сообщение'), + self.Comments(author='user2', message=u'хабарлама') + ]).save() + self.assertEqual(post.comments.get(message=u'сообщение').author, + 'user1') + def test_save(self): """ Tests the save method of a List of Embedded Documents. diff --git a/tests/fixtures.py b/tests/fixtures.py index b3bf73e8..d8eb8487 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -26,7 +26,7 @@ class NewDocumentPickleTest(Document): new_field = StringField() -class PickleDyanmicEmbedded(DynamicEmbeddedDocument): +class PickleDynamicEmbedded(DynamicEmbeddedDocument): date = DateTimeField(default=datetime.now) diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py index 6fc83e02..ef62d876 100644 --- a/tests/migration/__init__.py +++ b/tests/migration/__init__.py @@ -1,8 +1,11 @@ +import unittest + from convert_to_new_inheritance_model import * from decimalfield_as_float import * -from refrencefield_dbref_to_object_id import * +from referencefield_dbref_to_object_id import * from turn_off_inheritance import * from uuidfield_to_binary import * + if __name__ == '__main__': unittest.main() diff --git a/tests/migration/refrencefield_dbref_to_object_id.py b/tests/migration/referencefield_dbref_to_object_id.py similarity index 100% rename from tests/migration/refrencefield_dbref_to_object_id.py rename to tests/migration/referencefield_dbref_to_object_id.py diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 944c6fc1..9d926803 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -680,12 +680,21 @@ class QuerySetTest(unittest.TestCase): def test_upsert_one(self): self.Person.drop_collection() - self.Person.objects(name="Bob", age=30).update_one(upsert=True) + bob = self.Person.objects(name="Bob", age=30).upsert_one() - bob = self.Person.objects.first() self.assertEqual("Bob", bob.name) self.assertEqual(30, bob.age) + bob.name = "Bobby" + bob.save() + + bobby = self.Person.objects(name="Bobby", age=30).upsert_one() + + self.assertEqual("Bobby", bobby.name) + self.assertEqual(30, bobby.age) + self.assertEqual(bob.id, bobby.id) + + def test_set_on_insert(self): self.Person.drop_collection() @@ -2757,25 +2766,15 @@ class QuerySetTest(unittest.TestCase): avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) - self.assertAlmostEqual( - int(self.Person.objects.aggregate_average('age')), avg - ) self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.average('age')), avg) - self.assertEqual( - int(self.Person.objects.aggregate_average('age')), avg - ) # dot notation self.Person( name='person meta', person_meta=self.PersonMeta(weight=0)).save() self.assertAlmostEqual( int(self.Person.objects.average('person_meta.weight')), 0) - self.assertAlmostEqual( - int(self.Person.objects.aggregate_average('person_meta.weight')), - 0 - ) for i, weight in enumerate(ages): self.Person( @@ -2784,19 +2783,11 @@ class QuerySetTest(unittest.TestCase): self.assertAlmostEqual( int(self.Person.objects.average('person_meta.weight')), avg ) - self.assertAlmostEqual( - int(self.Person.objects.aggregate_average('person_meta.weight')), - avg - ) self.Person(name='test meta none').save() self.assertEqual( int(self.Person.objects.average('person_meta.weight')), avg ) - self.assertEqual( - int(self.Person.objects.aggregate_average('person_meta.weight')), - avg - ) # test summing over a filtered queryset over_50 = [a for a in ages if a >= 50] @@ -2805,10 +2796,6 @@ class QuerySetTest(unittest.TestCase): self.Person.objects.filter(age__gte=50).average('age'), avg ) - self.assertEqual( - self.Person.objects.filter(age__gte=50).aggregate_average('age'), - avg - ) def test_sum(self): """Ensure that field can be summed over correctly. @@ -2818,15 +2805,9 @@ class QuerySetTest(unittest.TestCase): self.Person(name='test%s' % i, age=age).save() self.assertEqual(self.Person.objects.sum('age'), sum(ages)) - self.assertEqual( - self.Person.objects.aggregate_sum('age'), sum(ages) - ) self.Person(name='ageless person').save() self.assertEqual(self.Person.objects.sum('age'), sum(ages)) - self.assertEqual( - self.Person.objects.aggregate_sum('age'), sum(ages) - ) for i, age in enumerate(ages): self.Person(name='test meta%s' % @@ -2835,26 +2816,15 @@ class QuerySetTest(unittest.TestCase): self.assertEqual( self.Person.objects.sum('person_meta.weight'), sum(ages) ) - self.assertEqual( - self.Person.objects.aggregate_sum('person_meta.weight'), - sum(ages) - ) self.Person(name='weightless person').save() self.assertEqual(self.Person.objects.sum('age'), sum(ages)) - self.assertEqual( - self.Person.objects.aggregate_sum('age'), sum(ages) - ) # test summing over a filtered queryset self.assertEqual( self.Person.objects.filter(age__gte=50).sum('age'), sum([a for a in ages if a >= 50]) ) - self.assertEqual( - self.Person.objects.filter(age__gte=50).aggregate_sum('age'), - sum([a for a in ages if a >= 50]) - ) def test_embedded_average(self): class Pay(EmbeddedDocument): @@ -2867,21 +2837,12 @@ class QuerySetTest(unittest.TestCase): Doc.drop_collection() - Doc(name=u"Wilson Junior", - pay=Pay(value=150)).save() + Doc(name='Wilson Junior', pay=Pay(value=150)).save() + Doc(name='Isabella Luanna', pay=Pay(value=530)).save() + Doc(name='Tayza mariana', pay=Pay(value=165)).save() + Doc(name='Eliana Costa', pay=Pay(value=115)).save() - Doc(name=u"Isabella Luanna", - pay=Pay(value=530)).save() - - Doc(name=u"Tayza mariana", - pay=Pay(value=165)).save() - - Doc(name=u"Eliana Costa", - pay=Pay(value=115)).save() - - self.assertEqual( - Doc.objects.average('pay.value'), - 240) + self.assertEqual(Doc.objects.average('pay.value'), 240) def test_embedded_array_average(self): class Pay(EmbeddedDocument): @@ -2889,26 +2850,16 @@ class QuerySetTest(unittest.TestCase): class Doc(Document): name = StringField() - pay = EmbeddedDocumentField( - Pay) + pay = EmbeddedDocumentField(Pay) Doc.drop_collection() - Doc(name=u"Wilson Junior", - pay=Pay(values=[150, 100])).save() + Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() + Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() + Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() + Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() - Doc(name=u"Isabella Luanna", - pay=Pay(values=[530, 100])).save() - - Doc(name=u"Tayza mariana", - pay=Pay(values=[165, 100])).save() - - Doc(name=u"Eliana Costa", - pay=Pay(values=[115, 100])).save() - - self.assertEqual( - Doc.objects.average('pay.values'), - 170) + self.assertEqual(Doc.objects.average('pay.values'), 170) def test_array_average(self): class Doc(Document): @@ -2921,9 +2872,7 @@ class QuerySetTest(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual( - Doc.objects.average('values'), - 170) + self.assertEqual(Doc.objects.average('values'), 170) def test_embedded_sum(self): class Pay(EmbeddedDocument): @@ -2931,26 +2880,16 @@ class QuerySetTest(unittest.TestCase): class Doc(Document): name = StringField() - pay = EmbeddedDocumentField( - Pay) + pay = EmbeddedDocumentField(Pay) Doc.drop_collection() - Doc(name=u"Wilson Junior", - pay=Pay(value=150)).save() + Doc(name='Wilson Junior', pay=Pay(value=150)).save() + Doc(name='Isabella Luanna', pay=Pay(value=530)).save() + Doc(name='Tayza mariana', pay=Pay(value=165)).save() + Doc(name='Eliana Costa', pay=Pay(value=115)).save() - Doc(name=u"Isabella Luanna", - pay=Pay(value=530)).save() - - Doc(name=u"Tayza mariana", - pay=Pay(value=165)).save() - - Doc(name=u"Eliana Costa", - pay=Pay(value=115)).save() - - self.assertEqual( - Doc.objects.sum('pay.value'), - 960) + self.assertEqual(Doc.objects.sum('pay.value'), 960) def test_embedded_array_sum(self): class Pay(EmbeddedDocument): @@ -2958,26 +2897,16 @@ class QuerySetTest(unittest.TestCase): class Doc(Document): name = StringField() - pay = EmbeddedDocumentField( - Pay) + pay = EmbeddedDocumentField(Pay) Doc.drop_collection() - Doc(name=u"Wilson Junior", - pay=Pay(values=[150, 100])).save() + Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() + Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() + Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() + Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() - Doc(name=u"Isabella Luanna", - pay=Pay(values=[530, 100])).save() - - Doc(name=u"Tayza mariana", - pay=Pay(values=[165, 100])).save() - - Doc(name=u"Eliana Costa", - pay=Pay(values=[115, 100])).save() - - self.assertEqual( - Doc.objects.sum('pay.values'), - 1360) + self.assertEqual(Doc.objects.sum('pay.values'), 1360) def test_array_sum(self): class Doc(Document): @@ -2990,9 +2919,7 @@ class QuerySetTest(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual( - Doc.objects.sum('values'), - 1360) + self.assertEqual(Doc.objects.sum('values'), 1360) def test_distinct(self): """Ensure that the QuerySet.distinct method works. @@ -3604,6 +3531,15 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(MyDoc.objects.count(), 10) self.assertEqual(MyDoc.objects.none().count(), 0) + def test_count_list_embedded(self): + class B(EmbeddedDocument): + c = StringField() + + class A(Document): + b = ListField(EmbeddedDocumentField(B)) + + self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0) + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ @@ -4105,6 +4041,10 @@ class QuerySetTest(unittest.TestCase): Foo(shape="circle", color="purple", thick=False)]) b2.save() + b3 = Bar(foo=[Foo(shape="square", thick=True), + Foo(shape="circle", color="purple", thick=False)]) + b3.save() + ak = list( Bar.objects(foo__match={'shape': "square", "color": "purple"})) self.assertEqual([b1], ak) @@ -4116,6 +4056,22 @@ class QuerySetTest(unittest.TestCase): ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) self.assertEqual([b1], ak) + ak = list( + Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True})) + self.assertEqual([b1, b2], ak) + + ak = list( + Bar.objects(foo__match={'shape': "square", "color__exists": True})) + self.assertEqual([b1, b2], ak) + + ak = list( + Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False})) + self.assertEqual([b3], ak) + + ak = list( + Bar.objects(foo__match={'shape': "square", "color__exists": False})) + self.assertEqual([b3], ak) + def test_upsert_includes_cls(self): """Upserts should include _cls information for inheritable classes """ @@ -4156,7 +4112,11 @@ class QuerySetTest(unittest.TestCase): def test_read_preference(self): class Bar(Document): - pass + txt = StringField() + + meta = { + 'indexes': [ 'txt' ] + } Bar.drop_collection() bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) @@ -4168,9 +4128,51 @@ class QuerySetTest(unittest.TestCase): error_class = TypeError self.assertRaises(error_class, Bar.objects, read_preference='Primary') + # read_preference as a kwarg bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) self.assertEqual( bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) + + # read_preference as a query set method + bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) + + # read_preference after skip + bars = Bar.objects.skip(1) \ + .read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) + + # read_preference after limit + bars = Bar.objects.limit(1) \ + .read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) + + # read_preference after order_by + bars = Bar.objects.order_by('txt') \ + .read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) + + # read_preference after hint + bars = Bar.objects.hint([('txt', 1)]) \ + .read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._cursor._Cursor__read_preference, + ReadPreference.SECONDARY_PREFERRED) def test_json_simple(self): @@ -4824,5 +4826,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(1, Doc.objects(item__type__="axe").count()) + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index a543317a..1cb8223d 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -224,6 +224,10 @@ class TransformTest(unittest.TestCase): self.assertEqual(1, Doc.objects(item__type__="axe").count()) self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) + Doc.objects(id=doc.id).update(set__item__type__='sword') + self.assertEqual(1, Doc.objects(item__type__="sword").count()) + self.assertEqual(0, Doc.objects(item__type__="axe").count()) + def test_understandable_error_raised(self): class Event(Document): title = StringField() diff --git a/tests/test_connection.py b/tests/test_connection.py index 1b7b7a22..b2f7406e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -8,6 +8,7 @@ try: import unittest2 as unittest except ImportError: import unittest +from nose.plugins.skip import SkipTest import pymongo from bson.tz_util import utc @@ -51,6 +52,42 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + def test_connect_in_mocking(self): + """Ensure that the connect() method works properly in mocking. + """ + try: + import mongomock + except ImportError: + raise SkipTest('you need mongomock installed to run this testcase') + + connect('mongoenginetest', host='mongomock://localhost') + conn = get_connection() + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2') + conn = get_connection('testdb2') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3') + conn = get_connection('testdb3') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect('mongoenginetest4', is_mock=True, alias='testdb4') + conn = get_connection('testdb4') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5') + conn = get_connection('testdb5') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6') + conn = get_connection('testdb6') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7') + conn = get_connection('testdb7') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + def test_disconnect(self): """Ensure that the disconnect() method works properly """ @@ -151,7 +188,7 @@ class ConnectionTest(unittest.TestCase): self.assertRaises(ConnectionError, get_db, 'test1') # Authentication succeeds with "authSource" - test_conn2 = connect( + connect( 'mongoenginetest', alias='test2', host=('mongodb://username2:password@localhost/' 'mongoenginetest?authSource=admin') diff --git a/tests/test_dereference.py b/tests/test_dereference.py index e1ae3740..11bdd612 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -12,9 +12,13 @@ from mongoengine.context_managers import query_counter class FieldTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() + @classmethod + def setUpClass(cls): + cls.db = connect(db='mongoenginetest') + + @classmethod + def tearDownClass(cls): + cls.db.drop_database('mongoenginetest') def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. @@ -304,6 +308,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() Post.drop_collection() + SimpleList.drop_collection() u1 = User.objects.create(name='u1') u2 = User.objects.create(name='u2') diff --git a/tests/test_signals.py b/tests/test_signals.py index 8672925c..23da7cd4 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -25,6 +25,8 @@ class SignalTests(unittest.TestCase): connect(db='mongoenginetest') class Author(Document): + # Make the id deterministic for easier testing + id = SequenceField(primary_key=True) name = StringField() def __unicode__(self): @@ -33,7 +35,7 @@ class SignalTests(unittest.TestCase): @classmethod def pre_init(cls, sender, document, *args, **kwargs): signal_output.append('pre_init signal, %s' % cls.__name__) - signal_output.append(str(kwargs['values'])) + signal_output.append(kwargs['values']) @classmethod def post_init(cls, sender, document, **kwargs): @@ -43,48 +45,55 @@ class SignalTests(unittest.TestCase): @classmethod def pre_save(cls, sender, document, **kwargs): signal_output.append('pre_save signal, %s' % document) + signal_output.append(kwargs) @classmethod def pre_save_post_validation(cls, sender, document, **kwargs): signal_output.append('pre_save_post_validation signal, %s' % document) - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') - else: - signal_output.append('Is updated') + if kwargs.pop('created', False): + signal_output.append('Is created') + else: + signal_output.append('Is updated') + signal_output.append(kwargs) @classmethod def post_save(cls, sender, document, **kwargs): dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() signal_output.append('post_save signal, %s' % document) signal_output.append('post_save dirty keys, %s' % dirty_keys) - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') - else: - signal_output.append('Is updated') + if kwargs.pop('created', False): + signal_output.append('Is created') + else: + signal_output.append('Is updated') + signal_output.append(kwargs) @classmethod def pre_delete(cls, sender, document, **kwargs): signal_output.append('pre_delete signal, %s' % document) + signal_output.append(kwargs) @classmethod def post_delete(cls, sender, document, **kwargs): signal_output.append('post_delete signal, %s' % document) + signal_output.append(kwargs) @classmethod def pre_bulk_insert(cls, sender, documents, **kwargs): signal_output.append('pre_bulk_insert signal, %s' % documents) + signal_output.append(kwargs) @classmethod def post_bulk_insert(cls, sender, documents, **kwargs): signal_output.append('post_bulk_insert signal, %s' % documents) - if kwargs.get('loaded', False): + if kwargs.pop('loaded', False): signal_output.append('Is loaded') else: signal_output.append('Not loaded') + signal_output.append(kwargs) + self.Author = Author Author.drop_collection() + Author.id.set_next_value(0) class Another(Document): @@ -96,10 +105,12 @@ class SignalTests(unittest.TestCase): @classmethod def pre_delete(cls, sender, document, **kwargs): signal_output.append('pre_delete signal, %s' % document) + signal_output.append(kwargs) @classmethod def post_delete(cls, sender, document, **kwargs): signal_output.append('post_delete signal, %s' % document) + signal_output.append(kwargs) self.Another = Another Another.drop_collection() @@ -118,6 +129,41 @@ class SignalTests(unittest.TestCase): self.ExplicitId = ExplicitId ExplicitId.drop_collection() + class Post(Document): + title = StringField() + content = StringField() + active = BooleanField(default=False) + + def __unicode__(self): + return self.title + + @classmethod + def pre_bulk_insert(cls, sender, documents, **kwargs): + signal_output.append('pre_bulk_insert signal, %s' % + [(doc, {'active': documents[n].active}) + for n, doc in enumerate(documents)]) + + # make changes here, this is just an example - + # it could be anything that needs pre-validation or looks-ups before bulk bulk inserting + for document in documents: + if not document.active: + document.active = True + signal_output.append(kwargs) + + @classmethod + def post_bulk_insert(cls, sender, documents, **kwargs): + signal_output.append('post_bulk_insert signal, %s' % + [(doc, {'active': documents[n].active}) + for n, doc in enumerate(documents)]) + if kwargs.pop('loaded', False): + signal_output.append('Is loaded') + else: + signal_output.append('Not loaded') + signal_output.append(kwargs) + + self.Post = Post + Post.drop_collection() + # Save up the number of connected signals so that we can check at the # end that all the signals we register get properly unregistered self.pre_signals = ( @@ -147,6 +193,9 @@ class SignalTests(unittest.TestCase): signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId) + signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post) + signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post) + def tearDown(self): signals.pre_init.disconnect(self.Author.pre_init) signals.post_init.disconnect(self.Author.post_init) @@ -163,6 +212,9 @@ class SignalTests(unittest.TestCase): signals.post_save.disconnect(self.ExplicitId.post_save) + signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert) + signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert) + # Check that all our signals got disconnected properly. post_signals = ( len(signals.pre_init.receivers), @@ -199,66 +251,121 @@ class SignalTests(unittest.TestCase): a.save() self.get_signal_output(lambda: None) # eliminate signal output a1 = self.Author.objects(name='Bill Shakespeare')[0] - + self.assertEqual(self.get_signal_output(create_author), [ "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", + {'name': 'Bill Shakespeare'}, "post_init signal, Bill Shakespeare, document._created = True", ]) a1 = self.Author(name='Bill Shakespeare') self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, Bill Shakespeare", + {}, "pre_save_post_validation signal, Bill Shakespeare", "Is created", + {}, "post_save signal, Bill Shakespeare", "post_save dirty keys, ['name']", - "Is created" + "Is created", + {} ]) a1.reload() a1.name = 'William Shakespeare' self.assertEqual(self.get_signal_output(a1.save), [ "pre_save signal, William Shakespeare", + {}, "pre_save_post_validation signal, William Shakespeare", "Is updated", + {}, "post_save signal, William Shakespeare", "post_save dirty keys, ['name']", - "Is updated" + "Is updated", + {} ]) self.assertEqual(self.get_signal_output(a1.delete), [ 'pre_delete signal, William Shakespeare', + {}, 'post_delete signal, William Shakespeare', + {} ]) - signal_output = self.get_signal_output(load_existing_author) - # test signal_output lines separately, because of random ObjectID after object load - self.assertEqual(signal_output[0], + self.assertEqual(self.get_signal_output(load_existing_author), [ "pre_init signal, Author", - ) - self.assertEqual(signal_output[2], - "post_init signal, Bill Shakespeare, document._created = False", - ) + {'id': 2, 'name': 'Bill Shakespeare'}, + "post_init signal, Bill Shakespeare, document._created = False" + ]) - - signal_output = self.get_signal_output(bulk_create_author_with_load) - - # The output of this signal is not entirely deterministic. The reloaded - # object will have an object ID. Hence, we only check part of the output - self.assertEqual(signal_output[3], "pre_bulk_insert signal, []" - ) - self.assertEqual(signal_output[-2:], - ["post_bulk_insert signal, []", - "Is loaded",]) + self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ + 'pre_init signal, Author', + {'name': 'Bill Shakespeare'}, + 'post_init signal, Bill Shakespeare, document._created = True', + 'pre_bulk_insert signal, []', + {}, + 'pre_init signal, Author', + {'id': 3, 'name': 'Bill Shakespeare'}, + 'post_init signal, Bill Shakespeare, document._created = False', + 'post_bulk_insert signal, []', + 'Is loaded', + {} + ]) self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ "pre_init signal, Author", - "{'name': 'Bill Shakespeare'}", + {'name': 'Bill Shakespeare'}, "post_init signal, Bill Shakespeare, document._created = True", "pre_bulk_insert signal, []", + {}, "post_bulk_insert signal, []", "Not loaded", + {} + ]) + + def test_signal_kwargs(self): + """ Make sure signal_kwargs is passed to signals calls. """ + + def live_and_let_die(): + a = self.Author(name='Bill Shakespeare') + a.save(signal_kwargs={'live': True, 'die': False}) + a.delete(signal_kwargs={'live': False, 'die': True}) + + self.assertEqual(self.get_signal_output(live_and_let_die), [ + "pre_init signal, Author", + {'name': 'Bill Shakespeare'}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_save signal, Bill Shakespeare", + {'die': False, 'live': True}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {'die': False, 'live': True}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {'die': False, 'live': True}, + 'pre_delete signal, Bill Shakespeare', + {'die': True, 'live': False}, + 'post_delete signal, Bill Shakespeare', + {'die': True, 'live': False} + ]) + + def bulk_create_author(): + a1 = self.Author(name='Bill Shakespeare') + self.Author.objects.insert([a1], signal_kwargs={'key': True}) + + self.assertEqual(self.get_signal_output(bulk_create_author), [ + 'pre_init signal, Author', + {'name': 'Bill Shakespeare'}, + 'post_init signal, Bill Shakespeare, document._created = True', + 'pre_bulk_insert signal, []', + {'key': True}, + 'pre_init signal, Author', + {'id': 2, 'name': 'Bill Shakespeare'}, + 'post_init signal, Bill Shakespeare, document._created = False', + 'post_bulk_insert signal, []', + 'Is loaded', + {'key': True} ]) def test_queryset_delete_signals(self): @@ -267,7 +374,9 @@ class SignalTests(unittest.TestCase): self.Another(name='Bill Shakespeare').save() self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ 'pre_delete signal, Bill Shakespeare', + {}, 'post_delete signal, Bill Shakespeare', + {} ]) def test_signals_with_explicit_doc_ids(self): @@ -306,6 +415,23 @@ class SignalTests(unittest.TestCase): ei.switch_db("testdb-1", keep_created=False) self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + def test_signals_bulk_insert(self): + def bulk_set_active_post(): + posts = [ + self.Post(title='Post 1'), + self.Post(title='Post 2'), + self.Post(title='Post 3') + ] + self.Post.objects.insert(posts) + + results = self.get_signal_output(bulk_set_active_post) + self.assertEqual(results, [ + "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", + {}, + "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", + 'Is loaded', + {} + ]) if __name__ == '__main__': unittest.main() diff --git a/tox.ini b/tox.ini index e6aa7c81..124c8843 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28} +envlist = {py26,py27,py32,py33,py34,py35,pypy,pypy3}-{mg27,mg28} #envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28,mg30,mgdev} [testenv] @@ -12,3 +12,6 @@ deps = mg28: PyMongo>=2.8,<3.0 mg30: PyMongo>=3.0 mgdev: https://github.com/mongodb/mongo-python-driver/tarball/master +setenv = + PYTHON_EGG_CACHE = {envdir}/python-eggs +passenv = windir