Added queryset clone support and tests, thanks to hensom

Fixes #130
This commit is contained in:
Ross Lawley 2011-05-18 10:30:07 +01:00
parent e3b4563c2b
commit 31521ccff5
2 changed files with 101 additions and 51 deletions

View File

@ -344,19 +344,19 @@ class QuerySet(object):
self._cursor_obj = None self._cursor_obj = None
self._limit = None self._limit = None
self._skip = None self._skip = None
def clone(self): def clone(self):
"""Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`""" """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`"""
c = self.__class__(self._document, self._collection_obj) c = self.__class__(self._document, self._collection_obj)
copy_props = ('_initial_query', '_query_obj', '_where_clause', copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip') '_timeout', '_limit', '_skip')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
setattr(c, prop, copy.deepcopy(val)) setattr(c, prop, copy.deepcopy(val))
return c return c
@property @property
@ -493,7 +493,7 @@ class QuerySet(object):
} }
if self._loaded_fields: if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict() cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query, self._cursor_obj = self._collection.find(self._query,
**cursor_args) **cursor_args)
# Apply where clauses to cursor # Apply where clauses to cursor
if self._where_clause: if self._where_clause:
@ -553,8 +553,8 @@ class QuerySet(object):
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not'] 'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere'] geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
mongo_query = {} mongo_query = {}
@ -644,8 +644,8 @@ class QuerySet(object):
% self._document._class_name) % self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object ``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises and ``created`` is a boolean specifying whether a new object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found. `DocumentName.MultipleObjectsReturned` if multiple results are found.
@ -857,7 +857,7 @@ class QuerySet(object):
self._skip, self._limit = key.start, key.stop self._skip, self._limit = key.start, key.stop
except IndexError, err: except IndexError, err:
# PyMongo raises an error if key.start == key.stop, catch it, # PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it. # bin it, kill it.
start = key.start or 0 start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None: if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop: if start == key.stop:
@ -1052,7 +1052,7 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query. When """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
@ -1076,7 +1076,7 @@ class QuerySet(object):
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query. When """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
@ -1104,8 +1104,8 @@ class QuerySet(object):
return self return self
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be *fieldname* is the Python name of a field, *fieldname* will be
substituted for the MongoDB name of the field (specified using the substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor). :attr:`name` keyword argument in a field's constructor).
""" """
@ -1128,9 +1128,9 @@ class QuerySet(object):
options specified as keyword arguments. options specified as keyword arguments.
As fields in MongoEngine may use different names in the database (set As fields in MongoEngine may use different names in the database (set
using the :attr:`db_field` keyword argument to a :class:`Field` using the :attr:`db_field` keyword argument to a :class:`Field`
constructor), a mechanism exists for replacing MongoEngine field names constructor), a mechanism exists for replacing MongoEngine field names
with the database field names in Javascript code. When accessing a with the database field names in Javascript code. When accessing a
field, use square-bracket notation, and prefix the MongoEngine field field, use square-bracket notation, and prefix the MongoEngine field
name with a tilde (~). name with a tilde (~).

View File

@ -162,7 +162,7 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30) person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
def test_find_array_position(self): def test_find_array_position(self):
"""Ensure that query by array position works. """Ensure that query by array position works.
""" """
@ -177,7 +177,7 @@ class QuerySetTest(unittest.TestCase):
posts = ListField(EmbeddedDocumentField(Post)) posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection() Blog.drop_collection()
Blog.objects.create(tags=['a', 'b']) Blog.objects.create(tags=['a', 'b'])
self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='a')), 1)
self.assertEqual(len(Blog.objects(tags__0='b')), 0) self.assertEqual(len(Blog.objects(tags__0='b')), 0)
@ -226,16 +226,16 @@ class QuerySetTest(unittest.TestCase):
person, created = self.Person.objects.get_or_create(age=30) person, created = self.Person.objects.get_or_create(age=30)
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertEqual(created, False) self.assertEqual(created, False)
person, created = self.Person.objects.get_or_create(age__lt=30) person, created = self.Person.objects.get_or_create(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertEqual(created, False) self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created # Try retrieving when no objects exists - new doc should be created
kwargs = dict(age=50, defaults={'name': 'User C'}) kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs) person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True) self.assertEqual(created, True)
person = self.Person.objects.get(age=50) person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C") self.assertEqual(person.name, "User C")
@ -328,7 +328,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
# Test unsafe expressions # Test unsafe expressions
person = self.Person(name='Guido van Rossum [.\'Geek\']') person = self.Person(name='Guido van Rossum [.\'Geek\']')
person.save() person.save()
@ -674,7 +674,7 @@ class QuerySetTest(unittest.TestCase):
posts = [post.id for post in q] posts = [post.id for post in q]
published_posts = (post1, post2, post3, post5, post6) published_posts = (post1, post2, post3, post5, post6)
self.assertTrue(all(obj.id in posts for obj in published_posts)) self.assertTrue(all(obj.id in posts for obj in published_posts))
# Check Q object combination # Check Q object combination
date = datetime(2010, 1, 10) date = datetime(2010, 1, 10)
@ -714,7 +714,7 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
@ -786,7 +786,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
name = StringField(db_field='doc-name') name = StringField(db_field='doc-name')
comments = ListField(EmbeddedDocumentField(Comment), comments = ListField(EmbeddedDocumentField(Comment),
db_field='cmnts') db_field='cmnts')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -958,7 +958,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.update_one(unset__hits=1) BlogPost.objects.update_one(unset__hits=1)
post.reload() post.reload()
self.assertEqual(post.hits, None) self.assertEqual(post.hits, None)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self): def test_update_pull(self):
@ -1038,7 +1038,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(film.value, 3) self.assertEqual(film.value, 3)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_map_reduce_with_custom_object_ids(self): def test_map_reduce_with_custom_object_ids(self):
"""Ensure that QuerySet.map_reduce works properly with custom """Ensure that QuerySet.map_reduce works properly with custom
primary keys. primary keys.
@ -1047,24 +1047,24 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
title = StringField(primary_key=True) title = StringField(primary_key=True)
tags = ListField(StringField()) tags = ListField(StringField())
post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"])
post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"])
post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"])
post1.save() post1.save()
post2.save() post2.save()
post3.save() post3.save()
self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._fields['title'].db_field, '_id')
self.assertEqual(BlogPost._meta['id_field'], 'title') self.assertEqual(BlogPost._meta['id_field'], 'title')
map_f = """ map_f = """
function() { function() {
emit(this._id, 1); emit(this._id, 1);
} }
""" """
# reduce to a list of tag ids and counts # reduce to a list of tag ids and counts
reduce_f = """ reduce_f = """
function(key, values) { function(key, values) {
@ -1075,10 +1075,10 @@ class QuerySetTest(unittest.TestCase):
return total; return total;
} }
""" """
results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(results[0].object, post1) self.assertEqual(results[0].object, post1)
self.assertEqual(results[1].object, post2) self.assertEqual(results[1].object, post2)
self.assertEqual(results[2].object, post3) self.assertEqual(results[2].object, post3)
@ -1168,7 +1168,7 @@ class QuerySetTest(unittest.TestCase):
finalize_f = """ finalize_f = """
function(key, value) { function(key, value) {
// f(sec_since_epoch,y,z) = // f(sec_since_epoch,y,z) =
// log10(z) + ((y*sec_since_epoch) / 45000) // log10(z) + ((y*sec_since_epoch) / 45000)
z_10 = Math.log(value.z) / Math.log(10); z_10 = Math.log(value.z) / Math.log(10);
weight = z_10 + ((value.y * value.t_s) / 45000); weight = z_10 + ((value.y * value.t_s) / 45000);
@ -1452,9 +1452,9 @@ class QuerySetTest(unittest.TestCase):
""" """
class Test(Document): class Test(Document):
testdict = DictField() testdict = DictField()
Test.drop_collection() Test.drop_collection()
t = Test(testdict={'f': 'Value'}) t = Test(testdict={'f': 'Value'})
t.save() t.save()
@ -1517,12 +1517,12 @@ class QuerySetTest(unittest.TestCase):
title = StringField() title = StringField()
date = DateTimeField() date = DateTimeField()
location = GeoPointField() location = GeoPointField()
def __unicode__(self): def __unicode__(self):
return self.title return self.title
Event.drop_collection() Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door", event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1), date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137]) location=[41.909889, -87.677137])
@ -1532,7 +1532,7 @@ class QuerySetTest(unittest.TestCase):
event3 = Event(title="Coltrane Motion @ Empty Bottle", event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(), date=datetime.now(),
location=[41.900474, -87.686638]) location=[41.900474, -87.686638])
event1.save() event1.save()
event2.save() event2.save()
event3.save() event3.save()
@ -1552,24 +1552,24 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(event2 not in events) self.assertTrue(event2 not in events)
self.assertTrue(event1 in events) self.assertTrue(event1 in events)
self.assertTrue(event3 in events) self.assertTrue(event3 in events)
# ensure ordering is respected by "near" # ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892]) events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date") events = events.order_by("-date")
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2]) self.assertEqual(list(events), [event3, event1, event2])
# find events within 10 degrees of san francisco # find events within 10 degrees of san francisco
point_and_distance = [[37.7566023, -122.415579], 10] point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2) self.assertEqual(events[0], event2)
# find events within 1 degree of greenpoint, broolyn, nyc, ny # find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1] point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0) self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance" # ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10] point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
@ -1582,7 +1582,7 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_box=box) events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id) self.assertEqual(events[0].id, event2.id)
Event.drop_collection() Event.drop_collection()
def test_spherical_geospatial_operators(self): def test_spherical_geospatial_operators(self):
@ -1692,6 +1692,35 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection() Number.drop_collection()
def test_clone(self):
"""Ensure that cloning clones complex querysets
"""
class Number(Document):
n = IntField()
Number.drop_collection()
for i in xrange(1, 101):
t = Number(n=i)
t.save()
test = Number.objects
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.filter(n__gt=11)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.limit(10)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
Number.drop_collection()
def test_unset_reference(self): def test_unset_reference(self):
class Comment(Document): class Comment(Document):
text = StringField() text = StringField()
@ -1734,7 +1763,7 @@ class QTest(unittest.TestCase):
query = {'age': {'$gte': 18}, 'name': 'test'} query = {'age': {'$gte': 18}, 'name': 'test'}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self): def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly""" """Ensure Q objects handle DBRefs correctly"""
connect(db='mongoenginetest') connect(db='mongoenginetest')
@ -1776,7 +1805,7 @@ class QTest(unittest.TestCase):
query = Q(x__lt=100) & Q(y__ne='NotMyString') query = Q(x__lt=100) & Q(y__ne='NotMyString')
query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
mongo_query = { mongo_query = {
'x': {'$lt': 100, '$gt': -100}, 'x': {'$lt': 100, '$gt': -100},
'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
} }
self.assertEqual(query.to_query(TestDoc), mongo_query) self.assertEqual(query.to_query(TestDoc), mongo_query)
@ -1850,6 +1879,30 @@ class QTest(unittest.TestCase):
for condition in conditions: for condition in conditions:
self.assertTrue(condition in query['$or']) self.assertTrue(condition in query['$or'])
def test_q_clone(self):
class TestDoc(Document):
x = IntField()
TestDoc.drop_collection()
for i in xrange(1, 101):
t = TestDoc(x=i)
t.save()
# Check normal cases work without an error
test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3))
self.assertEqual(test.count(), 3)
test2 = test.clone()
self.assertEqual(test2.count(), 3)
self.assertFalse(test2 == test)
test2.filter(x=6)
self.assertEqual(test2.count(), 1)
self.assertEqual(test.count(), 3)
class QueryFieldListTest(unittest.TestCase): class QueryFieldListTest(unittest.TestCase):
def test_empty(self): def test_empty(self):
q = QueryFieldList() q = QueryFieldList()
@ -1904,8 +1957,5 @@ class QueryFieldListTest(unittest.TestCase):
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()