diff --git a/docs/apireference.rst b/docs/apireference.rst index 857a14b0..7ba93408 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -79,6 +79,7 @@ Fields .. autoclass:: mongoengine.fields.GenericEmbeddedDocumentField .. autoclass:: mongoengine.fields.DynamicField .. autoclass:: mongoengine.fields.ListField +.. autoclass:: mongoengine.fields.EmbeddedDocumentListField .. autoclass:: mongoengine.fields.SortedListField .. autoclass:: mongoengine.fields.DictField .. autoclass:: mongoengine.fields.MapField @@ -103,6 +104,21 @@ Fields .. autoclass:: mongoengine.fields.ImageGridFsProxy .. autoclass:: mongoengine.fields.ImproperlyConfigured +Embedded Document Querying +========================== + +.. versionadded:: 0.9 + +Additional queries for Embedded Documents are available when using the +:class:`~mongoengine.EmbeddedDocumentListField` to store a list of embedded +documents. + +A list of embedded documents is returned as a special list with the +following methods: + +.. autoclass:: mongoengine.base.datastructures.EmbeddedDocumentList + :members: + Misc ==== diff --git a/docs/changelog.rst b/docs/changelog.rst index 4e53234f..9aa6aa3a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Added `EmbeddedDocumentListField` for Lists of Embedded Documents. #826 - ComplexDateTimeField should fall back to None when null=True #864 - Request Support for $min, $max Field update operators #863 - `BaseDict` does not follow `setdefault` #866 diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 022338d7..bac67ddc 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -2,8 +2,9 @@ import weakref import functools import itertools from mongoengine.common import _import_class +from mongoengine.errors import DoesNotExist, MultipleObjectsReturned -__all__ = ("BaseDict", "BaseList") +__all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList") class BaseDict(dict): @@ -106,7 +107,7 @@ class BaseList(list): if isinstance(instance, (Document, EmbeddedDocument)): self._instance = weakref.proxy(instance) self._name = name - return super(BaseList, self).__init__(list_items) + super(BaseList, self).__init__(list_items) def __getitem__(self, key, *args, **kwargs): value = super(BaseList, self).__getitem__(key) @@ -191,6 +192,167 @@ class BaseList(list): self._instance._mark_as_changed(self._name) +class EmbeddedDocumentList(BaseList): + + @classmethod + def __match_all(cls, i, kwargs): + items = kwargs.items() + return all([ + getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items + ]) + + @classmethod + def __only_matches(cls, obj, kwargs): + if not kwargs: + return obj + return filter(lambda i: cls.__match_all(i, kwargs), obj) + + def __init__(self, list_items, instance, name): + super(EmbeddedDocumentList, self).__init__(list_items, instance, name) + self._instance = instance + + def filter(self, **kwargs): + """ + Filters the list by only including embedded documents with the + given keyword arguments. + + :param kwargs: The keyword arguments corresponding to the fields to + filter on. *Multiple arguments are treated as if they are ANDed + together.* + :return: A new ``EmbeddedDocumentList`` containing the matching + embedded documents. + + Raises ``AttributeError`` if a given keyword is not a valid field for + the embedded document class. + """ + values = self.__only_matches(self, kwargs) + return EmbeddedDocumentList(values, self._instance, self._name) + + def exclude(self, **kwargs): + """ + Filters the list by excluding embedded documents with the given + keyword arguments. + + :param kwargs: The keyword arguments corresponding to the fields to + exclude on. *Multiple arguments are treated as if they are ANDed + together.* + :return: A new ``EmbeddedDocumentList`` containing the non-matching + embedded documents. + + Raises ``AttributeError`` if a given keyword is not a valid field for + the embedded document class. + """ + exclude = self.__only_matches(self, kwargs) + values = [item for item in self if item not in exclude] + return EmbeddedDocumentList(values, self._instance, self._name) + + def count(self): + """ + The number of embedded documents in the list. + + :return: The length of the list, equivalent to the result of ``len()``. + """ + return len(self) + + def get(self, **kwargs): + """ + Retrieves an embedded document determined by the given keyword + arguments. + + :param kwargs: The keyword arguments corresponding to the fields to + search on. *Multiple arguments are treated as if they are ANDed + together.* + :return: The embedded document matched by the given keyword arguments. + + Raises ``DoesNotExist`` if the arguments used to query an embedded + document returns no results. ``MultipleObjectsReturned`` if more + than one result is returned. + """ + values = self.__only_matches(self, kwargs) + if len(values) == 0: + raise DoesNotExist( + "%s matching query does not exist." % self._name + ) + elif len(values) > 1: + raise MultipleObjectsReturned( + "%d items returned, instead of 1" % len(values) + ) + + return values[0] + + def first(self): + """ + Returns the first embedded document in the list, or ``None`` if empty. + """ + if len(self) > 0: + return self[0] + + def create(self, **values): + """ + Creates a new embedded document and saves it to the database. + + .. note:: + The embedded document changes are not automatically saved + to the database after calling this method. + + :param values: A dictionary of values for the embedded document. + :return: The new embedded document instance. + """ + name = self._name + EmbeddedClass = self._instance._fields[name].field.document_type_obj + self._instance[self._name].append(EmbeddedClass(**values)) + + return self._instance[self._name][-1] + + def save(self, *args, **kwargs): + """ + Saves the ancestor document. + + :param args: Arguments passed up to the ancestor Document's save + method. + :param kwargs: Keyword arguments passed up to the ancestor Document's + save method. + """ + self._instance.save(*args, **kwargs) + + def delete(self): + """ + Deletes the embedded documents from the database. + + .. note:: + The embedded document changes are not automatically saved + to the database after calling this method. + + :return: The number of entries deleted. + """ + values = list(self) + for item in values: + self._instance[self._name].remove(item) + + return len(values) + + def update(self, **update): + """ + Updates the embedded documents with the given update values. + + .. note:: + The embedded document changes are not automatically saved + to the database after calling this method. + + :param update: A dictionary of update values to apply to each + embedded document. + :return: The number of entries updated. + """ + if len(update) == 0: + return 0 + values = list(self) + for item in values: + for k, v in update.items(): + setattr(item, k, v) + + return len(values) + + class StrictDict(object): __slots__ = () _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 998e366f..6bf38ee2 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -16,7 +16,13 @@ from mongoengine.errors import (ValidationError, InvalidDocumentError, from mongoengine.python_support import PY3, txt_type from mongoengine.base.common import get_document, ALLOW_INHERITANCE -from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict +from mongoengine.base.datastructures import ( + BaseDict, + BaseList, + EmbeddedDocumentList, + StrictDict, + SemiStrictDict +) from mongoengine.base.fields import ComplexBaseField __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') @@ -419,6 +425,8 @@ class BaseDocument(object): if not isinstance(value, (dict, list, tuple)): return value + EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') + is_list = False if not hasattr(value, 'items'): is_list = True @@ -442,7 +450,10 @@ class BaseDocument(object): # Convert lists / values so we can watch for any changes on them if (isinstance(value, (list, tuple)) and not isinstance(value, BaseList)): - value = BaseList(value, self, name) + if issubclass(type(self), EmbeddedDocumentListField): + value = EmbeddedDocumentList(value, self, name) + else: + value = BaseList(value, self, name) elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, self, name) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 359ea6d2..aa16804e 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -9,7 +9,9 @@ from mongoengine.common import _import_class from mongoengine.errors import ValidationError from mongoengine.base.common import ALLOW_INHERITANCE -from mongoengine.base.datastructures import BaseDict, BaseList +from mongoengine.base.datastructures import ( + BaseDict, BaseList, EmbeddedDocumentList +) __all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") @@ -210,6 +212,7 @@ class ComplexBaseField(BaseField): ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') + EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') dereference = (self._auto_dereference and (self.field is None or isinstance(self.field, (GenericReferenceField, ReferenceField)))) @@ -226,9 +229,12 @@ class ComplexBaseField(BaseField): value = super(ComplexBaseField, self).__get__(instance, owner) # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, instance, self.name) + if isinstance(value, (list, tuple)): + if (issubclass(type(self), EmbeddedDocumentListField) and + not isinstance(value, EmbeddedDocumentList)): + value = EmbeddedDocumentList(value, instance, self.name) + elif not isinstance(value, BaseList): + value = BaseList(value, instance, self.name) instance._data[self.name] = value elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, instance, self.name) diff --git a/mongoengine/common.py b/mongoengine/common.py index 7c0c18d2..3e63e98e 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -1,4 +1,5 @@ _class_registry_cache = {} +_field_list_cache = [] def _import_class(cls_name): @@ -20,13 +21,16 @@ def _import_class(cls_name): doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', 'MapReduceDocument') - field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', - 'FileField', 'GenericReferenceField', - 'GenericEmbeddedDocumentField', 'GeoPointField', - 'PointField', 'LineStringField', 'ListField', - 'PolygonField', 'ReferenceField', 'StringField', - 'CachedReferenceField', - 'ComplexBaseField', 'GeoJsonBaseField') + + # Field Classes + if not _field_list_cache: + from mongoengine.fields import __all__ as fields + _field_list_cache.extend(fields) + from mongoengine.base.fields import __all__ as fields + _field_list_cache.extend(fields) + + field_classes = _field_list_cache + queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index a22e3473..415d5678 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,6 +1,9 @@ from bson import DBRef, SON -from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) +from base import ( + BaseDict, BaseList, EmbeddedDocumentList, + TopLevelDocumentMetaclass, get_document +) from fields import (ReferenceField, ListField, DictField, MapField) from connection import get_db from queryset import QuerySet @@ -189,6 +192,9 @@ class DeReference(object): if not hasattr(items, 'items'): is_list = True + list_type = BaseList + if isinstance(items, EmbeddedDocumentList): + list_type = EmbeddedDocumentList as_tuple = isinstance(items, tuple) iterator = enumerate(items) data = [] @@ -225,7 +231,7 @@ class DeReference(object): if instance and name: if is_list: - return tuple(data) if as_tuple else BaseList(data, instance, name) + return tuple(data) if as_tuple else list_type(data, instance, name) return BaseDict(data, instance, name) depth += 1 return data diff --git a/mongoengine/document.py b/mongoengine/document.py index 03318c84..5b1b313d 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -9,9 +9,16 @@ from bson import ObjectId from bson.dbref import DBRef from mongoengine import signals from mongoengine.common import _import_class -from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, - BaseDocument, BaseDict, BaseList, - ALLOW_INHERITANCE, get_document) +from mongoengine.base import ( + DocumentMetaclass, + TopLevelDocumentMetaclass, + BaseDocument, + BaseDict, + BaseList, + EmbeddedDocumentList, + ALLOW_INHERITANCE, + get_document +) from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError from mongoengine.queryset import (OperationError, NotUniqueError, QuerySet, transform) @@ -76,6 +83,12 @@ class EmbeddedDocument(BaseDocument): def __ne__(self, other): return not self.__eq__(other) + def save(self, *args, **kwargs): + self._instance.save(*args, **kwargs) + + def reload(self, *args, **kwargs): + self._instance.reload(*args, **kwargs) + class Document(BaseDocument): @@ -560,6 +573,9 @@ class Document(BaseDocument): if isinstance(value, BaseDict): value = [(k, self._reload(k, v)) for k, v in value.items()] value = BaseDict(value, self, key) + elif isinstance(value, EmbeddedDocumentList): + value = [self._reload(key, v) for v in value] + value = EmbeddedDocumentList(value, self, key) elif isinstance(value, BaseList): value = [self._reload(key, v) for v in value] value = BaseList(value, self, key) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index cfe66a48..9d66774e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -39,13 +39,13 @@ __all__ = [ 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', - 'SortedListField', 'DictField', 'MapField', 'ReferenceField', - 'CachedReferenceField', 'GenericReferenceField', 'BinaryField', - 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy', - 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', - 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', - 'MultiPointField', 'MultiLineStringField', 'MultiPolygonField', - 'GeoJsonBaseField'] + 'SortedListField', 'EmbeddedDocumentListField', 'DictField', + 'MapField', 'ReferenceField', 'CachedReferenceField', + 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy', + 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', + 'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', + 'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', + 'MultiPolygonField', 'GeoJsonBaseField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -728,6 +728,32 @@ class ListField(ComplexBaseField): return super(ListField, self).prepare_query_value(op, value) +class EmbeddedDocumentListField(ListField): + """A :class:`~mongoengine.ListField` designed specially to hold a list of + embedded documents to provide additional query helpers. + + .. note:: + The only valid list values are subclasses of + :class:`~mongoengine.EmbeddedDocument`. + + .. versionadded:: 0.9 + + """ + + def __init__(self, document_type, *args, **kwargs): + """ + :param document_type: The type of + :class:`~mongoengine.EmbeddedDocument` the list will hold. + :param args: Arguments passed directly into the parent + :class:`~mongoengine.ListField`. + :param kwargs: Keyword arguments passed directly into the parent + :class:`~mongoengine.ListField`. + """ + super(EmbeddedDocumentListField, self).__init__( + field=EmbeddedDocumentField(document_type), **kwargs + ) + + class SortedListField(ListField): """A ListField that sorts the contents of its list before writing to diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 8e957c68..64898995 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -18,11 +18,11 @@ from bson import Binary, DBRef, ObjectId from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import _document_registry -from mongoengine.base.datastructures import BaseDict +from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList from mongoengine.errors import NotRegistered from mongoengine.python_support import PY3, b, bin_type -__all__ = ("FieldTest", ) +__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") class FieldTest(unittest.TestCase): @@ -3159,5 +3159,473 @@ class FieldTest(unittest.TestCase): self.assertRaises(FieldDoesNotExist, test) + +class EmbeddedDocumentListFieldTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.db = connect(db='EmbeddedDocumentListFieldTestCase') + + class Comments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = EmbeddedDocumentListField(Comments) + + cls.Comments = Comments + cls.BlogPost = BlogPost + + def setUp(self): + """ + Create two BlogPost entries in the database, each with + several EmbeddedDocuments. + """ + self.post1 = self.BlogPost(comments=[ + self.Comments(author='user1', message='message1'), + self.Comments(author='user2', message='message1') + ]).save() + + self.post2 = self.BlogPost(comments=[ + self.Comments(author='user2', message='message2'), + self.Comments(author='user2', message='message3'), + self.Comments(author='user3', message='message1') + ]).save() + + def tearDown(self): + self.BlogPost.drop_collection() + + @classmethod + def tearDownClass(cls): + cls.db.drop_database('EmbeddedDocumentListFieldTestCase') + + def test_no_keyword_filter(self): + """ + Tests the filter method of a List of Embedded Documents + with a no keyword. + """ + filtered = self.post1.comments.filter() + + # Ensure nothing was changed + # < 2.6 Incompatible > + # self.assertListEqual(filtered, self.post1.comments) + self.assertEqual(filtered, self.post1.comments) + + def test_single_keyword_filter(self): + """ + Tests the filter method of a List of Embedded Documents + with a single keyword. + """ + filtered = self.post1.comments.filter(author='user1') + + # Ensure only 1 entry was returned. + self.assertEqual(len(filtered), 1) + + # Ensure the entry returned is the correct entry. + self.assertEqual(filtered[0].author, 'user1') + + def test_multi_keyword_filter(self): + """ + Tests the filter method of a List of Embedded Documents + with multiple keywords. + """ + filtered = self.post2.comments.filter( + author='user2', message='message2' + ) + + # Ensure only 1 entry was returned. + self.assertEqual(len(filtered), 1) + + # Ensure the entry returned is the correct entry. + self.assertEqual(filtered[0].author, 'user2') + self.assertEqual(filtered[0].message, 'message2') + + def test_chained_filter(self): + """ + Tests chained filter methods of a List of Embedded Documents + """ + filtered = self.post2.comments.filter(author='user2').filter( + message='message2' + ) + + # Ensure only 1 entry was returned. + self.assertEqual(len(filtered), 1) + + # Ensure the entry returned is the correct entry. + self.assertEqual(filtered[0].author, 'user2') + self.assertEqual(filtered[0].message, 'message2') + + def test_unknown_keyword_filter(self): + """ + Tests the filter method of a List of Embedded Documents + when the keyword is not a known keyword. + """ + # < 2.6 Incompatible > + # with self.assertRaises(AttributeError): + # self.post2.comments.filter(year=2) + self.assertRaises(AttributeError, self.post2.comments.filter, year=2) + + def test_no_keyword_exclude(self): + """ + Tests the exclude method of a List of Embedded Documents + with a no keyword. + """ + filtered = self.post1.comments.exclude() + + # Ensure everything was removed + # < 2.6 Incompatible > + # self.assertListEqual(filtered, []) + self.assertEqual(filtered, []) + + def test_single_keyword_exclude(self): + """ + Tests the exclude method of a List of Embedded Documents + with a single keyword. + """ + excluded = self.post1.comments.exclude(author='user1') + + # Ensure only 1 entry was returned. + self.assertEqual(len(excluded), 1) + + # Ensure the entry returned is the correct entry. + self.assertEqual(excluded[0].author, 'user2') + + def test_multi_keyword_exclude(self): + """ + Tests the exclude method of a List of Embedded Documents + with multiple keywords. + """ + excluded = self.post2.comments.exclude( + author='user3', message='message1' + ) + + # Ensure only 2 entries were returned. + self.assertEqual(len(excluded), 2) + + # Ensure the entries returned are the correct entries. + self.assertEqual(excluded[0].author, 'user2') + self.assertEqual(excluded[1].author, 'user2') + + def test_non_matching_exclude(self): + """ + Tests the exclude method of a List of Embedded Documents + when the keyword does not match any entries. + """ + excluded = self.post2.comments.exclude(author='user4') + + # Ensure the 3 entries still exist. + self.assertEqual(len(excluded), 3) + + def test_unknown_keyword_exclude(self): + """ + Tests the exclude method of a List of Embedded Documents + when the keyword is not a known keyword. + """ + # < 2.6 Incompatible > + # with self.assertRaises(AttributeError): + # self.post2.comments.exclude(year=2) + self.assertRaises(AttributeError, self.post2.comments.exclude, year=2) + + def test_chained_filter_exclude(self): + """ + Tests the exclude method after a filter method of a List of + Embedded Documents. + """ + excluded = self.post2.comments.filter(author='user2').exclude( + message='message2' + ) + + # Ensure only 1 entry was returned. + self.assertEqual(len(excluded), 1) + + # Ensure the entry returned is the correct entry. + self.assertEqual(excluded[0].author, 'user2') + self.assertEqual(excluded[0].message, 'message3') + + def test_count(self): + """ + Tests the count method of a List of Embedded Documents. + """ + self.assertEqual(self.post1.comments.count(), 2) + self.assertEqual(self.post1.comments.count(), len(self.post1.comments)) + + def test_filtered_count(self): + """ + Tests the filter + count method of a List of Embedded Documents. + """ + count = self.post1.comments.filter(author='user1').count() + self.assertEqual(count, 1) + + def test_single_keyword_get(self): + """ + Tests the get method of a List of Embedded Documents using a + single keyword. + """ + comment = self.post1.comments.get(author='user1') + + # < 2.6 Incompatible > + # self.assertIsInstance(comment, self.Comments) + self.assertTrue(isinstance(comment, self.Comments)) + self.assertEqual(comment.author, 'user1') + + def test_multi_keyword_get(self): + """ + Tests the get method of a List of Embedded Documents using + multiple keywords. + """ + comment = self.post2.comments.get(author='user2', message='message2') + + # < 2.6 Incompatible > + # self.assertIsInstance(comment, self.Comments) + self.assertTrue(isinstance(comment, self.Comments)) + self.assertEqual(comment.author, 'user2') + self.assertEqual(comment.message, 'message2') + + def test_no_keyword_multiple_return_get(self): + """ + Tests the get method of a List of Embedded Documents without + a keyword to return multiple documents. + """ + # < 2.6 Incompatible > + # with self.assertRaises(MultipleObjectsReturned): + # self.post1.comments.get() + self.assertRaises(MultipleObjectsReturned, self.post1.comments.get) + + def test_keyword_multiple_return_get(self): + """ + Tests the get method of a List of Embedded Documents with a keyword + to return multiple documents. + """ + # < 2.6 Incompatible > + # with self.assertRaises(MultipleObjectsReturned): + # self.post2.comments.get(author='user2') + self.assertRaises( + MultipleObjectsReturned, self.post2.comments.get, author='user2' + ) + + def test_unknown_keyword_get(self): + """ + Tests the get method of a List of Embedded Documents with an + unknown keyword. + """ + # < 2.6 Incompatible > + # with self.assertRaises(AttributeError): + # self.post2.comments.get(year=2020) + self.assertRaises(AttributeError, self.post2.comments.get, year=2020) + + def test_no_result_get(self): + """ + Tests the get method of a List of Embedded Documents where get + returns no results. + """ + # < 2.6 Incompatible > + # with self.assertRaises(DoesNotExist): + # self.post1.comments.get(author='user3') + self.assertRaises( + DoesNotExist, self.post1.comments.get, author='user3' + ) + + def test_first(self): + """ + Tests the first method of a List of Embedded Documents to + ensure it returns the first comment. + """ + comment = self.post1.comments.first() + + # Ensure a Comment object was returned. + # < 2.6 Incompatible > + # self.assertIsInstance(comment, self.Comments) + self.assertTrue(isinstance(comment, self.Comments)) + self.assertEqual(comment, self.post1.comments[0]) + + def test_create(self): + """ + Test the create method of a List of Embedded Documents. + """ + comment = self.post1.comments.create( + author='user4', message='message1' + ) + self.post1.save() + + # Ensure the returned value is the comment object. + # < 2.6 Incompatible > + # self.assertIsInstance(comment, self.Comments) + self.assertTrue(isinstance(comment, self.Comments)) + self.assertEqual(comment.author, 'user4') + self.assertEqual(comment.message, 'message1') + + # Ensure the new comment was actually saved to the database. + # < 2.6 Incompatible > + # self.assertIn( + # comment, + # self.BlogPost.objects(comments__author='user4')[0].comments + # ) + self.assertTrue( + comment in self.BlogPost.objects( + comments__author='user4' + )[0].comments + ) + + def test_filtered_create(self): + """ + Test the create method of a List of Embedded Documents chained + to a call to the filter method. Filtering should have no effect + on creation. + """ + comment = self.post1.comments.filter(author='user1').create( + author='user4', message='message1' + ) + self.post1.save() + + # Ensure the returned value is the comment object. + # < 2.6 Incompatible > + # self.assertIsInstance(comment, self.Comments) + self.assertTrue(isinstance(comment, self.Comments)) + self.assertEqual(comment.author, 'user4') + self.assertEqual(comment.message, 'message1') + + # Ensure the new comment was actually saved to the database. + # < 2.6 Incompatible > + # self.assertIn( + # comment, + # self.BlogPost.objects(comments__author='user4')[0].comments + # ) + self.assertTrue( + comment in self.BlogPost.objects( + comments__author='user4' + )[0].comments + ) + + def test_no_keyword_update(self): + """ + Tests the update method of a List of Embedded Documents with + no keywords. + """ + original = list(self.post1.comments) + number = self.post1.comments.update() + self.post1.save() + + # Ensure that nothing was altered. + # < 2.6 Incompatible > + # self.assertIn( + # original[0], + # self.BlogPost.objects(id=self.post1.id)[0].comments + # ) + self.assertTrue( + original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments + ) + + # < 2.6 Incompatible > + # self.assertIn( + # original[1], + # self.BlogPost.objects(id=self.post1.id)[0].comments + # ) + self.assertTrue( + original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments + ) + + # Ensure the method returned 0 as the number of entries + # modified + self.assertEqual(number, 0) + + def test_single_keyword_update(self): + """ + Tests the update method of a List of Embedded Documents with + a single keyword. + """ + number = self.post1.comments.update(author='user4') + self.post1.save() + + comments = self.BlogPost.objects(id=self.post1.id)[0].comments + + # Ensure that the database was updated properly. + self.assertEqual(comments[0].author, 'user4') + self.assertEqual(comments[1].author, 'user4') + + # Ensure the method returned 2 as the number of entries + # modified + self.assertEqual(number, 2) + + def test_save(self): + """ + Tests the save method of a List of Embedded Documents. + """ + comments = self.post1.comments + new_comment = self.Comments(author='user4') + comments.append(new_comment) + comments.save() + + # Ensure that the new comment has been added to the database. + # < 2.6 Incompatible > + # self.assertIn( + # new_comment, + # self.BlogPost.objects(id=self.post1.id)[0].comments + # ) + self.assertTrue( + new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments + ) + + def test_delete(self): + """ + Tests the delete method of a List of Embedded Documents. + """ + number = self.post1.comments.delete() + self.post1.save() + + # Ensure that all the comments under post1 were deleted in the + # database. + # < 2.6 Incompatible > + # self.assertListEqual( + # self.BlogPost.objects(id=self.post1.id)[0].comments, [] + # ) + self.assertEqual( + self.BlogPost.objects(id=self.post1.id)[0].comments, [] + ) + + # Ensure that post1 comments were deleted from the list. + # < 2.6 Incompatible > + # self.assertListEqual(self.post1.comments, []) + self.assertEqual(self.post1.comments, []) + + # Ensure that comments still returned a EmbeddedDocumentList object. + # < 2.6 Incompatible > + # self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) + self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList)) + + # Ensure that the delete method returned 2 as the number of entries + # deleted from the database + self.assertEqual(number, 2) + + def test_filtered_delete(self): + """ + Tests the delete method of a List of Embedded Documents + after the filter method has been called. + """ + comment = self.post1.comments[1] + number = self.post1.comments.filter(author='user2').delete() + self.post1.save() + + # Ensure that only the user2 comment was deleted. + # < 2.6 Incompatible > + # self.assertNotIn( + # comment, self.BlogPost.objects(id=self.post1.id)[0].comments + # ) + self.assertTrue( + comment not in self.BlogPost.objects(id=self.post1.id)[0].comments + ) + self.assertEqual( + len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1 + ) + + # Ensure that the user2 comment no longer exists in the list. + # < 2.6 Incompatible > + # self.assertNotIn(comment, self.post1.comments) + self.assertTrue(comment not in self.post1.comments) + self.assertEqual(len(self.post1.comments), 1) + + # Ensure that the delete method returned 1 as the number of entries + # deleted from the database + self.assertEqual(number, 1) + if __name__ == '__main__': unittest.main()