Added basic querying - find and find_one

This commit is contained in:
Harry Marr 2009-11-19 01:09:58 +00:00
parent 94be32b387
commit 8ec6fecd23
7 changed files with 141 additions and 12 deletions

View File

@ -44,7 +44,7 @@ class BaseField(object):
try: try:
value = self._to_python(value) value = self._to_python(value)
self._validate(value) self._validate(value)
except ValueError: except (ValueError, AttributeError):
raise ValidationError('Invalid value for field of type "' + raise ValidationError('Invalid value for field of type "' +
self.__class__.__name__ + '"') self.__class__.__name__ + '"')
elif self.required: elif self.required:
@ -145,7 +145,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Set up collection manager, needs the class to have fields so use # Set up collection manager, needs the class to have fields so use
# DocumentMetaclass before instantiating CollectionManager object # DocumentMetaclass before instantiating CollectionManager object
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
setattr(new_class, 'collection', CollectionManager(new_class)) new_class.objects = CollectionManager(new_class)
return new_class return new_class
@ -204,3 +204,11 @@ class BaseDocument(object):
if value is not None: if value is not None:
data[field_name] = field._to_mongo(value) data[field_name] = field._to_mongo(value)
return data return data
@classmethod
def _from_son(cls, son):
"""Create an instance of a Document (subclass) from a PyMongo SOM.
"""
data = dict((str(key), value) for key, value in son.items())
return cls(**data)

View File

@ -1,5 +1,29 @@
from connection import _get_db from connection import _get_db
class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor,
providing Document objects as the results.
"""
def __init__(self, document, cursor):
self._document = document
self._cursor = cursor
def next(self):
"""Wrap the result in a Document object.
"""
return self._document._from_son(self._cursor.next())
def count(self):
"""Count the selected elements in the query.
"""
return self._cursor.count()
def __iter__(self):
return self
class CollectionManager(object): class CollectionManager(object):
def __init__(self, document): def __init__(self, document):
@ -14,4 +38,15 @@ class CollectionManager(object):
def _save_document(self, document): def _save_document(self, document):
"""Save the provided document to the collection. """Save the provided document to the collection.
""" """
_id = self._collection.save(document) _id = self._collection.save(document._to_mongo())
document._id = _id
def find(self, query=None):
"""Query the collection for document matching the provided query.
"""
return QuerySet(self._document, self._collection.find(query))
def find_one(self, query=None):
"""Query the collection for document matching the provided query.
"""
return self._document._from_son(self._collection.find_one(query))

View File

@ -14,4 +14,7 @@ class Document(BaseDocument):
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
def save(self): def save(self):
self.collection._save_document(self._to_mongo()) """Save the document to the database. If the document already exists,
it will be updated, otherwise it will be created.
"""
self.objects._save_document(self)

View File

@ -59,6 +59,8 @@ class EmbeddedDocumentField(BaseField):
super(EmbeddedDocumentField, self).__init__(**kwargs) super(EmbeddedDocumentField, self).__init__(**kwargs)
def _to_python(self, value): def _to_python(self, value):
if not isinstance(value, self.document):
return self.document._from_son(value)
return value return value
def _to_mongo(self, value): def _to_mongo(self, value):
@ -68,6 +70,7 @@ class EmbeddedDocumentField(BaseField):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined. EmbeddedDocument subclass provided when the document was defined.
""" """
# Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document): if not isinstance(value, self.document):
raise ValidationError('Invalid embedded document instance ' raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField') 'provided to an EmbeddedDocumentField')

View File

@ -2,6 +2,7 @@ import unittest
import pymongo import pymongo
from mongomap.collection import CollectionManager from mongomap.collection import CollectionManager
from mongomap.connection import _get_db
from mongomap import * from mongomap import *
@ -15,19 +16,79 @@ class CollectionManagerTest(unittest.TestCase):
age = IntField() age = IntField()
self.Person = Person self.Person = Person
self.db = _get_db()
self.db.drop_collection(self.Person._meta['collection'])
def test_initialisation(self): def test_initialisation(self):
"""Ensure that CollectionManager is correctly initialised. """Ensure that CollectionManager is correctly initialised.
""" """
class Person(Document): self.assertTrue(isinstance(self.Person.objects, CollectionManager))
name = StringField() self.assertEqual(self.Person.objects._collection_name,
age = IntField() self.Person._meta['collection'])
self.assertTrue(isinstance(self.Person.objects._collection,
self.assertTrue(isinstance(Person.collection, CollectionManager))
self.assertEqual(Person.collection._collection_name,
Person._meta['collection'])
self.assertTrue(isinstance(Person.collection._collection,
pymongo.collection.Collection)) pymongo.collection.Collection))
def test_find(self):
"""Ensure that a query returns a valid set of results.
"""
person1 = self.Person(name="User A", age=20)
person1.save()
person2 = self.Person(name="User B", age=30)
person2.save()
# Find all people in the collection
people = self.Person.objects.find()
self.assertEqual(people.count(), 2)
results = list(people)
self.assertTrue(isinstance(results[0], self.Person))
self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1].age, 30)
# Use a query to filter the people found to just person1
people = self.Person.objects.find({'age': 20})
self.assertEqual(people.count(), 1)
person = people.next()
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
"""
person1 = self.Person(name="User A", age=20)
person1.save()
person2 = self.Person(name="User B", age=30)
person2.save()
# Retrieve the first person from the database
person = self.Person.objects.find_one()
self.assertTrue(isinstance(person, self.Person))
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
# Use a query to filter the people found to just person2
person = self.Person.objects.find_one({'age': 30})
self.assertEqual(person.name, "User B")
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query.
"""
class User(EmbeddedDocument):
name = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
post = BlogPost(content='Had a good coffee today...')
post.author = User(name='Test User')
post.save()
result = BlogPost.objects.find_one()
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -97,6 +97,7 @@ class DocumentTest(unittest.TestCase):
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['name'], 'Test User')
self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['age'], 30)
self.assertEqual(str(person_obj['_id']), person._id)
def test_save_custom_id(self): def test_save_custom_id(self):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id.

View File

@ -103,6 +103,24 @@ class FieldTest(unittest.TestCase):
person.preferences = PersonPreferences(food='Cheese', number=47) person.preferences = PersonPreferences(food='Cheese', number=47)
self.assertEqual(person.preferences.food, 'Cheese') self.assertEqual(person.preferences.food, 'Cheese')
def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to
EmbeddedDocumentFields of the superclass' type.
"""
class User(EmbeddedDocument):
name = StringField()
class PowerUser(User):
power = IntField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...')
post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()