diff --git a/mongoengine/collection.py b/mongoengine/collection.py index 83cea585..34329790 100644 --- a/mongoengine/collection.py +++ b/mongoengine/collection.py @@ -8,10 +8,40 @@ class QuerySet(object): providing Document objects as the results. """ - def __init__(self, document, cursor): + def __init__(self, document, collection, query): self._document = document - self._cursor = cursor + self._collection = collection + + self._query = QuerySet._transform_query(**query) + self._query['_types'] = self._document._class_name + self._cursor_obj = None + + @property + def _cursor(self): + if not self._cursor_obj: + self._cursor_obj = self._collection.find(self._query) + return self._cursor_obj + @classmethod + def _transform_query(cls, **query): + """Transform a query from Django-style format to Mongo format. + """ + operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', + 'all', 'size', 'exists'] + + mongo_query = {} + for key, value in query.items(): + parts = key.split('__') + # Check for an operator and transform to mongo-style if there is + if parts[-1] in operators: + op = parts.pop() + value = {'$' + op: value} + + key = '.'.join(parts) + mongo_query[key] = value + + return mongo_query + def next(self): """Wrap the result in a Document object. """ @@ -35,6 +65,11 @@ class QuerySet(object): self._cursor.skip(n) return self + def delete(self): + """Delete the documents matched by the query. + """ + self._collection.remove(self._query) + def __iter__(self): return self @@ -56,31 +91,10 @@ class CollectionManager(object): _id = self._collection.save(document._to_mongo()) document._id = _id - def _transform_query(self, **query): - """Transform a query from Django-style format to Mongo format. - """ - operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists'] - - mongo_query = {} - for key, value in query.items(): - parts = key.split('__') - # Check for an operator and transform to mongo-style if there is - if parts[-1] in operators: - op = parts.pop() - value = {'$' + op: value} - - key = '.'.join(parts) - mongo_query[key] = value - - return mongo_query - def find(self, **query): """Query the collection for documents matching the provided query. """ - query = self._transform_query(**query) - query['_types'] = self._document._class_name - return QuerySet(self._document, self._collection.find(query)) + return QuerySet(self._document, self._collection, query) def find_one(self, object_id=None, **query): """Query the collection for document matching the provided query. @@ -92,7 +106,7 @@ class CollectionManager(object): query = object_id else: # Otherwise, use the query provided - query = self._transform_query(**query) + query = QuerySet._transform_query(**query) query['_types'] = self._document._class_name result = self._collection.find_one(query) diff --git a/tests/collection.py b/tests/collection.py index fd6532cb..164ba32a 100644 --- a/tests/collection.py +++ b/tests/collection.py @@ -1,7 +1,7 @@ import unittest import pymongo -from mongoengine.collection import CollectionManager +from mongoengine.collection import CollectionManager, QuerySet from mongoengine.connection import _get_db from mongoengine import * @@ -31,14 +31,13 @@ class CollectionManagerTest(unittest.TestCase): def test_transform_query(self): """Ensure that the _transform_query function operates correctly. """ - manager = self.Person().objects - self.assertEqual(manager._transform_query(name='test', age=30), + self.assertEqual(QuerySet._transform_query(name='test', age=30), {'name': 'test', 'age': 30}) - self.assertEqual(manager._transform_query(age__lt=30), + self.assertEqual(QuerySet._transform_query(age__lt=30), {'age': {'$lt': 30}}) - self.assertEqual(manager._transform_query(friend__age__gte=30), + self.assertEqual(QuerySet._transform_query(friend__age__gte=30), {'friend.age': {'$gte': 30}}) - self.assertEqual(manager._transform_query(name__exists=True), + self.assertEqual(QuerySet._transform_query(name__exists=True), {'name': {'$exists': True}}) def test_find(self): @@ -125,6 +124,21 @@ class CollectionManagerTest(unittest.TestCase): self.db.drop_collection(BlogPost._meta['collection']) + def test_delete(self): + """Ensure that documents are properly deleted from the database. + """ + self.Person(name="User A", age=20).save() + self.Person(name="User B", age=30).save() + self.Person(name="User C", age=40).save() + + self.assertEqual(self.Person.objects.find().count(), 3) + + self.Person.objects.find(age__lt=30).delete() + self.assertEqual(self.Person.objects.find().count(), 2) + + self.Person.objects.find().delete() + self.assertEqual(self.Person.objects.find().count(), 0) + if __name__ == '__main__': unittest.main() diff --git a/tests/document.py b/tests/document.py index 2e8145ba..6e97cc8c 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,4 +1,5 @@ import unittest +import pymongo from mongoengine import * from mongoengine.connection import _get_db @@ -246,13 +247,19 @@ class DocumentTest(unittest.TestCase): author.save() post = BlogPost(content='Watched some TV today... how exciting.') + # Should only reference author when saving post.author = author post.save() post_obj = BlogPost.objects.find_one() + + # Test laziness + self.assertTrue(isinstance(post_obj._data['author'], + pymongo.dbref.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') + # Ensure that the dereferenced object may be changed and saved post_obj.author.age = 25 post_obj.author.save()