Merge branch 'master' of git://github.com/flosch/mongoengine
Added unit test for get_or_create, merged flosch's get with punteney's get. Conflicts: mongoengine/queryset.py
This commit is contained in:
commit
ffc9d7b152
@ -6,7 +6,6 @@ import pymongo
|
|||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
"""A base class for fields in a MongoDB document. Instances of this class
|
"""A base class for fields in a MongoDB document. Instances of this class
|
||||||
may be added to subclasses of `Document` to define a document's schema.
|
may be added to subclasses of `Document` to define a document's schema.
|
||||||
@ -76,7 +75,10 @@ class ObjectIdField(BaseField):
|
|||||||
|
|
||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
if not isinstance(value, pymongo.objectid.ObjectId):
|
if not isinstance(value, pymongo.objectid.ObjectId):
|
||||||
|
try:
|
||||||
return pymongo.objectid.ObjectId(str(value))
|
return pymongo.objectid.ObjectId(str(value))
|
||||||
|
except Exception, e:
|
||||||
|
raise ValidationError(e.message)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def prepare_query_value(self, value):
|
def prepare_query_value(self, value):
|
||||||
|
@ -209,20 +209,19 @@ class ReferenceField(BaseField):
|
|||||||
return super(ReferenceField, self).__get__(instance, owner)
|
return super(ReferenceField, self).__get__(instance, owner)
|
||||||
|
|
||||||
def to_mongo(self, document):
|
def to_mongo(self, document):
|
||||||
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
|
id_field_name = self.document_type._meta['id_field']
|
||||||
# document may already be an object id
|
id_field = self.document_type._fields[id_field_name]
|
||||||
id_ = document
|
|
||||||
else:
|
if isinstance(document, Document):
|
||||||
# We need the id from the saved object to create the DBRef
|
# We need the id from the saved object to create the DBRef
|
||||||
id_ = document.id
|
id_ = document.id
|
||||||
if id_ is None:
|
if id_ is None:
|
||||||
raise ValidationError('You can only reference documents once '
|
raise ValidationError('You can only reference documents once '
|
||||||
'they have been saved to the database')
|
'they have been saved to the database')
|
||||||
|
else:
|
||||||
|
id_ = document
|
||||||
|
|
||||||
# id may be a string rather than an ObjectID object
|
id_ = id_field.to_mongo(id_)
|
||||||
if not isinstance(id_, pymongo.objectid.ObjectId):
|
|
||||||
id_ = pymongo.objectid.ObjectId(id_)
|
|
||||||
|
|
||||||
collection = self.document_type._meta['collection']
|
collection = self.document_type._meta['collection']
|
||||||
return pymongo.dbref.DBRef(collection, id_)
|
return pymongo.dbref.DBRef(collection, id_)
|
||||||
|
|
||||||
|
@ -10,6 +10,13 @@ __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
|
|||||||
# The maximum number of items to display in a QuerySet.__repr__
|
# The maximum number of items to display in a QuerySet.__repr__
|
||||||
REPR_OUTPUT_SIZE = 20
|
REPR_OUTPUT_SIZE = 20
|
||||||
|
|
||||||
|
class DoesNotExist(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleObjectsReturned(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidQueryError(Exception):
|
class InvalidQueryError(Exception):
|
||||||
pass
|
pass
|
||||||
@ -298,18 +305,43 @@ class QuerySet(object):
|
|||||||
|
|
||||||
def get(self, *q_objs, **query):
|
def get(self, *q_objs, **query):
|
||||||
"""Retrieve the the matching object raising
|
"""Retrieve the the matching object raising
|
||||||
'MultipleObjectsReturned' or 'DoesNotExist' exceptions
|
:class:`~mongoengine.queryset.MultipleObjectsReturned` or
|
||||||
if multiple or no results are found.
|
:class:`~mongoengine.queryset.DoesNotExist` exceptions if multiple or
|
||||||
|
no results are found.
|
||||||
"""
|
"""
|
||||||
self.__call__(*q_objs, **query)
|
self.__call__(*q_objs, **query)
|
||||||
count = self.count()
|
count = self.count()
|
||||||
if count == 1:
|
if count == 1:
|
||||||
return self[0]
|
return self[0]
|
||||||
elif count > 1:
|
elif count > 1:
|
||||||
raise MultipleObjectsReturned
|
message = u'%d items returned, instead of 1' % count
|
||||||
|
raise MultipleObjectsReturned(message)
|
||||||
else:
|
else:
|
||||||
raise DoesNotExist
|
raise DoesNotExist('Document not found')
|
||||||
|
|
||||||
|
def get_or_create(self, *q_objs, **query):
|
||||||
|
"""Retreive unique object or create, if it doesn't exist. Raises
|
||||||
|
:class:`~mongoengine.queryset.MultipleObjectsReturned` if multiple
|
||||||
|
results are found. A new document will be created if the document
|
||||||
|
doesn't exists; a dictionary of default values for the new document
|
||||||
|
may be provided as a keyword argument called :attr:`defaults`.
|
||||||
|
"""
|
||||||
|
defaults = query.get('defaults', {})
|
||||||
|
if query.has_key('defaults'):
|
||||||
|
del query['defaults']
|
||||||
|
|
||||||
|
self.__call__(*q_objs, **query)
|
||||||
|
count = self.count()
|
||||||
|
if count == 0:
|
||||||
|
query.update(defaults)
|
||||||
|
doc = self._document(**query)
|
||||||
|
doc.save()
|
||||||
|
return doc
|
||||||
|
elif count == 1:
|
||||||
|
return self.first()
|
||||||
|
else:
|
||||||
|
message = u'%d items returned, instead of 1' % count
|
||||||
|
raise MultipleObjectsReturned(message)
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
"""Retrieve the first object matching the query.
|
"""Retrieve the first object matching the query.
|
||||||
|
@ -259,6 +259,40 @@ class FieldTest(unittest.TestCase):
|
|||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
def test_reference_query_conversion(self):
|
||||||
|
"""Ensure that ReferenceFields can be queried using objects and values
|
||||||
|
of the type of the primary key of the referenced object.
|
||||||
|
"""
|
||||||
|
class Member(Document):
|
||||||
|
user_num = IntField(primary_key=True)
|
||||||
|
|
||||||
|
class BlogPost(Document):
|
||||||
|
title = StringField()
|
||||||
|
author = ReferenceField(Member)
|
||||||
|
|
||||||
|
Member.drop_collection()
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
m1 = Member(user_num=1)
|
||||||
|
m1.save()
|
||||||
|
m2 = Member(user_num=2)
|
||||||
|
m2.save()
|
||||||
|
|
||||||
|
post1 = BlogPost(title='post 1', author=m1)
|
||||||
|
post1.save()
|
||||||
|
|
||||||
|
post2 = BlogPost(title='post 2', author=m2)
|
||||||
|
post2.save()
|
||||||
|
|
||||||
|
post = BlogPost.objects(author=m1.id).first()
|
||||||
|
self.assertEqual(post.id, post1.id)
|
||||||
|
|
||||||
|
post = BlogPost.objects(author=m2.id).first()
|
||||||
|
self.assertEqual(post.id, post2.id)
|
||||||
|
|
||||||
|
Member.drop_collection()
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2,7 +2,8 @@ import unittest
|
|||||||
import pymongo
|
import pymongo
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from mongoengine.queryset import QuerySet, MultipleObjectsReturned, DoesNotExist
|
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
|
||||||
|
DoesNotExist)
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
|
||||||
|
|
||||||
@ -136,7 +137,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
self.assertEqual(person.name, "User A")
|
self.assertEqual(person.name, "User A")
|
||||||
|
|
||||||
def test_find_only_one(self):
|
def test_find_only_one(self):
|
||||||
"""Ensure that a query using find_one returns a valid result.
|
"""Ensure that a query using ``get`` returns at most one result.
|
||||||
"""
|
"""
|
||||||
# Try retrieving when no objects exists
|
# Try retrieving when no objects exists
|
||||||
self.assertRaises(DoesNotExist, self.Person.objects.get)
|
self.assertRaises(DoesNotExist, self.Person.objects.get)
|
||||||
@ -156,6 +157,33 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
person = self.Person.objects.get(age__lt=30)
|
person = self.Person.objects.get(age__lt=30)
|
||||||
self.assertEqual(person.name, "User A")
|
self.assertEqual(person.name, "User A")
|
||||||
|
|
||||||
|
def test_get_or_create(self):
|
||||||
|
"""Ensure that ``get_or_create`` returns one result or creates a new
|
||||||
|
document.
|
||||||
|
"""
|
||||||
|
person1 = self.Person(name="User A", age=20)
|
||||||
|
person1.save()
|
||||||
|
person2 = self.Person(name="User B", age=30)
|
||||||
|
person2.save()
|
||||||
|
|
||||||
|
# Retrieve the first person from the database
|
||||||
|
self.assertRaises(MultipleObjectsReturned,
|
||||||
|
self.Person.objects.get_or_create)
|
||||||
|
|
||||||
|
# Use a query to filter the people found to just person2
|
||||||
|
person = self.Person.objects.get_or_create(age=30)
|
||||||
|
self.assertEqual(person.name, "User B")
|
||||||
|
|
||||||
|
person = self.Person.objects.get_or_create(age__lt=30)
|
||||||
|
self.assertEqual(person.name, "User A")
|
||||||
|
|
||||||
|
# Try retrieving when no objects exists - new doc should be created
|
||||||
|
self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'})
|
||||||
|
|
||||||
|
person = self.Person.objects.get(age=50)
|
||||||
|
self.assertEqual(person.name, "User C")
|
||||||
|
|
||||||
|
|
||||||
def test_filter_chaining(self):
|
def test_filter_chaining(self):
|
||||||
"""Ensure filters can be chained together.
|
"""Ensure filters can be chained together.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user