Merge branch 'dev' into feature/update_lists

This commit is contained in:
Alistair Roche 2011-05-24 08:58:38 +01:00
commit 13935fc335
7 changed files with 476 additions and 102 deletions

View File

@ -7,16 +7,26 @@ import pymongo
import pymongo.objectid import pymongo.objectid
_document_registry = {} class NotRegistered(Exception):
pass
def get_document(name):
return _document_registry[name]
class ValidationError(Exception): class ValidationError(Exception):
pass pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
@ -295,6 +305,30 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
for spec in meta['indexes']] + base_indexes for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes new_class._meta['indexes'] = user_indexes
unique_indexes = cls._unique_with_indexes(new_class)
new_class._meta['unique_indexes'] = unique_indexes
for field_name, field in new_class._fields.items():
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
@classmethod
def _unique_with_indexes(cls, new_class, namespace=""):
unique_indexes = [] unique_indexes = []
for field_name, field in new_class._fields.items(): for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints # Generate a list of indexes needed by uniqueness constraints
@ -320,28 +354,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
unique_fields += unique_with unique_fields += unique_with
# Add the new index to the list # Add the new index to the list
index = [(f, pymongo.ASCENDING) for f in unique_fields] index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields]
unique_indexes.append(index) unique_indexes.append(index)
# Check for custom primary key # Grab any embedded document field unique indexes
if field.primary_key: if field.__class__.__name__ == "EmbeddedDocumentField":
current_pk = new_class._meta['id_field'] field_namespace = "%s." % field_name
if current_pk and current_pk != field_name: unique_indexes += cls._unique_with_indexes(field.document_type,
raise ValueError('Cannot override primary key field') field_namespace)
if not current_pk: return unique_indexes
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
new_class._meta['unique_indexes'] = unique_indexes
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
class BaseDocument(object): class BaseDocument(object):

View File

@ -40,44 +40,54 @@ class Document(BaseDocument):
presence of `_cls` and `_types`, set :attr:`allow_inheritance` to presence of `_cls` and `_types`, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary. ``False`` in the :attr:`meta` dictionary.
A :class:`~mongoengine.Document` may use a **Capped Collection** by A :class:`~mongoengine.Document` may use a **Capped Collection** by
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
dictionary. :attr:`max_documents` is the maximum number of documents that dictionary. :attr:`max_documents` is the maximum number of documents that
is allowed to be stored in the collection, and :attr:`max_size` is the is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB). 10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign. a **+** or **-** sign.
""" """
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True): def save(self, safe=True, force_insert=False, validate=True, write_options=None):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
If ``safe=True`` and the operation is unsuccessful, an If ``safe=True`` and the operation is unsuccessful, an
:class:`~mongoengine.OperationError` will be raised. :class:`~mongoengine.OperationError` will be raised.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
""" """
if validate: if validate:
self.validate() self.validate()
if not write_options:
write_options = {}
doc = self.to_mongo() doc = self.to_mongo()
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe) object_id = collection.insert(doc, safe=safe, **write_options)
else: else:
object_id = collection.save(doc, safe=safe) object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err): if u'duplicate key' in unicode(err):
@ -131,9 +141,9 @@ class MapReduceDocument(object):
"""A document returned from a map/reduce query. """A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of :param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as :class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``, an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property. the object can be accessed via the ``object`` property.
:param value: The result(s) for this key. :param value: The result(s) for this key.
@ -148,7 +158,7 @@ class MapReduceDocument(object):
@property @property
def object(self): def object(self):
"""Lazy-load the object referenced by ``self.key``. ``self.key`` """Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``. should be the ``primary_key``.
""" """
id_field = self._document()._meta['id_field'] id_field = self._document()._meta['id_field']

View File

@ -522,6 +522,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -601,6 +604,7 @@ class GridFSProxy(object):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
def __getattr__(self, name): def __getattr__(self, name):
obj = self.get() obj = self.get()
@ -614,8 +618,12 @@ class GridFSProxy(object):
def get(self, id=None): def get(self, id=None):
if id: if id:
self.grid_id = id self.grid_id = id
if self.grid_id is None:
return None
try: try:
return self.fs.get(id or self.grid_id) if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except: except:
# File has been deleted # File has been deleted
return None return None
@ -645,9 +653,9 @@ class GridFSProxy(object):
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self): def read(self, size=-1):
try: try:
return self.get().read() return self.get().read(size)
except: except:
return None return None
@ -655,6 +663,7 @@ class GridFSProxy(object):
# Delete file from GridFS, FileField still remains # Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id) self.fs.delete(self.grid_id)
self.grid_id = None self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs): def replace(self, file, **kwargs):
self.delete() self.delete()

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True ONLY = True
EXCLUDE = False EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]): def __init__(self, fields=[], value=ONLY, always_include=[]):
self.direction = direction self.value = value
self.fields = set(fields) self.fields = set(fields)
self.always_include = set(always_include) self.always_include = set(always_include)
def as_dict(self): def as_dict(self):
return dict((field, self.direction) for field in self.fields) return dict((field, self.value) for field in self.fields)
def __add__(self, f): def __add__(self, f):
if not self.fields: if not self.fields:
self.fields = f.fields self.fields = f.fields
self.direction = f.direction self.value = f.value
elif self.direction is self.ONLY and f.direction is self.ONLY: elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields) self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields) self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE: elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY: elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.direction = self.ONLY self.value = self.ONLY
self.fields = f.fields - self.fields self.fields = f.fields - self.fields
if self.always_include: if self.always_include:
if self.direction is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self): def reset(self):
self.fields = set([]) self.fields = set([])
self.direction = self.ONLY self.value = self.ONLY
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
@ -551,7 +552,7 @@ class QuerySet(object):
return '.'.join(parts) return '.'.join(parts)
@classmethod @classmethod
def _transform_query(cls, _doc_cls=None, **query): def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
@ -646,7 +647,7 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist." raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name) % self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object ``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises and ``created`` is a boolean specifying whether a new object was created. Raises
@ -656,6 +657,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`. keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
defaults = query.get('defaults', {}) defaults = query.get('defaults', {})
@ -667,7 +672,7 @@ class QuerySet(object):
if count == 0: if count == 0:
query.update(defaults) query.update(defaults)
doc = self._document(**query) doc = self._document(**query)
doc.save() doc.save(write_options=write_options)
return doc, True return doc, True
elif count == 1: elif count == 1:
return self.first(), False return self.first(), False
@ -893,10 +898,8 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.ONLY) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) return self.fields(**fields)
return self
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
@ -905,8 +908,44 @@ class QuerySet(object):
:param fields: fields to exclude :param fields: fields to exclude
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self return self
def all_fields(self): def all_fields(self):
@ -1062,22 +1101,27 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param update: Django-style update keyword arguments :param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if pymongo.version < '1.1.1': if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+') raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
ret = self._collection.update(self._query, update, multi=True, ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret: if ret is not None and 'n' in ret:
return ret['n'] return ret['n']
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
@ -1086,22 +1130,27 @@ class QuerySet(object):
raise OperationError(message) raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
# Explicitly provide 'multi=False' to newer versions of PyMongo # Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True' # as the default may change to 'True'
if pymongo.version >= '1.1.1': if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False, ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
else: else:
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update, ret = self._collection.update(self._query, update,
@ -1284,7 +1333,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():

View File

@ -1,11 +1,23 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import pymongo import pymongo
import pickle
from mongoengine import * from mongoengine import *
from mongoengine.base import BaseField
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleTest(Document):
number = IntField()
string = StringField()
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
class DocumentTest(unittest.TestCase): class DocumentTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -200,6 +212,22 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
self.assertFalse(collection in self.db.collection_names()) self.assertFalse(collection in self.db.collection_names())
def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
"""
class Person(Document):
name = StringField(primary_key=True)
meta = {'collection': 'app'}
user = Person(name="Test User")
user.save()
user_obj = Person.objects[0]
self.assertEqual(user_obj.name, "Test User")
Person.drop_collection()
def test_inherited_collections(self): def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections. """Ensure that subclassed documents don't override parents' collections.
""" """
@ -334,6 +362,10 @@ class DocumentTest(unittest.TestCase):
post2 = BlogPost(title='test2', slug='test') post2 = BlogPost(title='test2', slug='test')
self.assertRaises(OperationError, post2.save) self.assertRaises(OperationError, post2.save)
def test_unique_with(self):
"""Ensure that unique_with constraints are applied to fields.
"""
class Date(EmbeddedDocument): class Date(EmbeddedDocument):
year = IntField(db_field='yr') year = IntField(db_field='yr')
@ -357,6 +389,63 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField()
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_with_embedded_document_and_embedded_unique(self):
"""Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField(unique_with='sub.year')
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
# Now there will be two docs with the same title and year
post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_and_indexes(self): def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by """Ensure that 'unique' constraints aren't overridden by
meta.indexes. meta.indexes.
@ -798,6 +887,25 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection() self.Person.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self):
class A(Document):
pass
class B(A):
foo = BooleanField(unique=True)
A.drop_collection()
B.drop_collection()
A().save()
A().save()
B(foo=True).save()
self.assertEquals(A.objects.count(), 2)
self.assertEquals(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def tearDown(self): def tearDown(self):
self.Person.drop_collection() self.Person.drop_collection()
@ -850,6 +958,43 @@ class DocumentTest(unittest.TestCase):
self.assertTrue(u1 in all_user_set ) self.assertTrue(u1 in all_user_set )
def test_picklable(self):
pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEquals(resurrected, pickle_doc)
resurrected.string = "Working"
resurrected.save()
pickle_doc.reload()
self.assertEquals(resurrected, pickle_doc)
def test_write_options(self):
"""Test that passing write_options works"""
self.Person.drop_collection()
write_options = {"fsync": True}
author, created = self.Person.objects.get_or_create(
name='Test User', write_options=write_options)
author.save(write_options=write_options)
self.Person.objects.update(set__name='Ross', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Ross')
self.Person.objects.update_one(set__name='Test User', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Test User')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -7,6 +7,7 @@ import gridfs
from mongoengine import * from mongoengine import *
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
from mongoengine.base import _document_registry, NotRegistered
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase):
""" """
class Person(Document): class Person(Document):
name = StringField() name = StringField()
person = Person(name='Test User') person = Person(name='Test User')
self.assertEqual(person.id, None) self.assertEqual(person.id, None)
@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080' link.url = 'http://www.google.com:8080'
link.validate() link.validate()
def test_int_validation(self): def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields. """Ensure that invalid values cannot be assigned to int fields.
""" """
@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
person.height = 4.0 person.height = 4.0
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_decimal_validation(self): def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields. """Ensure that invalid values cannot be assigned to decimal fields.
""" """
class Person(Document): class Person(Document):
height = DecimalField(min_value=Decimal('0.1'), height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5')) max_value=Decimal('3.5'))
Person.drop_collection() Person.drop_collection()
@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase):
post.save() post.save()
post.reload() post.reload()
self.assertEqual(post.tags, ['fun', 'leisure']) self.assertEqual(post.tags, ['fun', 'leisure'])
comment1 = Comment(content='Good for you', order=1) comment1 = Comment(content='Good for you', order=1)
comment2 = Comment(content='Yay.', order=0) comment2 = Comment(content='Yay.', order=0)
comments = [comment1, comment2] comments = [comment1, comment2]
@ -315,7 +316,7 @@ class FieldTest(unittest.TestCase):
person.validate() person.validate()
def test_embedded_document_inheritance(self): def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to """Ensure that subclasses of embedded documents may be provided to
EmbeddedDocumentFields of the superclass' type. EmbeddedDocumentFields of the superclass' type.
""" """
class User(EmbeddedDocument): class User(EmbeddedDocument):
@ -327,7 +328,7 @@ class FieldTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = EmbeddedDocumentField(User) author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...') post = BlogPost(content='What I did today...')
post.author = User(name='Test User') post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47) post.author = PowerUser(name='Test User', power=47)
@ -370,7 +371,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_item_dereference(self): def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -434,7 +435,7 @@ class FieldTest(unittest.TestCase):
class TreeNode(EmbeddedDocument): class TreeNode(EmbeddedDocument):
name = StringField() name = StringField()
children = ListField(EmbeddedDocumentField('self')) children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree") tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1") first_child = TreeNode(name="Child 1")
@ -442,7 +443,7 @@ class FieldTest(unittest.TestCase):
second_child = TreeNode(name="Child 2") second_child = TreeNode(name="Child 2")
first_child.children.append(second_child) first_child.children.append(second_child)
third_child = TreeNode(name="Child 3") third_child = TreeNode(name="Child 3")
first_child.children.append(third_child) first_child.children.append(third_child)
@ -506,20 +507,20 @@ class FieldTest(unittest.TestCase):
Member.drop_collection() Member.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_generic_reference(self): def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items. """Ensure that a GenericReferenceField properly dereferences items.
""" """
class Link(Document): class Link(Document):
title = StringField() title = StringField()
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
class Post(Document): class Post(Document):
title = StringField() title = StringField()
class Bookmark(Document): class Bookmark(Document):
bookmark_object = GenericReferenceField() bookmark_object = GenericReferenceField()
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
Bookmark.drop_collection() Bookmark.drop_collection()
@ -574,16 +575,49 @@ class FieldTest(unittest.TestCase):
user = User(bookmarks=[post_1, link_1]) user = User(bookmarks=[post_1, link_1])
user.save() user.save()
user = User.objects(bookmarks__all=[post_1, link_1]).first() user = User.objects(bookmarks__all=[post_1, link_1]).first()
self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[0], post_1)
self.assertEqual(user.bookmarks[1], link_1) self.assertEqual(user.bookmarks[1], link_1)
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError, "Link was removed from the registry"
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """
@ -701,6 +735,12 @@ class FieldTest(unittest.TestCase):
self.assertTrue(streamfile == result) self.assertTrue(streamfile == result)
self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.read(), text + more_text)
self.assertEquals(result.file.content_type, content_type) self.assertEquals(result.file.content_type, content_type)
result.file.seek(0)
self.assertEquals(result.file.tell(), 0)
self.assertEquals(result.file.read(len(text)), text)
self.assertEquals(result.file.tell(), len(text))
self.assertEquals(result.file.read(len(more_text)), more_text)
self.assertEquals(result.file.tell(), len(text + more_text))
result.file.delete() result.file.delete()
# Ensure deleted file returns None # Ensure deleted file returns None
@ -721,7 +761,7 @@ class FieldTest(unittest.TestCase):
result = SetFile.objects.first() result = SetFile.objects.first()
self.assertTrue(setfile == result) self.assertTrue(setfile == result)
self.assertEquals(result.file.read(), more_text) self.assertEquals(result.file.read(), more_text)
result.file.delete() result.file.delete()
PutFile.drop_collection() PutFile.drop_collection()
StreamFile.drop_collection() StreamFile.drop_collection()

View File

@ -597,6 +597,81 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection() Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__n=3).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__n=-3).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_slicing_nested_fields(self):
"""Ensure that query slicing an embedded array works.
"""
class EmbeddedNumber(EmbeddedDocument):
n = ListField(IntField())
class Numbers(Document):
embedded = EmbeddedDocumentField(EmbeddedNumber)
Numbers.drop_collection()
numbers = Numbers()
numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
self.assertEquals(numbers.embedded.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
self.assertEquals(numbers.embedded.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
self.assertEquals(numbers.embedded.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -1294,6 +1369,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
tags = ListField(StringField()) tags = ListField(StringField())
deleted = BooleanField(default=False) deleted = BooleanField(default=False)
date = DateTimeField(default=datetime.now)
@queryset_manager @queryset_manager
def objects(doc_cls, queryset): def objects(doc_cls, queryset):
@ -1301,7 +1377,7 @@ class QuerySetTest(unittest.TestCase):
@queryset_manager @queryset_manager
def music_posts(doc_cls, queryset): def music_posts(doc_cls, queryset):
return queryset(tags='music', deleted=False) return queryset(tags='music', deleted=False).order_by('-date')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1317,7 +1393,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([p.id for p in BlogPost.objects], self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id]) [post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.music_posts], self.assertEqual([p.id for p in BlogPost.music_posts],
[post1.id, post2.id]) [post2.id, post1.id])
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1760,6 +1836,25 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection() Number.drop_collection()
def test_order_works_with_primary(self):
"""Ensure that order_by and primary work.
"""
class Number(Document):
n = IntField(primary_key=True)
Number.drop_collection()
Number(n=1).save()
Number(n=2).save()
Number(n=3).save()
numbers = [n.n for n in Number.objects.order_by('-n')]
self.assertEquals([3, 2, 1], numbers)
numbers = [n.n for n in Number.objects.order_by('+n')]
self.assertEquals([1, 2, 3], numbers)
Number.drop_collection()
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
@ -1931,49 +2026,53 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self): def test_include_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True}) self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self): def test_include_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True}) self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self): def test_exclude_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self): def test_exclude_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True}) self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self): def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self): def test_reset(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset() q.reset()
self.assertFalse(q) self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()