Merge remote branch 'hmarr/master'

Conflicts:
	mongoengine/fields.py
This commit is contained in:
Florian Schlachter 2010-04-15 23:10:34 +02:00
commit 0a074e52e0
7 changed files with 146 additions and 30 deletions

View File

@ -178,7 +178,7 @@ either a single field name, or a list or tuple of field names::
class User(Document): class User(Document):
username = StringField(unique=True) username = StringField(unique=True)
first_name = StringField() first_name = StringField()
last_name = StringField(unique_with='last_name') last_name = StringField(unique_with='first_name')
Document collections Document collections
==================== ====================

View File

@ -135,8 +135,8 @@ additional keyword argument, :attr:`defaults` may be provided, which will be
used as default values for the new document, in the case that it should need used as default values for the new document, in the case that it should need
to be created:: to be created::
>>> a = User.objects.get_or_create(name='User A', defaults={'age': 30}) >>> a, created = User.objects.get_or_create(name='User A', defaults={'age': 30})
>>> b = User.objects.get_or_create(name='User A', defaults={'age': 40}) >>> b, created = User.objects.get_or_create(name='User A', defaults={'age': 40})
>>> a.name == b.name and a.age == b.age >>> a.name == b.name and a.age == b.age
True True
@ -172,7 +172,7 @@ custom manager methods as you like::
@queryset_manager @queryset_manager
def live_posts(doc_cls, queryset): def live_posts(doc_cls, queryset):
return queryset.order_by('-date') return queryset.filter(published=True)
BlogPost(title='test1', published=False).save() BlogPost(title='test1', published=False).save()
BlogPost(title='test2', published=True).save() BlogPost(title='test2', published=True).save()

View File

@ -1,5 +1,7 @@
from queryset import QuerySet, QuerySetManager from queryset import QuerySet, QuerySetManager
from queryset import DoesNotExist, MultipleObjectsReturned
import sys
import pymongo import pymongo
@ -167,9 +169,24 @@ class DocumentMetaclass(type):
for field in new_class._fields.values(): for field in new_class._fields.values():
field.owner_document = new_class field.owner_document = new_class
module = attrs.pop('__module__')
new_class.add_to_class('DoesNotExist', subclass_exception('DoesNotExist',
tuple(x.DoesNotExist
for k,x in superclasses.items())
or (DoesNotExist,), module))
new_class.add_to_class('MultipleObjectsReturned', subclass_exception('MultipleObjectsReturned',
tuple(x.MultipleObjectsReturned
for k,x in superclasses.items())
or (MultipleObjectsReturned,), module))
return new_class return new_class
def add_to_class(self, name, value):
setattr(self, name, value)
class TopLevelDocumentMetaclass(DocumentMetaclass): class TopLevelDocumentMetaclass(DocumentMetaclass):
"""Metaclass for top-level documents (i.e. documents that have their own """Metaclass for top-level documents (i.e. documents that have their own
collection in the database. collection in the database.
@ -417,3 +434,11 @@ class BaseDocument(object):
if self.id == other.id: if self.id == other.id:
return True return True
return False return False
if sys.version_info < (2, 5):
# Prior to Python 2.5, Exception was an old-style class
def subclass_exception(name, parents, unused):
return types.ClassType(name, parents, {})
else:
def subclass_exception(name, parents, module):
return type(name, parents, {'__module__': module})

View File

@ -1,6 +1,7 @@
from base import BaseField, ObjectIdField, ValidationError, get_document from base import BaseField, ObjectIdField, ValidationError, get_document
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
from connection import _get_db from connection import _get_db
from operator import itemgetter
import re import re
import pymongo import pymongo
@ -12,7 +13,7 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'ObjectIdField', 'ReferenceField', 'ValidationError',
'DecimalField', 'URLField', 'GenericReferenceField', 'DecimalField', 'URLField', 'GenericReferenceField',
'BinaryField', 'EmailField', 'GeoLocationField'] 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField']
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -169,6 +170,9 @@ class DecimalField(BaseField):
value = unicode(value) value = unicode(value)
return decimal.Decimal(value) return decimal.Decimal(value)
def to_mongo(self, value):
return unicode(value)
def validate(self, value): def validate(self, value):
if not isinstance(value, decimal.Decimal): if not isinstance(value, decimal.Decimal):
if not isinstance(value, basestring): if not isinstance(value, basestring):
@ -320,6 +324,23 @@ class ListField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.field.lookup_member(member_name) return self.field.lookup_member(member_name)
class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always
retrieved.
"""
_ordering = None
def __init__(self, field, **kwargs):
if 'ordering' in kwargs.keys():
self._ordering = kwargs.pop('ordering')
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value):
if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value], key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField): class DictField(BaseField):
"""A dictionary field that wraps a standard Python dictionary. This is """A dictionary field that wraps a standard Python dictionary. This is

View File

@ -14,7 +14,6 @@ REPR_OUTPUT_SIZE = 20
class DoesNotExist(Exception): class DoesNotExist(Exception):
pass pass
class MultipleObjectsReturned(Exception): class MultipleObjectsReturned(Exception):
pass pass
@ -26,6 +25,8 @@ class InvalidQueryError(Exception):
class OperationError(Exception): class OperationError(Exception):
pass pass
class InvalidCollectionError(Exception):
pass
RE_TYPE = type(re.compile('')) RE_TYPE = type(re.compile(''))
@ -345,8 +346,9 @@ class QuerySet(object):
def get(self, *q_objs, **query): def get(self, *q_objs, **query):
"""Retrieve the the matching object raising """Retrieve the the matching object raising
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
:class:`~mongoengine.queryset.DoesNotExist` exceptions if multiple or `DocumentName.MultipleObjectsReturned` exception if multiple results and
no results are found. :class:`~mongoengine.queryset.DoesNotExist` or `DocumentName.DoesNotExist`
if no results are found.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -356,16 +358,20 @@ class QuerySet(object):
return self[0] return self[0]
elif count > 1: elif count > 1:
message = u'%d items returned, instead of 1' % count message = u'%d items returned, instead of 1' % count
raise MultipleObjectsReturned(message) raise self._document.MultipleObjectsReturned(message)
else: else:
raise DoesNotExist('Document not found') raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, *q_objs, **query):
"""Retreive unique object or create, if it doesn't exist. Raises """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
:class:`~mongoengine.queryset.MultipleObjectsReturned` if multiple ``(object, created)``, where ``object`` is the retrieved or created object
results are found. A new document will be created if the document and ``created`` is a boolean specifying whether a new object was created. Raises
doesn't exists; a dictionary of default values for the new document :class:`~mongoengine.queryset.MultipleObjectsReturned` or
may be provided as a keyword argument called :attr:`defaults`. `DocumentName.MultipleObjectsReturned` if multiple results are found.
A new document will be created if the document doesn't exists; a
dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -379,12 +385,12 @@ class QuerySet(object):
query.update(defaults) query.update(defaults)
doc = self._document(**query) doc = self._document(**query)
doc.save() doc.save()
return doc return doc, True
elif count == 1: elif count == 1:
return self.first() return self.first(), False
else: else:
message = u'%d items returned, instead of 1' % count message = u'%d items returned, instead of 1' % count
raise MultipleObjectsReturned(message) raise self._document.MultipleObjectsReturned(message)
def first(self): def first(self):
"""Retrieve the first object matching the query. """Retrieve the first object matching the query.
@ -873,10 +879,6 @@ class QuerySet(object):
return repr(data) return repr(data)
class InvalidCollectionError(Exception):
pass
class QuerySetManager(object): class QuerySetManager(object):
def __init__(self, manager_func=None): def __init__(self, manager_func=None):

View File

@ -136,12 +136,16 @@ class FieldTest(unittest.TestCase):
height = DecimalField(min_value=Decimal('0.1'), height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5')) max_value=Decimal('3.5'))
Person.drop_collection()
person = Person() person = Person()
person.height = Decimal('1.89') person.height = Decimal('1.89')
person.validate() person.save()
person.reload()
self.assertEqual(person.height, Decimal('1.89'))
person.height = '2.0' person.height = '2.0'
person.validate() person.save()
person.height = 0.01 person.height = 0.01
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
person.height = Decimal('0.01') person.height = Decimal('0.01')
@ -149,6 +153,8 @@ class FieldTest(unittest.TestCase):
person.height = Decimal('4.0') person.height = Decimal('4.0')
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
Person.drop_collection()
def test_boolean_validation(self): def test_boolean_validation(self):
"""Ensure that invalid values cannot be assigned to boolean fields. """Ensure that invalid values cannot be assigned to boolean fields.
""" """
@ -212,6 +218,37 @@ class FieldTest(unittest.TestCase):
post.comments = 'yay' post.comments = 'yay'
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values.
"""
class Comment(EmbeddedDocument):
order = IntField()
content = StringField()
class BlogPost(Document):
content = StringField()
comments = SortedListField(EmbeddedDocumentField(Comment), ordering='order')
tags = SortedListField(StringField())
post = BlogPost(content='Went for a walk today...')
post.save()
post.tags = ['leisure', 'fun']
post.save()
post.reload()
self.assertEqual(post.tags, ['fun', 'leisure'])
comment1 = Comment(content='Good for you', order=1)
comment2 = Comment(content='Yay.', order=0)
comments = [comment1, comment2]
post.comments = comments
post.save()
post.reload()
self.assertEqual(post.comments[0].content, comment2.content)
self.assertEqual(post.comments[1].content, comment1.content)
BlogPost.drop_collection()
def test_dict_validation(self): def test_dict_validation(self):
"""Ensure that dict types work as expected. """Ensure that dict types work as expected.
""" """

View File

@ -147,6 +147,7 @@ class QuerySetTest(unittest.TestCase):
""" """
# Try retrieving when no objects exists # Try retrieving when no objects exists
self.assertRaises(DoesNotExist, self.Person.objects.get) self.assertRaises(DoesNotExist, self.Person.objects.get)
self.assertRaises(self.Person.DoesNotExist, self.Person.objects.get)
person1 = self.Person(name="User A", age=20) person1 = self.Person(name="User A", age=20)
person1.save() person1.save()
@ -155,6 +156,7 @@ class QuerySetTest(unittest.TestCase):
# Retrieve the first person from the database # Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) self.assertRaises(MultipleObjectsReturned, self.Person.objects.get)
self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get)
# Use a query to filter the people found to just person2 # Use a query to filter the people found to just person2
person = self.Person.objects.get(age=30) person = self.Person.objects.get(age=30)
@ -163,6 +165,9 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30) person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.
@ -175,16 +180,21 @@ class QuerySetTest(unittest.TestCase):
# Retrieve the first person from the database # Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.assertRaises(MultipleObjectsReturned,
self.Person.objects.get_or_create) self.Person.objects.get_or_create)
self.assertRaises(self.Person.MultipleObjectsReturned,
self.Person.objects.get_or_create)
# Use a query to filter the people found to just person2 # Use a query to filter the people found to just person2
person = self.Person.objects.get_or_create(age=30) person, created = self.Person.objects.get_or_create(age=30)
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertEqual(created, False)
person = self.Person.objects.get_or_create(age__lt=30) person, created = self.Person.objects.get_or_create(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created # Try retrieving when no objects exists - new doc should be created
self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'}) person, created = self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'})
self.assertEqual(created, True)
person = self.Person.objects.get(age=50) person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C") self.assertEqual(person.name, "User C")
@ -616,6 +626,27 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self):
"""Ensure that the 'pull' update operation works correctly.
"""
class Comment(EmbeddedDocument):
content = StringField()
class BlogPost(Document):
slug = StringField()
comments = ListField(EmbeddedDocumentField(Comment))
comment1 = Comment(content="test1")
comment2 = Comment(content="test2")
post = BlogPost(slug="test", comments=[comment1, comment2])
post.save()
self.assertTrue(comment2 in post.comments)
BlogPost.objects(slug="test").update(pull__comments__content="test2")
post.reload()
self.assertTrue(comment2 not in post.comments)
def test_order_by(self): def test_order_by(self):
"""Ensure that QuerySets may be ordered. """Ensure that QuerySets may be ordered.
""" """