From 3da37fbf6ebbc638b7ebafe68904d8af6b26fbdd Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 20 Jul 2012 16:17:35 +0100 Subject: [PATCH] Updated for pymongo --- mongoengine/__init__.py | 2 +- mongoengine/base.py | 28 +++++++------- mongoengine/document.py | 22 +++++------ mongoengine/fields.py | 28 +++++++------- mongoengine/queryset.py | 85 ++++++++++++++++++++++++----------------- tests/document.py | 65 +++++++++++++++---------------- tests/queryset.py | 66 ++++++++++++++++---------------- 7 files changed, 158 insertions(+), 138 deletions(-) diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 6d18ffe7..d501660d 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ + __author__ = 'Harry Marr' -VERSION = (0, 4, 0) +VERSION = (0, 4, 1) def get_version(): version = '%s.%s' % (VERSION[0], VERSION[1]) diff --git a/mongoengine/base.py b/mongoengine/base.py index 6b74cb07..49a2b436 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -2,8 +2,8 @@ from queryset import QuerySet, QuerySetManager from queryset import DoesNotExist, MultipleObjectsReturned import sys +import bson import pymongo -import pymongo.objectid _document_registry = {} @@ -21,11 +21,11 @@ class BaseField(object): may be added to subclasses of `Document` to define a document's schema. """ - # Fields may have _types inserted into indexes by default + # Fields may have _types inserted into indexes by default _index_with_types = True _geo_index = False - def __init__(self, db_field=None, name=None, required=False, default=None, + def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, validation=None, choices=None): self.db_field = (db_field or name) if not primary_key else '_id' @@ -43,7 +43,7 @@ class BaseField(object): self.choices = choices def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do + """Descriptor for retrieving a value from a field in a document. Do any necessary conversion between Python and MongoDB types. """ if instance is None: @@ -111,9 +111,9 @@ class ObjectIdField(BaseField): # return unicode(value) def to_mongo(self, value): - if not isinstance(value, pymongo.objectid.ObjectId): + if not isinstance(value, bson.objectid.ObjectId): try: - return pymongo.objectid.ObjectId(unicode(value)) + return bson.objectid.ObjectId(unicode(value)) except Exception, e: #e.message attribute has been deprecated since Python 2.6 raise ValidationError(unicode(e)) @@ -124,7 +124,7 @@ class ObjectIdField(BaseField): def validate(self, value): try: - pymongo.objectid.ObjectId(unicode(value)) + bson.objectid.ObjectId(unicode(value)) except: raise ValidationError('Invalid Object ID') @@ -153,8 +153,8 @@ class DocumentMetaclass(type): superclasses.update(base._superclasses) if hasattr(base, '_meta'): - # Ensure that the Document class may be subclassed - - # inheritance may be disabled to remove dependency on + # Ensure that the Document class may be subclassed - + # inheritance may be disabled to remove dependency on # additional fields _cls and _types if base._meta.get('allow_inheritance', True) == False: raise ValueError('Document %s may not be subclassed' % @@ -193,12 +193,12 @@ class DocumentMetaclass(type): module = attrs.get('__module__') - base_excs = tuple(base.DoesNotExist for base in bases + base_excs = tuple(base.DoesNotExist for base in bases if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) exc = subclass_exception('DoesNotExist', base_excs, module) new_class.add_to_class('DoesNotExist', exc) - base_excs = tuple(base.MultipleObjectsReturned for base in bases + base_excs = tuple(base.MultipleObjectsReturned for base in bases if hasattr(base, 'MultipleObjectsReturned')) base_excs = base_excs or (MultipleObjectsReturned,) exc = subclass_exception('MultipleObjectsReturned', base_excs, module) @@ -220,9 +220,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): def __new__(cls, name, bases, attrs): super_new = super(TopLevelDocumentMetaclass, cls).__new__ - # Classes defined in this package are abstract and should not have + # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. - # __metaclass__ is only set on the class with the __metaclass__ + # __metaclass__ is only set on the class with the __metaclass__ # attribute (i.e. it is not set on subclasses). This differentiates # 'real' documents from the 'Document' class if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: @@ -347,7 +347,7 @@ class BaseDocument(object): are present. """ # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) + fields = [(field, getattr(self, name)) for name, field in self._fields.items()] # Ensure that each field is matched to a valid value diff --git a/mongoengine/document.py b/mongoengine/document.py index fef737db..44512184 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -40,16 +40,16 @@ class Document(BaseDocument): presence of `_cls` and `_types`, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. - A :class:`~mongoengine.Document` may use a **Capped Collection** by + A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` dictionary. :attr:`max_documents` is the maximum number of documents that - is allowed to be stored in the collection, and :attr:`max_size` is the - maximum size of the collection in bytes. If :attr:`max_size` is not - specified and :attr:`max_documents` is, :attr:`max_size` defaults to + is allowed to be stored in the collection, and :attr:`max_size` is the + maximum size of the collection in bytes. If :attr:`max_size` is not + specified and :attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB). Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` - dictionary. The value should be a list of field names or tuples of field + dictionary. The value should be a list of field names or tuples of field names. Index direction may be specified by prefixing the field names with a **+** or **-** sign. """ @@ -61,11 +61,11 @@ class Document(BaseDocument): document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an + If ``safe=True`` and the operation is unsuccessful, an :class:`~mongoengine.OperationError` will be raised. :param safe: check if the operation succeeded before returning - :param force_insert: only try to create a new document, don't allow + :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` for skiping """ @@ -123,9 +123,9 @@ class MapReduceDocument(object): """A document returned from a map/reduce query. :param collection: An instance of :class:`~pymongo.Collection` - :param key: Document/result key, often an instance of - :class:`~pymongo.objectid.ObjectId`. If supplied as - an ``ObjectId`` found in the given ``collection``, + :param key: Document/result key, often an instance of + :class:`~bson.objectid.ObjectId`. If supplied as + an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. @@ -140,7 +140,7 @@ class MapReduceDocument(object): @property def object(self): - """Lazy-load the object referenced by ``self.key``. ``self.key`` + """Lazy-load the object referenced by ``self.key``. ``self.key`` should be the ``primary_key``. """ id_field = self._document()._meta['id_field'] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index e95fd65e..a0f619f9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -5,9 +5,9 @@ from operator import itemgetter import re import pymongo -import pymongo.dbref -import pymongo.son -import pymongo.binary +import bson.dbref +import bson.son +import bson.binary import datetime import decimal import gridfs @@ -300,13 +300,13 @@ class ListField(BaseField): if isinstance(self.field, ReferenceField): referenced_type = self.field.document_type - # Get value from document instance if available + # Get value from document instance if available value_list = instance._data.get(self.name) if value_list: deref_list = [] for value in value_list: # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): + if isinstance(value, (bson.dbref.DBRef)): value = _get_db().dereference(value) deref_list.append(referenced_type._from_son(value)) else: @@ -319,7 +319,7 @@ class ListField(BaseField): deref_list = [] for value in value_list: # Dereference DBRefs - if isinstance(value, (dict, pymongo.son.SON)): + if isinstance(value, (dict, bson.son.SON)): deref_list.append(self.field.dereference(value)) else: deref_list.append(value) @@ -444,7 +444,7 @@ class ReferenceField(BaseField): # Get value from document instance if available value = instance._data.get(self.name) # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): + if isinstance(value, (bson.dbref.DBRef)): value = _get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) @@ -466,13 +466,13 @@ class ReferenceField(BaseField): id_ = id_field.to_mongo(id_) collection = self.document_type._meta['collection'] - return pymongo.dbref.DBRef(collection, id_) + return bson.dbref.DBRef(collection, id_) def prepare_query_value(self, op, value): return self.to_mongo(value) def validate(self, value): - assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) + assert isinstance(value, (self.document_type, bson.dbref.DBRef)) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -490,7 +490,7 @@ class GenericReferenceField(BaseField): return self value = instance._data.get(self.name) - if isinstance(value, (dict, pymongo.son.SON)): + if isinstance(value, (dict, bson.son.SON)): instance._data[self.name] = self.dereference(value) return super(GenericReferenceField, self).__get__(instance, owner) @@ -518,7 +518,7 @@ class GenericReferenceField(BaseField): id_ = id_field.to_mongo(id_) collection = document._meta['collection'] - ref = pymongo.dbref.DBRef(collection, id_) + ref = bson.dbref.DBRef(collection, id_) return {'_cls': document.__class__.__name__, '_ref': ref} def prepare_query_value(self, op, value): @@ -534,7 +534,7 @@ class BinaryField(BaseField): super(BinaryField, self).__init__(**kwargs) def to_mongo(self, value): - return pymongo.binary.Binary(value) + return bson.binary.Binary(value) def to_python(self, value): # Returns str not unicode as this is binary data @@ -603,7 +603,7 @@ class GridFSProxy(object): if not self.newfile: self.new_file() self.grid_id = self.newfile._id - self.newfile.writelines(lines) + self.newfile.writelines(lines) def read(self): try: @@ -680,7 +680,7 @@ class FileField(BaseField): def validate(self, value): if value.grid_id is not None: assert isinstance(value, GridFSProxy) - assert isinstance(value.grid_id, pymongo.objectid.ObjectId) + assert isinstance(value.grid_id, bson.objectid.ObjectId) class GeoPointField(BaseField): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 519dda03..bbb6cf34 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -2,9 +2,9 @@ from connection import _get_db import pprint import pymongo -import pymongo.code -import pymongo.dbref -import pymongo.objectid +import bson.code +import bson.dbref +import bson.objectid import re import copy import itertools @@ -424,7 +424,7 @@ class QuerySet(object): } if self._loaded_fields: cursor_args['fields'] = self._loaded_fields - self._cursor_obj = self._collection.find(self._query, + self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor if self._where_clause: @@ -476,8 +476,8 @@ class QuerySet(object): operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not'] geo_operators = ['within_distance', 'within_box', 'near'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -563,8 +563,8 @@ class QuerySet(object): % self._document._class_name) def get_or_create(self, *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a tuple of - ``(object, created)``, where ``object`` is the retrieved or created object + """Retrieve unique object or create, if it doesn't exist. Returns a tuple of + ``(object, created)``, where ``object`` is the retrieved or created object and ``created`` is a boolean specifying whether a new object was created. Raises :class:`~mongoengine.queryset.MultipleObjectsReturned` or `DocumentName.MultipleObjectsReturned` if multiple results are found. @@ -667,8 +667,8 @@ class QuerySet(object): def __len__(self): return self.count() - def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, - scope=None, keep_temp=False): + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): """Perform a map/reduce query using the current query spec and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, it must be the last call made, as it does not return a maleable @@ -678,52 +678,61 @@ class QuerySet(object): and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` tests in ``tests.queryset.QuerySetTest`` for usage examples. - :param map_f: map function, as :class:`~pymongo.code.Code` or string + :param map_f: map function, as :class:`~bson.code.Code` or string :param reduce_f: reduce function, as - :class:`~pymongo.code.Code` or string + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce :param finalize_f: finalize function, an optional function that performs any post-reduction processing. :param scope: values to insert into map/reduce global scope. Optional. :param limit: number of objects from current query to provide to map/reduce method - :param keep_temp: keep temporary table (boolean, default ``True``) Returns an iterator yielding :class:`~mongoengine.document.MapReduceDocument`. - .. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.2**. + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 .. versionadded:: 0.3 """ from document import MapReduceDocument if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.1.1") + raise NotImplementedError("Requires MongoDB >= 1.7.1") map_f_scope = {} - if isinstance(map_f, pymongo.code.Code): + if isinstance(map_f, bson.code.Code): map_f_scope = map_f.scope map_f = unicode(map_f) - map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) + map_f = bson.code.Code(self._sub_js_fields(map_f), map_f_scope) reduce_f_scope = {} - if isinstance(reduce_f, pymongo.code.Code): + if isinstance(reduce_f, bson.code.Code): reduce_f_scope = reduce_f.scope reduce_f = unicode(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f) - reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) + reduce_f = bson.code.Code(reduce_f_code, reduce_f_scope) - mr_args = {'query': self._query, 'keeptemp': keep_temp} + mr_args = {'query': self._query} if finalize_f: finalize_f_scope = {} - if isinstance(finalize_f, pymongo.code.Code): + if isinstance(finalize_f, bson.code.Code): finalize_f_scope = finalize_f.scope finalize_f = unicode(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f) - finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) + finalize_f = bson.code.Code(finalize_f_code, finalize_f_scope) mr_args['finalize'] = finalize_f if scope: @@ -732,8 +741,16 @@ class QuerySet(object): if limit: mr_args['limit'] = limit - results = self._collection.map_reduce(map_f, reduce_f, **mr_args) - results = results.find() + if output == 'inline' and not self._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(self._collection, map_reduce_function)(map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() if self._ordering: results = results.sort(self._ordering) @@ -777,7 +794,7 @@ class QuerySet(object): self._skip, self._limit = key.start, key.stop except IndexError, err: # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. + # bin it, kill it. start = key.start or 0 if start >= 0 and key.stop >= 0 and key.step is None: if start == key.stop: @@ -933,7 +950,7 @@ class QuerySet(object): return mongo_update def update(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on the fields matched by the query. When + """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning @@ -957,7 +974,7 @@ class QuerySet(object): raise OperationError(u'Update failed (%s)' % unicode(err)) def update_one(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on first field matched by the query. When + """Perform an atomic update on first field matched by the query. When ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning @@ -985,8 +1002,8 @@ class QuerySet(object): return self def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ @@ -1009,9 +1026,9 @@ class QuerySet(object): options specified as keyword arguments. As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` + using the :attr:`db_field` keyword argument to a :class:`Field` constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a + with the database field names in Javascript code. When accessing a field, use square-bracket notation, and prefix the MongoEngine field name with a tilde (~). @@ -1037,7 +1054,7 @@ class QuerySet(object): query['$where'] = self._where_clause scope['query'] = query - code = pymongo.code.Code(code, scope=scope) + code = bson.code.Code(code, scope=scope) db = _get_db() return db.eval(code, *fields) diff --git a/tests/document.py b/tests/document.py index c0567632..cb005833 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,5 +1,6 @@ import unittest from datetime import datetime +import bson import pymongo from mongoengine import * @@ -7,7 +8,7 @@ from mongoengine.connection import _get_db class DocumentTest(unittest.TestCase): - + def setUp(self): connect(db='mongoenginetest') self.db = _get_db() @@ -38,7 +39,7 @@ class DocumentTest(unittest.TestCase): name = name_field age = age_field non_field = True - + self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['age'], age_field) self.assertFalse('non_field' in Person._fields) @@ -60,7 +61,7 @@ class DocumentTest(unittest.TestCase): mammal_superclasses = {'Animal': Animal} self.assertEqual(Mammal._superclasses, mammal_superclasses) - + dog_superclasses = { 'Animal': Animal, 'Animal.Mammal': Mammal, @@ -68,7 +69,7 @@ class DocumentTest(unittest.TestCase): self.assertEqual(Dog._superclasses, dog_superclasses) def test_get_subclasses(self): - """Ensure that the correct list of subclasses is retrieved by the + """Ensure that the correct list of subclasses is retrieved by the _get_subclasses method. """ class Animal(Document): pass @@ -78,15 +79,15 @@ class DocumentTest(unittest.TestCase): class Dog(Mammal): pass mammal_subclasses = { - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) - + animal_subclasses = { 'Animal.Fish': Fish, 'Animal.Mammal': Mammal, - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Animal._get_subclasses(), animal_subclasses) @@ -124,7 +125,7 @@ class DocumentTest(unittest.TestCase): self.assertTrue('name' in Employee._fields) self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._meta['collection'], + self.assertEqual(Employee._meta['collection'], self.Person._meta['collection']) # Ensure that MRO error is not raised @@ -146,7 +147,7 @@ class DocumentTest(unittest.TestCase): class Dog(Animal): pass self.assertRaises(ValueError, create_dog_class) - + # Check that _cls etc aren't present on simple documents dog = Animal(name='dog') dog.save() @@ -161,7 +162,7 @@ class DocumentTest(unittest.TestCase): class Employee(self.Person): meta = {'allow_inheritance': False} self.assertRaises(ValueError, create_employee_class) - + # Test the same for embedded documents class Comment(EmbeddedDocument): content = StringField() @@ -186,7 +187,7 @@ class DocumentTest(unittest.TestCase): class Person(Document): name = StringField() meta = {'collection': collection} - + user = Person(name="Test User") user.save() self.assertTrue(collection in self.db.collection_names()) @@ -280,7 +281,7 @@ class DocumentTest(unittest.TestCase): tags = ListField(StringField()) meta = { 'indexes': [ - '-date', + '-date', 'tags', ('category', '-date') ], @@ -296,12 +297,12 @@ class DocumentTest(unittest.TestCase): list(BlogPost.objects) info = BlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index self.assertTrue([('tags', 1)] in info) - + class ExtendedBlogPost(BlogPost): title = StringField() meta = {'indexes': ['title']} @@ -311,7 +312,7 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('title', 1)] in info) @@ -380,7 +381,7 @@ class DocumentTest(unittest.TestCase): class EmailUser(User): email = StringField() - + user = User(username='test', name='test user') user.save() @@ -391,20 +392,20 @@ class DocumentTest(unittest.TestCase): user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() - + user = User(pk='mongo', name='mongo user') user.save() - + user_obj = User.objects.first() self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.pk, 'mongo') - + user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'mongo') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() def test_creation(self): @@ -457,18 +458,18 @@ class DocumentTest(unittest.TestCase): """ class Comment(EmbeddedDocument): content = StringField() - + self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) self.assertFalse('collection' in Comment._meta) - + def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ class Comment(EmbeddedDocument): date = DateTimeField() content = StringField(required=True) - + comment = Comment() self.assertRaises(ValidationError, comment.validate) @@ -496,7 +497,7 @@ class DocumentTest(unittest.TestCase): # Test skipping validation on save class Recipient(Document): email = EmailField(required=True) - + recipient = Recipient(email='root@localhost') self.assertRaises(ValidationError, recipient.save) try: @@ -517,19 +518,19 @@ class DocumentTest(unittest.TestCase): """Ensure that a document may be saved with a custom _id. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, id='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._meta['collection']] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') - + def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, pk='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id @@ -565,7 +566,7 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() def test_save_embedded_document(self): - """Ensure that a document with an embedded document field may be + """Ensure that a document with an embedded document field may be saved in the database. """ class EmployeeDetails(EmbeddedDocument): @@ -591,7 +592,7 @@ class DocumentTest(unittest.TestCase): def test_save_reference(self): """Ensure that a document reference field may be saved in the database. """ - + class BlogPost(Document): meta = {'collection': 'blogpost_1'} content = StringField() @@ -610,8 +611,8 @@ class DocumentTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], - pymongo.dbref.DBRef)) + self.assertTrue(isinstance(post_obj._data['author'], + bson.dbref.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') diff --git a/tests/queryset.py b/tests/queryset.py index 6ca4174d..54660d49 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -3,6 +3,7 @@ import unittest import pymongo +import bson from datetime import datetime, timedelta from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, @@ -58,7 +59,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(people), 2) results = list(people) self.assertTrue(isinstance(results[0], self.Person)) - self.assertTrue(isinstance(results[0].id, (pymongo.objectid.ObjectId, + self.assertTrue(isinstance(results[0].id, (bson.objectid.ObjectId, str, unicode))) self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].age, 20) @@ -162,7 +163,7 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") - + def test_find_array_position(self): """Ensure that query by array position works. """ @@ -177,7 +178,7 @@ class QuerySetTest(unittest.TestCase): posts = ListField(EmbeddedDocumentField(Post)) Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='b')), 0) @@ -226,16 +227,16 @@ class QuerySetTest(unittest.TestCase): person, created = self.Person.objects.get_or_create(age=30) self.assertEqual(person.name, "User B") self.assertEqual(created, False) - + person, created = self.Person.objects.get_or_create(age__lt=30) self.assertEqual(person.name, "User A") self.assertEqual(created, False) - + # Try retrieving when no objects exists - new doc should be created kwargs = dict(age=50, defaults={'name': 'User C'}) person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) - + person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") @@ -328,7 +329,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) - + # Test unsafe expressions person = self.Person(name='Guido van Rossum [.\'Geek\']') person.save() @@ -559,7 +560,7 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) @@ -631,7 +632,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): name = StringField(db_field='doc-name') - comments = ListField(EmbeddedDocumentField(Comment), + comments = ListField(EmbeddedDocumentField(Comment), db_field='cmnts') BlogPost.drop_collection() @@ -733,7 +734,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.update_one(add_to_set__tags='unique') post.reload() self.assertEqual(post.tags.count('unique'), 1) - + BlogPost.drop_collection() def test_update_pull(self): @@ -802,7 +803,7 @@ class QuerySetTest(unittest.TestCase): """ # run a map/reduce operation spanning all posts - results = BlogPost.objects.map_reduce(map_f, reduce_f) + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) self.assertEqual(len(results), 4) @@ -813,7 +814,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(film.value, 3) BlogPost.drop_collection() - + def test_map_reduce_with_custom_object_ids(self): """Ensure that QuerySet.map_reduce works properly with custom primary keys. @@ -822,24 +823,24 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): title = StringField(primary_key=True) tags = ListField(StringField()) - + post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) - + post1.save() post2.save() post3.save() - + self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._meta['id_field'], 'title') - + map_f = """ function() { emit(this._id, 1); } """ - + # reduce to a list of tag ids and counts reduce_f = """ function(key, values) { @@ -850,10 +851,10 @@ class QuerySetTest(unittest.TestCase): return total; } """ - - results = BlogPost.objects.map_reduce(map_f, reduce_f) + + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - + self.assertEqual(results[0].object, post1) self.assertEqual(results[1].object, post2) self.assertEqual(results[2].object, post3) @@ -943,7 +944,7 @@ class QuerySetTest(unittest.TestCase): finalize_f = """ function(key, value) { - // f(sec_since_epoch,y,z) = + // f(sec_since_epoch,y,z) = // log10(z) + ((y*sec_since_epoch) / 45000) z_10 = Math.log(value.z) / Math.log(10); weight = z_10 + ((value.y * value.t_s) / 45000); @@ -962,6 +963,7 @@ class QuerySetTest(unittest.TestCase): results = Link.objects.order_by("-value") results = results.map_reduce(map_f, reduce_f, + "myresults", finalize_f=finalize_f, scope=scope) results = list(results) @@ -1289,12 +1291,12 @@ class QuerySetTest(unittest.TestCase): title = StringField() date = DateTimeField() location = GeoPointField() - + def __unicode__(self): return self.title - + Event.drop_collection() - + event1 = Event(title="Coltrane Motion @ Double Door", date=datetime.now() - timedelta(days=1), location=[41.909889, -87.677137]) @@ -1304,7 +1306,7 @@ class QuerySetTest(unittest.TestCase): event3 = Event(title="Coltrane Motion @ Empty Bottle", date=datetime.now(), location=[41.900474, -87.686638]) - + event1.save() event2.save() event3.save() @@ -1324,24 +1326,24 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(event2 not in events) self.assertTrue(event1 in events) self.assertTrue(event3 in events) - + # ensure ordering is respected by "near" events = Event.objects(location__near=[41.9120459, -87.67892]) events = events.order_by("-date") self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event3, event1, event2]) - + # find events around san francisco point_and_distance = [[37.7566023, -122.415579], 10] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - + # find events within 1 mile of greenpoint, broolyn, nyc, ny point_and_distance = [[40.7237134, -73.9509714], 1] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 0) - + # ensure ordering is respected by "within_distance" point_and_distance = [[41.9120459, -87.67892], 10] events = Event.objects(location__within_distance=point_and_distance) @@ -1354,7 +1356,7 @@ class QuerySetTest(unittest.TestCase): events = Event.objects(location__within_box=box) self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event2.id) - + Event.drop_collection() def test_custom_querysets(self): @@ -1398,7 +1400,7 @@ class QTest(unittest.TestCase): query = {'age': {'$gte': 18}, 'name': 'test'} self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) - + def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" connect(db='mongoenginetest') @@ -1440,7 +1442,7 @@ class QTest(unittest.TestCase): query = Q(x__lt=100) & Q(y__ne='NotMyString') query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, + 'x': {'$lt': 100, '$gt': -100}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, } self.assertEqual(query.to_query(TestDoc), mongo_query)