diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index bafb7c19..e0c62132 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -344,19 +344,19 @@ class QuerySet(object): self._cursor_obj = None self._limit = None self._skip = None - + def clone(self): """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`""" c = self.__class__(self._document, self._collection_obj) - + copy_props = ('_initial_query', '_query_obj', '_where_clause', '_loaded_fields', '_ordering', '_snapshot', '_timeout', '_limit', '_skip') - + for prop in copy_props: val = getattr(self, prop) setattr(c, prop, copy.deepcopy(val)) - + return c @property @@ -493,7 +493,7 @@ class QuerySet(object): } if self._loaded_fields: cursor_args['fields'] = self._loaded_fields.as_dict() - self._cursor_obj = self._collection.find(self._query, + self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor if self._where_clause: @@ -553,8 +553,8 @@ class QuerySet(object): operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not'] geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -644,8 +644,8 @@ class QuerySet(object): % self._document._class_name) def get_or_create(self, *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a tuple of - ``(object, created)``, where ``object`` is the retrieved or created object + """Retrieve unique object or create, if it doesn't exist. Returns a tuple of + ``(object, created)``, where ``object`` is the retrieved or created object and ``created`` is a boolean specifying whether a new object was created. Raises :class:`~mongoengine.queryset.MultipleObjectsReturned` or `DocumentName.MultipleObjectsReturned` if multiple results are found. @@ -857,7 +857,7 @@ class QuerySet(object): self._skip, self._limit = key.start, key.stop except IndexError, err: # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. + # bin it, kill it. start = key.start or 0 if start >= 0 and key.stop >= 0 and key.step is None: if start == key.stop: @@ -1052,7 +1052,7 @@ class QuerySet(object): return mongo_update def update(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on the fields matched by the query. When + """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning @@ -1076,7 +1076,7 @@ class QuerySet(object): raise OperationError(u'Update failed (%s)' % unicode(err)) def update_one(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on first field matched by the query. When + """Perform an atomic update on first field matched by the query. When ``safe_update`` is used, the number of affected documents is returned. :param safe: check if the operation succeeded before returning @@ -1104,8 +1104,8 @@ class QuerySet(object): return self def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ @@ -1128,9 +1128,9 @@ class QuerySet(object): options specified as keyword arguments. As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` + using the :attr:`db_field` keyword argument to a :class:`Field` constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a + with the database field names in Javascript code. When accessing a field, use square-bracket notation, and prefix the MongoEngine field name with a tilde (~). diff --git a/tests/queryset.py b/tests/queryset.py index d4d7fb3a..25431782 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -162,7 +162,7 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") - + def test_find_array_position(self): """Ensure that query by array position works. """ @@ -177,7 +177,7 @@ class QuerySetTest(unittest.TestCase): posts = ListField(EmbeddedDocumentField(Post)) Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='b')), 0) @@ -226,16 +226,16 @@ class QuerySetTest(unittest.TestCase): person, created = self.Person.objects.get_or_create(age=30) self.assertEqual(person.name, "User B") self.assertEqual(created, False) - + person, created = self.Person.objects.get_or_create(age__lt=30) self.assertEqual(person.name, "User A") self.assertEqual(created, False) - + # Try retrieving when no objects exists - new doc should be created kwargs = dict(age=50, defaults={'name': 'User C'}) person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) - + person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") @@ -328,7 +328,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) - + # Test unsafe expressions person = self.Person(name='Guido van Rossum [.\'Geek\']') person.save() @@ -674,7 +674,7 @@ class QuerySetTest(unittest.TestCase): posts = [post.id for post in q] published_posts = (post1, post2, post3, post5, post6) self.assertTrue(all(obj.id in posts for obj in published_posts)) - + # Check Q object combination date = datetime(2010, 1, 10) @@ -714,7 +714,7 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) @@ -786,7 +786,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): name = StringField(db_field='doc-name') - comments = ListField(EmbeddedDocumentField(Comment), + comments = ListField(EmbeddedDocumentField(Comment), db_field='cmnts') BlogPost.drop_collection() @@ -958,7 +958,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.update_one(unset__hits=1) post.reload() self.assertEqual(post.hits, None) - + BlogPost.drop_collection() def test_update_pull(self): @@ -1038,7 +1038,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(film.value, 3) BlogPost.drop_collection() - + def test_map_reduce_with_custom_object_ids(self): """Ensure that QuerySet.map_reduce works properly with custom primary keys. @@ -1047,24 +1047,24 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): title = StringField(primary_key=True) tags = ListField(StringField()) - + post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) - + post1.save() post2.save() post3.save() - + self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._meta['id_field'], 'title') - + map_f = """ function() { emit(this._id, 1); } """ - + # reduce to a list of tag ids and counts reduce_f = """ function(key, values) { @@ -1075,10 +1075,10 @@ class QuerySetTest(unittest.TestCase): return total; } """ - + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - + self.assertEqual(results[0].object, post1) self.assertEqual(results[1].object, post2) self.assertEqual(results[2].object, post3) @@ -1168,7 +1168,7 @@ class QuerySetTest(unittest.TestCase): finalize_f = """ function(key, value) { - // f(sec_since_epoch,y,z) = + // f(sec_since_epoch,y,z) = // log10(z) + ((y*sec_since_epoch) / 45000) z_10 = Math.log(value.z) / Math.log(10); weight = z_10 + ((value.y * value.t_s) / 45000); @@ -1452,9 +1452,9 @@ class QuerySetTest(unittest.TestCase): """ class Test(Document): testdict = DictField() - + Test.drop_collection() - + t = Test(testdict={'f': 'Value'}) t.save() @@ -1517,12 +1517,12 @@ class QuerySetTest(unittest.TestCase): title = StringField() date = DateTimeField() location = GeoPointField() - + def __unicode__(self): return self.title - + Event.drop_collection() - + event1 = Event(title="Coltrane Motion @ Double Door", date=datetime.now() - timedelta(days=1), location=[41.909889, -87.677137]) @@ -1532,7 +1532,7 @@ class QuerySetTest(unittest.TestCase): event3 = Event(title="Coltrane Motion @ Empty Bottle", date=datetime.now(), location=[41.900474, -87.686638]) - + event1.save() event2.save() event3.save() @@ -1552,24 +1552,24 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(event2 not in events) self.assertTrue(event1 in events) self.assertTrue(event3 in events) - + # ensure ordering is respected by "near" events = Event.objects(location__near=[41.9120459, -87.67892]) events = events.order_by("-date") self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event3, event1, event2]) - + # find events within 10 degrees of san francisco point_and_distance = [[37.7566023, -122.415579], 10] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - + # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[40.7237134, -73.9509714], 1] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 0) - + # ensure ordering is respected by "within_distance" point_and_distance = [[41.9120459, -87.67892], 10] events = Event.objects(location__within_distance=point_and_distance) @@ -1582,7 +1582,7 @@ class QuerySetTest(unittest.TestCase): events = Event.objects(location__within_box=box) self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event2.id) - + Event.drop_collection() def test_spherical_geospatial_operators(self): @@ -1692,6 +1692,35 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_clone(self): + """Ensure that cloning clones complex querysets + """ + class Number(Document): + n = IntField() + + Number.drop_collection() + + for i in xrange(1, 101): + t = Number(n=i) + t.save() + + test = Number.objects + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.filter(n__gt=11) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.limit(10) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + Number.drop_collection() + def test_unset_reference(self): class Comment(Document): text = StringField() @@ -1734,7 +1763,7 @@ class QTest(unittest.TestCase): query = {'age': {'$gte': 18}, 'name': 'test'} self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) - + def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" connect(db='mongoenginetest') @@ -1776,7 +1805,7 @@ class QTest(unittest.TestCase): query = Q(x__lt=100) & Q(y__ne='NotMyString') query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, + 'x': {'$lt': 100, '$gt': -100}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, } self.assertEqual(query.to_query(TestDoc), mongo_query) @@ -1850,6 +1879,30 @@ class QTest(unittest.TestCase): for condition in conditions: self.assertTrue(condition in query['$or']) + + def test_q_clone(self): + + class TestDoc(Document): + x = IntField() + + TestDoc.drop_collection() + for i in xrange(1, 101): + t = TestDoc(x=i) + t.save() + + # Check normal cases work without an error + test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) + + self.assertEqual(test.count(), 3) + + test2 = test.clone() + self.assertEqual(test2.count(), 3) + self.assertFalse(test2 == test) + + test2.filter(x=6) + self.assertEqual(test2.count(), 1) + self.assertEqual(test.count(), 3) + class QueryFieldListTest(unittest.TestCase): def test_empty(self): q = QueryFieldList() @@ -1904,8 +1957,5 @@ class QueryFieldListTest(unittest.TestCase): self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) - - - if __name__ == '__main__': unittest.main()