diff --git a/mongoengine/base.py b/mongoengine/base.py index 495e2418..0c024be5 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -7,16 +7,26 @@ import pymongo import pymongo.objectid -_document_registry = {} - -def get_document(name): - return _document_registry[name] +class NotRegistered(Exception): + pass class ValidationError(Exception): pass +_document_registry = {} + +def get_document(name): + if name not in _document_registry: + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() % name) + return _document_registry[name] + + class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -295,6 +305,30 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): for spec in meta['indexes']] + base_indexes new_class._meta['indexes'] = user_indexes + unique_indexes = cls._unique_with_indexes(new_class) + new_class._meta['unique_indexes'] = unique_indexes + + for field_name, field in new_class._fields.items(): + # Check for custom primary key + if field.primary_key: + current_pk = new_class._meta['id_field'] + if current_pk and current_pk != field_name: + raise ValueError('Cannot override primary key field') + + if not current_pk: + new_class._meta['id_field'] = field_name + # Make 'Document.id' an alias to the real primary key field + new_class.id = field + + if not new_class._meta['id_field']: + new_class._meta['id_field'] = 'id' + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class.id = new_class._fields['id'] + + return new_class + + @classmethod + def _unique_with_indexes(cls, new_class, namespace=""): unique_indexes = [] for field_name, field in new_class._fields.items(): # Generate a list of indexes needed by uniqueness constraints @@ -320,28 +354,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): unique_fields += unique_with # Add the new index to the list - index = [(f, pymongo.ASCENDING) for f in unique_fields] + index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] unique_indexes.append(index) - # Check for custom primary key - if field.primary_key: - current_pk = new_class._meta['id_field'] - if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') + # Grab any embedded document field unique indexes + if field.__class__.__name__ == "EmbeddedDocumentField": + field_namespace = "%s." % field_name + unique_indexes += cls._unique_with_indexes(field.document_type, + field_namespace) - if not current_pk: - new_class._meta['id_field'] = field_name - # Make 'Document.id' an alias to the real primary key field - new_class.id = field - - new_class._meta['unique_indexes'] = unique_indexes - - if not new_class._meta['id_field']: - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class.id = new_class._fields['id'] - - return new_class + return unique_indexes class BaseDocument(object): diff --git a/mongoengine/document.py b/mongoengine/document.py index 196662c3..771b9229 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -40,44 +40,54 @@ class Document(BaseDocument): presence of `_cls` and `_types`, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. - A :class:`~mongoengine.Document` may use a **Capped Collection** by + A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` dictionary. :attr:`max_documents` is the maximum number of documents that - is allowed to be stored in the collection, and :attr:`max_size` is the - maximum size of the collection in bytes. If :attr:`max_size` is not - specified and :attr:`max_documents` is, :attr:`max_size` defaults to + is allowed to be stored in the collection, and :attr:`max_size` is the + maximum size of the collection in bytes. If :attr:`max_size` is not + specified and :attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB). Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` - dictionary. The value should be a list of field names or tuples of field + dictionary. The value should be a list of field names or tuples of field names. Index direction may be specified by prefixing the field names with a **+** or **-** sign. """ __metaclass__ = TopLevelDocumentMetaclass - def save(self, safe=True, force_insert=False, validate=True): + def save(self, safe=True, force_insert=False, validate=True, write_options=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an + If ``safe=True`` and the operation is unsuccessful, an :class:`~mongoengine.OperationError` will be raised. :param safe: check if the operation succeeded before returning - :param force_insert: only try to create a new document, don't allow + :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. + :param write_options: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.save` OR + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant ``getLastError`` command. + For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers + have recorded the write and will force an fsync on each server being written to. """ if validate: self.validate() + + if not write_options: + write_options = {} + doc = self.to_mongo() try: collection = self.__class__.objects._collection if force_insert: - object_id = collection.insert(doc, safe=safe) + object_id = collection.insert(doc, safe=safe, **write_options) else: - object_id = collection.save(doc, safe=safe) + object_id = collection.save(doc, safe=safe, **write_options) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): @@ -131,9 +141,9 @@ class MapReduceDocument(object): """A document returned from a map/reduce query. :param collection: An instance of :class:`~pymongo.Collection` - :param key: Document/result key, often an instance of - :class:`~pymongo.objectid.ObjectId`. If supplied as - an ``ObjectId`` found in the given ``collection``, + :param key: Document/result key, often an instance of + :class:`~pymongo.objectid.ObjectId`. If supplied as + an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. @@ -148,7 +158,7 @@ class MapReduceDocument(object): @property def object(self): - """Lazy-load the object referenced by ``self.key``. ``self.key`` + """Lazy-load the object referenced by ``self.key``. ``self.key`` should be the ``primary_key``. """ id_field = self._document()._meta['id_field'] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9fcfcc2b..0cc8219b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -522,6 +522,9 @@ class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). + note: Any documents used as a generic reference must be registered in the + document registry. Importing the model will automatically register it. + .. versionadded:: 0.3 """ @@ -601,6 +604,7 @@ class GridFSProxy(object): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file + self.gridout = None def __getattr__(self, name): obj = self.get() @@ -614,8 +618,12 @@ class GridFSProxy(object): def get(self, id=None): if id: self.grid_id = id + if self.grid_id is None: + return None try: - return self.fs.get(id or self.grid_id) + if self.gridout is None: + self.gridout = self.fs.get(self.grid_id) + return self.gridout except: # File has been deleted return None @@ -645,9 +653,9 @@ class GridFSProxy(object): self.grid_id = self.newfile._id self.newfile.writelines(lines) - def read(self): + def read(self, size=-1): try: - return self.get().read() + return self.get().read(size) except: return None @@ -655,6 +663,7 @@ class GridFSProxy(object): # Delete file from GridFS, FileField still remains self.fs.delete(self.grid_id) self.grid_id = None + self.gridout = None def replace(self, file, **kwargs): self.delete() diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index f58328ac..1d004d68 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -8,6 +8,7 @@ import pymongo.objectid import re import copy import itertools +import operator __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] @@ -280,30 +281,30 @@ class QueryFieldList(object): ONLY = True EXCLUDE = False - def __init__(self, fields=[], direction=ONLY, always_include=[]): - self.direction = direction + def __init__(self, fields=[], value=ONLY, always_include=[]): + self.value = value self.fields = set(fields) self.always_include = set(always_include) def as_dict(self): - return dict((field, self.direction) for field in self.fields) + return dict((field, self.value) for field in self.fields) def __add__(self, f): if not self.fields: self.fields = f.fields - self.direction = f.direction - elif self.direction is self.ONLY and f.direction is self.ONLY: + self.value = f.value + elif self.value is self.ONLY and f.value is self.ONLY: self.fields = self.fields.intersection(f.fields) - elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: + elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: self.fields = self.fields.union(f.fields) - elif self.direction is self.ONLY and f.direction is self.EXCLUDE: + elif self.value is self.ONLY and f.value is self.EXCLUDE: self.fields -= f.fields - elif self.direction is self.EXCLUDE and f.direction is self.ONLY: - self.direction = self.ONLY + elif self.value is self.EXCLUDE and f.value is self.ONLY: + self.value = self.ONLY self.fields = f.fields - self.fields if self.always_include: - if self.direction is self.ONLY and self.fields: + if self.value is self.ONLY and self.fields: self.fields = self.fields.union(self.always_include) else: self.fields -= self.always_include @@ -311,7 +312,7 @@ class QueryFieldList(object): def reset(self): self.fields = set([]) - self.direction = self.ONLY + self.value = self.ONLY def __nonzero__(self): return bool(self.fields) @@ -551,7 +552,7 @@ class QuerySet(object): return '.'.join(parts) @classmethod - def _transform_query(cls, _doc_cls=None, **query): + def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', @@ -646,7 +647,7 @@ class QuerySet(object): raise self._document.DoesNotExist("%s matching query does not exist." % self._document._class_name) - def get_or_create(self, *q_objs, **query): + def get_or_create(self, write_options=None, *q_objs, **query): """Retrieve unique object or create, if it doesn't exist. Returns a tuple of ``(object, created)``, where ``object`` is the retrieved or created object and ``created`` is a boolean specifying whether a new object was created. Raises @@ -656,6 +657,10 @@ class QuerySet(object): dictionary of default values for the new document may be provided as a keyword argument called :attr:`defaults`. + :param write_options: optional extra keyword arguments used if we + have to create a new document. + Passes any write_options onto :meth:`~mongoengine.document.Document.save` + .. versionadded:: 0.3 """ defaults = query.get('defaults', {}) @@ -667,7 +672,7 @@ class QuerySet(object): if count == 0: query.update(defaults) doc = self._document(**query) - doc.save() + doc.save(write_options=write_options) return doc, True elif count == 1: return self.first(), False @@ -893,10 +898,8 @@ class QuerySet(object): .. versionadded:: 0.3 """ - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) - return self - + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(**fields) def exclude(self, *fields): """Opposite to .only(), exclude some document's fields. :: @@ -905,8 +908,44 @@ class QuerySet(object): :param fields: fields to exclude """ - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements + --------------------------------------- + + You can use the $slice operator to retrieve a subrange of elements in + an array :: + + post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = self._fields_to_dbfields(fields) + self._loaded_fields += QueryFieldList(fields, value=value) return self def all_fields(self): @@ -1062,22 +1101,27 @@ class QuerySet(object): return mongo_update - def update(self, safe_update=True, upsert=False, **update): + def update(self, safe_update=True, upsert=False, write_options=None, **update): """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning - :param update: Django-style update keyword arguments + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` .. versionadded:: 0.2 """ if pymongo.version < '1.1.1': raise OperationError('update() method requires PyMongo 1.1.1+') + if not write_options: + write_options = {} + update = QuerySet._transform_update(self._document, **update) try: ret = self._collection.update(self._query, update, multi=True, - upsert=upsert, safe=safe_update) + upsert=upsert, safe=safe_update, + **write_options) if ret is not None and 'n' in ret: return ret['n'] except pymongo.errors.OperationFailure, err: @@ -1086,22 +1130,27 @@ class QuerySet(object): raise OperationError(message) raise OperationError(u'Update failed (%s)' % unicode(err)) - def update_one(self, safe_update=True, upsert=False, **update): + def update_one(self, safe_update=True, upsert=False, write_options=None, **update): """Perform an atomic update on first field matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` :param update: Django-style update keyword arguments .. versionadded:: 0.2 """ + if not write_options: + write_options = {} update = QuerySet._transform_update(self._document, **update) try: # Explicitly provide 'multi=False' to newer versions of PyMongo # as the default may change to 'True' if pymongo.version >= '1.1.1': ret = self._collection.update(self._query, update, multi=False, - upsert=upsert, safe=safe_update) + upsert=upsert, safe=safe_update, + **write_options) else: # Older versions of PyMongo don't support 'multi' ret = self._collection.update(self._query, update, @@ -1284,7 +1333,7 @@ class QuerySetManager(object): # Create collection as a capped collection if specified if owner._meta['max_size'] or owner._meta['max_documents']: # Get max document limit and max byte size from meta - max_size = owner._meta['max_size'] or 10000000 # 10MB default + max_size = owner._meta['max_size'] or 10000000 # 10MB default max_documents = owner._meta['max_documents'] if collection in db.collection_names(): diff --git a/tests/document.py b/tests/document.py index dee0b712..8f47ec3c 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,11 +1,23 @@ import unittest from datetime import datetime import pymongo +import pickle from mongoengine import * +from mongoengine.base import BaseField from mongoengine.connection import _get_db +class PickleEmbedded(EmbeddedDocument): + date = DateTimeField(default=datetime.now) + +class PickleTest(Document): + number = IntField() + string = StringField() + embedded = EmbeddedDocumentField(PickleEmbedded) + lists = ListField(StringField()) + + class DocumentTest(unittest.TestCase): def setUp(self): @@ -200,6 +212,22 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_collection_name_and_primary(self): + """Ensure that a collection with a specified name may be used. + """ + + class Person(Document): + name = StringField(primary_key=True) + meta = {'collection': 'app'} + + user = Person(name="Test User") + user.save() + + user_obj = Person.objects[0] + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' collections. """ @@ -334,6 +362,10 @@ class DocumentTest(unittest.TestCase): post2 = BlogPost(title='test2', slug='test') self.assertRaises(OperationError, post2.save) + + def test_unique_with(self): + """Ensure that unique_with constraints are applied to fields. + """ class Date(EmbeddedDocument): year = IntField(db_field='yr') @@ -357,6 +389,63 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + def test_unique_embedded_document(self): + """Ensure that uniqueness constraints are applied to fields on embedded documents. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_with_embedded_document_and_embedded_unique(self): + """Ensure that uniqueness constraints are applied to fields on + embedded documents. And work with unique_with as well. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField(unique_with='sub.year') + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + # Now there will be two docs with the same title and year + post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + def test_unique_and_indexes(self): """Ensure that 'unique' constraints aren't overridden by meta.indexes. @@ -798,6 +887,25 @@ class DocumentTest(unittest.TestCase): self.Person.drop_collection() BlogPost.drop_collection() + def subclasses_and_unique_keys_works(self): + + class A(Document): + pass + + class B(A): + foo = BooleanField(unique=True) + + A.drop_collection() + B.drop_collection() + + A().save() + A().save() + B(foo=True).save() + + self.assertEquals(A.objects.count(), 2) + self.assertEquals(B.objects.count(), 1) + A.drop_collection() + B.drop_collection() def tearDown(self): self.Person.drop_collection() @@ -850,6 +958,43 @@ class DocumentTest(unittest.TestCase): self.assertTrue(u1 in all_user_set ) + def test_picklable(self): + + pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2']) + pickle_doc.embedded = PickleEmbedded() + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEquals(resurrected, pickle_doc) + + resurrected.string = "Working" + resurrected.save() + + pickle_doc.reload() + self.assertEquals(resurrected, pickle_doc) + + def test_write_options(self): + """Test that passing write_options works""" + + self.Person.drop_collection() + + write_options = {"fsync": True} + + author, created = self.Person.objects.get_or_create( + name='Test User', write_options=write_options) + author.save(write_options=write_options) + + self.Person.objects.update(set__name='Ross', write_options=write_options) + + author = self.Person.objects.first() + self.assertEquals(author.name, 'Ross') + + self.Person.objects.update_one(set__name='Test User', write_options=write_options) + author = self.Person.objects.first() + self.assertEquals(author.name, 'Test User') + if __name__ == '__main__': unittest.main() diff --git a/tests/fields.py b/tests/fields.py index f24b5eda..38409b6a 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -7,6 +7,7 @@ import gridfs from mongoengine import * from mongoengine.connection import _get_db +from mongoengine.base import _document_registry, NotRegistered class FieldTest(unittest.TestCase): @@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase): """ class Person(Document): name = StringField() - + person = Person(name='Test User') self.assertEqual(person.id, None) @@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase): link.url = 'http://www.google.com:8080' link.validate() - + def test_int_validation(self): """Ensure that invalid values cannot be assigned to int fields. """ @@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.height = 4.0 self.assertRaises(ValidationError, person.validate) - + def test_decimal_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), + height = DecimalField(min_value=Decimal('0.1'), max_value=Decimal('3.5')) Person.drop_collection() @@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase): post.save() post.reload() self.assertEqual(post.tags, ['fun', 'leisure']) - + comment1 = Comment(content='Good for you', order=1) comment2 = Comment(content='Yay.', order=0) comments = [comment1, comment2] @@ -315,7 +316,7 @@ class FieldTest(unittest.TestCase): person.validate() def test_embedded_document_inheritance(self): - """Ensure that subclasses of embedded documents may be provided to + """Ensure that subclasses of embedded documents may be provided to EmbeddedDocumentFields of the superclass' type. """ class User(EmbeddedDocument): @@ -327,7 +328,7 @@ class FieldTest(unittest.TestCase): class BlogPost(Document): content = StringField() author = EmbeddedDocumentField(User) - + post = BlogPost(content='What I did today...') post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) @@ -370,7 +371,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() - + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -434,7 +435,7 @@ class FieldTest(unittest.TestCase): class TreeNode(EmbeddedDocument): name = StringField() children = ListField(EmbeddedDocumentField('self')) - + tree = Tree(name="Tree") first_child = TreeNode(name="Child 1") @@ -442,7 +443,7 @@ class FieldTest(unittest.TestCase): second_child = TreeNode(name="Child 2") first_child.children.append(second_child) - + third_child = TreeNode(name="Child 3") first_child.children.append(third_child) @@ -506,20 +507,20 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() - + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """ class Link(Document): title = StringField() meta = {'allow_inheritance': False} - + class Post(Document): title = StringField() - + class Bookmark(Document): bookmark_object = GenericReferenceField() - + Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() @@ -574,16 +575,49 @@ class FieldTest(unittest.TestCase): user = User(bookmarks=[post_1, link_1]) user.save() - + user = User.objects(bookmarks__all=[post_1, link_1]).first() - + self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[1], link_1) - + Link.drop_collection() Post.drop_collection() User.drop_collection() + def test_generic_reference_document_not_registered(self): + """Ensure dereferencing out of the document registry throws a + `NotRegistered` error. + """ + class Link(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField()) + + Link.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + user = User(bookmarks=[link_1]) + user.save() + + # Mimic User and Link definitions being in a different file + # and the Link model not being imported in the User file. + del(_document_registry["Link"]) + + user = User.objects.first() + try: + user.bookmarks + raise AssertionError, "Link was removed from the registry" + except NotRegistered: + pass + + Link.drop_collection() + User.drop_collection() + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """ @@ -701,6 +735,12 @@ class FieldTest(unittest.TestCase): self.assertTrue(streamfile == result) self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.content_type, content_type) + result.file.seek(0) + self.assertEquals(result.file.tell(), 0) + self.assertEquals(result.file.read(len(text)), text) + self.assertEquals(result.file.tell(), len(text)) + self.assertEquals(result.file.read(len(more_text)), more_text) + self.assertEquals(result.file.tell(), len(text + more_text)) result.file.delete() # Ensure deleted file returns None @@ -721,7 +761,7 @@ class FieldTest(unittest.TestCase): result = SetFile.objects.first() self.assertTrue(setfile == result) self.assertEquals(result.file.read(), more_text) - result.file.delete() + result.file.delete() PutFile.drop_collection() StreamFile.drop_collection() diff --git a/tests/queryset.py b/tests/queryset.py index 48ce6272..5a2c46cb 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -597,6 +597,81 @@ class QuerySetTest(unittest.TestCase): Email.drop_collection() + def test_slicing_fields(self): + """Ensure that query slicing an array works. + """ + class Numbers(Document): + n = ListField(IntField()) + + Numbers.drop_collection() + + numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__n=3).get() + self.assertEquals(numbers.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__n=-3).get() + self.assertEquals(numbers.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__n=[2, 3]).get() + self.assertEquals(numbers.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + def test_slicing_nested_fields(self): + """Ensure that query slicing an embedded array works. + """ + + class EmbeddedNumber(EmbeddedDocument): + n = ListField(IntField()) + + class Numbers(Document): + embedded = EmbeddedDocumentField(EmbeddedNumber) + + Numbers.drop_collection() + + numbers = Numbers() + numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__embedded__n=3).get() + self.assertEquals(numbers.embedded.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__embedded__n=-3).get() + self.assertEquals(numbers.embedded.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() + self.assertEquals(numbers.embedded.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -1294,6 +1369,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): tags = ListField(StringField()) deleted = BooleanField(default=False) + date = DateTimeField(default=datetime.now) @queryset_manager def objects(doc_cls, queryset): @@ -1301,7 +1377,7 @@ class QuerySetTest(unittest.TestCase): @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music', deleted=False) + return queryset(tags='music', deleted=False).order_by('-date') BlogPost.drop_collection() @@ -1317,7 +1393,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) self.assertEqual([p.id for p in BlogPost.music_posts], - [post1.id, post2.id]) + [post2.id, post1.id]) BlogPost.drop_collection() @@ -1760,6 +1836,25 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_order_works_with_primary(self): + """Ensure that order_by and primary work. + """ + class Number(Document): + n = IntField(primary_key=True) + + Number.drop_collection() + + Number(n=1).save() + Number(n=2).save() + Number(n=3).save() + + numbers = [n.n for n in Number.objects.order_by('-n')] + self.assertEquals([3, 2, 1], numbers) + + numbers = [n.n for n in Number.objects.order_by('+n')] + self.assertEquals([1, 2, 3], numbers) + Number.drop_collection() + class QTest(unittest.TestCase): @@ -1931,49 +2026,53 @@ class QueryFieldListTest(unittest.TestCase): def test_include_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'b': True}) def test_include_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': True}) def test_exclude_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) def test_exclude_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'c': True}) def test_always_include(self): q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) - def test_reset(self): q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) q.reset() self.assertFalse(q) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) + def test_using_a_slice(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a'], value={"$slice": 5}) + self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) + if __name__ == '__main__': unittest.main()