Added delete method to QuerySet

This commit is contained in:
Harry Marr 2009-11-21 20:03:31 +00:00
parent 3017dc78ed
commit 8e89c8b37a
3 changed files with 66 additions and 31 deletions

View File

@ -8,10 +8,40 @@ class QuerySet(object):
providing Document objects as the results. providing Document objects as the results.
""" """
def __init__(self, document, cursor): def __init__(self, document, collection, query):
self._document = document self._document = document
self._cursor = cursor self._collection = collection
self._query = QuerySet._transform_query(**query)
self._query['_types'] = self._document._class_name
self._cursor_obj = None
@property
def _cursor(self):
if not self._cursor_obj:
self._cursor_obj = self._collection.find(self._query)
return self._cursor_obj
@classmethod
def _transform_query(cls, **query):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists']
mongo_query = {}
for key, value in query.items():
parts = key.split('__')
# Check for an operator and transform to mongo-style if there is
if parts[-1] in operators:
op = parts.pop()
value = {'$' + op: value}
key = '.'.join(parts)
mongo_query[key] = value
return mongo_query
def next(self): def next(self):
"""Wrap the result in a Document object. """Wrap the result in a Document object.
""" """
@ -35,6 +65,11 @@ class QuerySet(object):
self._cursor.skip(n) self._cursor.skip(n)
return self return self
def delete(self):
"""Delete the documents matched by the query.
"""
self._collection.remove(self._query)
def __iter__(self): def __iter__(self):
return self return self
@ -56,31 +91,10 @@ class CollectionManager(object):
_id = self._collection.save(document._to_mongo()) _id = self._collection.save(document._to_mongo())
document._id = _id document._id = _id
def _transform_query(self, **query):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['neq', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists']
mongo_query = {}
for key, value in query.items():
parts = key.split('__')
# Check for an operator and transform to mongo-style if there is
if parts[-1] in operators:
op = parts.pop()
value = {'$' + op: value}
key = '.'.join(parts)
mongo_query[key] = value
return mongo_query
def find(self, **query): def find(self, **query):
"""Query the collection for documents matching the provided query. """Query the collection for documents matching the provided query.
""" """
query = self._transform_query(**query) return QuerySet(self._document, self._collection, query)
query['_types'] = self._document._class_name
return QuerySet(self._document, self._collection.find(query))
def find_one(self, object_id=None, **query): def find_one(self, object_id=None, **query):
"""Query the collection for document matching the provided query. """Query the collection for document matching the provided query.
@ -92,7 +106,7 @@ class CollectionManager(object):
query = object_id query = object_id
else: else:
# Otherwise, use the query provided # Otherwise, use the query provided
query = self._transform_query(**query) query = QuerySet._transform_query(**query)
query['_types'] = self._document._class_name query['_types'] = self._document._class_name
result = self._collection.find_one(query) result = self._collection.find_one(query)

View File

@ -1,7 +1,7 @@
import unittest import unittest
import pymongo import pymongo
from mongoengine.collection import CollectionManager from mongoengine.collection import CollectionManager, QuerySet
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
from mongoengine import * from mongoengine import *
@ -31,14 +31,13 @@ class CollectionManagerTest(unittest.TestCase):
def test_transform_query(self): def test_transform_query(self):
"""Ensure that the _transform_query function operates correctly. """Ensure that the _transform_query function operates correctly.
""" """
manager = self.Person().objects self.assertEqual(QuerySet._transform_query(name='test', age=30),
self.assertEqual(manager._transform_query(name='test', age=30),
{'name': 'test', 'age': 30}) {'name': 'test', 'age': 30})
self.assertEqual(manager._transform_query(age__lt=30), self.assertEqual(QuerySet._transform_query(age__lt=30),
{'age': {'$lt': 30}}) {'age': {'$lt': 30}})
self.assertEqual(manager._transform_query(friend__age__gte=30), self.assertEqual(QuerySet._transform_query(friend__age__gte=30),
{'friend.age': {'$gte': 30}}) {'friend.age': {'$gte': 30}})
self.assertEqual(manager._transform_query(name__exists=True), self.assertEqual(QuerySet._transform_query(name__exists=True),
{'name': {'$exists': True}}) {'name': {'$exists': True}})
def test_find(self): def test_find(self):
@ -125,6 +124,21 @@ class CollectionManagerTest(unittest.TestCase):
self.db.drop_collection(BlogPost._meta['collection']) self.db.drop_collection(BlogPost._meta['collection'])
def test_delete(self):
"""Ensure that documents are properly deleted from the database.
"""
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=30).save()
self.Person(name="User C", age=40).save()
self.assertEqual(self.Person.objects.find().count(), 3)
self.Person.objects.find(age__lt=30).delete()
self.assertEqual(self.Person.objects.find().count(), 2)
self.Person.objects.find().delete()
self.assertEqual(self.Person.objects.find().count(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,4 +1,5 @@
import unittest import unittest
import pymongo
from mongoengine import * from mongoengine import *
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
@ -246,13 +247,19 @@ class DocumentTest(unittest.TestCase):
author.save() author.save()
post = BlogPost(content='Watched some TV today... how exciting.') post = BlogPost(content='Watched some TV today... how exciting.')
# Should only reference author when saving
post.author = author post.author = author
post.save() post.save()
post_obj = BlogPost.objects.find_one() post_obj = BlogPost.objects.find_one()
# Test laziness
self.assertTrue(isinstance(post_obj._data['author'],
pymongo.dbref.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')
# Ensure that the dereferenced object may be changed and saved
post_obj.author.age = 25 post_obj.author.age = 25
post_obj.author.save() post_obj.author.save()