diff --git a/mongoengine/base.py b/mongoengine/base.py index f5a171aa..024602a9 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -6,7 +6,6 @@ import pymongo class ValidationError(Exception): pass - class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -76,7 +75,10 @@ class ObjectIdField(BaseField): def to_mongo(self, value): if not isinstance(value, pymongo.objectid.ObjectId): - return pymongo.objectid.ObjectId(str(value)) + try: + return pymongo.objectid.ObjectId(str(value)) + except Exception, e: + raise ValidationError(e.message) return value def prepare_query_value(self, value): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 13fa6689..beb8ae00 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -209,20 +209,19 @@ class ReferenceField(BaseField): return super(ReferenceField, self).__get__(instance, owner) def to_mongo(self, document): - if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): - # document may already be an object id - id_ = document - else: + id_field_name = self.document_type._meta['id_field'] + id_field = self.document_type._fields[id_field_name] + + if isinstance(document, Document): # We need the id from the saved object to create the DBRef id_ = document.id if id_ is None: raise ValidationError('You can only reference documents once ' 'they have been saved to the database') + else: + id_ = document - # id may be a string rather than an ObjectID object - if not isinstance(id_, pymongo.objectid.ObjectId): - id_ = pymongo.objectid.ObjectId(id_) - + id_ = id_field.to_mongo(id_) collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, id_) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ff4978d2..9764c699 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -10,6 +10,13 @@ __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 +class DoesNotExist(Exception): + pass + + +class MultipleObjectsReturned(Exception): + pass + class InvalidQueryError(Exception): pass @@ -298,18 +305,43 @@ class QuerySet(object): def get(self, *q_objs, **query): """Retrieve the the matching object raising - 'MultipleObjectsReturned' or 'DoesNotExist' exceptions - if multiple or no results are found. + :class:`~mongoengine.queryset.MultipleObjectsReturned` or + :class:`~mongoengine.queryset.DoesNotExist` exceptions if multiple or + no results are found. """ self.__call__(*q_objs, **query) count = self.count() if count == 1: return self[0] elif count > 1: - raise MultipleObjectsReturned + message = u'%d items returned, instead of 1' % count + raise MultipleObjectsReturned(message) else: - raise DoesNotExist + raise DoesNotExist('Document not found') + def get_or_create(self, *q_objs, **query): + """Retreive unique object or create, if it doesn't exist. Raises + :class:`~mongoengine.queryset.MultipleObjectsReturned` if multiple + results are found. A new document will be created if the document + doesn't exists; a dictionary of default values for the new document + may be provided as a keyword argument called :attr:`defaults`. + """ + defaults = query.get('defaults', {}) + if query.has_key('defaults'): + del query['defaults'] + + self.__call__(*q_objs, **query) + count = self.count() + if count == 0: + query.update(defaults) + doc = self._document(**query) + doc.save() + return doc + elif count == 1: + return self.first() + else: + message = u'%d items returned, instead of 1' % count + raise MultipleObjectsReturned(message) def first(self): """Retrieve the first object matching the query. diff --git a/tests/fields.py b/tests/fields.py index affb0c9b..0ebb143d 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -259,6 +259,40 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() + def test_reference_query_conversion(self): + """Ensure that ReferenceFields can be queried using objects and values + of the type of the primary key of the referenced object. + """ + class Member(Document): + user_num = IntField(primary_key=True) + + class BlogPost(Document): + title = StringField() + author = ReferenceField(Member) + + Member.drop_collection() + BlogPost.drop_collection() + + m1 = Member(user_num=1) + m1.save() + m2 = Member(user_num=2) + m2.save() + + post1 = BlogPost(title='post 1', author=m1) + post1.save() + + post2 = BlogPost(title='post 2', author=m2) + post2.save() + + post = BlogPost.objects(author=m1.id).first() + self.assertEqual(post.id, post1.id) + + post = BlogPost.objects(author=m2.id).first() + self.assertEqual(post.id, post2.id) + + Member.drop_collection() + BlogPost.drop_collection() + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index 3e73da86..00f3e461 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2,7 +2,8 @@ import unittest import pymongo from datetime import datetime -from mongoengine.queryset import QuerySet, MultipleObjectsReturned, DoesNotExist +from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, + DoesNotExist) from mongoengine import * @@ -136,7 +137,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(person.name, "User A") def test_find_only_one(self): - """Ensure that a query using find_one returns a valid result. + """Ensure that a query using ``get`` returns at most one result. """ # Try retrieving when no objects exists self.assertRaises(DoesNotExist, self.Person.objects.get) @@ -156,6 +157,33 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") + def test_get_or_create(self): + """Ensure that ``get_or_create`` returns one result or creates a new + document. + """ + person1 = self.Person(name="User A", age=20) + person1.save() + person2 = self.Person(name="User B", age=30) + person2.save() + + # Retrieve the first person from the database + self.assertRaises(MultipleObjectsReturned, + self.Person.objects.get_or_create) + + # Use a query to filter the people found to just person2 + person = self.Person.objects.get_or_create(age=30) + self.assertEqual(person.name, "User B") + + person = self.Person.objects.get_or_create(age__lt=30) + self.assertEqual(person.name, "User A") + + # Try retrieving when no objects exists - new doc should be created + self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'}) + + person = self.Person.objects.get(age=50) + self.assertEqual(person.name, "User C") + + def test_filter_chaining(self): """Ensure filters can be chained together. """