diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 59204d4d..f30b2c15 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from bson import DBRef, SON import six @@ -201,6 +202,10 @@ class DeReference(object): as_tuple = isinstance(items, tuple) iterator = enumerate(items) data = [] + elif isinstance(items, OrderedDict): + is_list = False + iterator = items.iteritems() + data = OrderedDict() else: is_list = False iterator = items.iteritems() diff --git a/mongoengine/fields.py b/mongoengine/fields.py index b67b385d..06c56f06 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -5,6 +5,7 @@ import re import time import uuid import warnings +from collections import Mapping from operator import itemgetter from bson import Binary, DBRef, ObjectId, SON @@ -619,6 +620,14 @@ class DynamicField(BaseField): Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" + def __init__(self, container_class=dict, *args, **kwargs): + self._container_cls = container_class + if not issubclass(self._container_cls, Mapping): + self.error('The class that is specified in `container_class` parameter ' + 'must be a subclass of `dict`.') + + super(DynamicField, self).__init__(*args, **kwargs) + def to_mongo(self, value, use_db_field=True, fields=None): """Convert a Python type to a MongoDB compatible type. """ @@ -644,7 +653,7 @@ class DynamicField(BaseField): is_list = True value = {k: v for k, v in enumerate(value)} - data = {} + data = self._container_cls() for k, v in value.iteritems(): data[k] = self.to_mongo(v, use_db_field, fields) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index f4ad0fa2..4017377d 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -5,8 +5,10 @@ import uuid import math import itertools import re +import pymongo from nose.plugins.skip import SkipTest +from collections import OrderedDict import six try: @@ -25,9 +27,12 @@ except ImportError: from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, - _document_registry) + _document_registry, TopLevelDocumentMetaclass) -from tests.utils import MongoDBTestCase +from tests.utils import MongoDBTestCase, MONGO_TEST_DB +from mongoengine.python_support import IS_PYMONGO_3 +if IS_PYMONGO_3: + from bson import CodecOptions __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") @@ -4110,6 +4115,67 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.assertTrue(hasattr(CustomData.c_field, 'custom_data')) self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) + def test_dynamicfield_with_container_class(self): + """ + Tests that object can be stored in order by DynamicField class + with container_class parameter. + """ + raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)] + + class Doc(Document): + ordered_data = DynamicField(container_class=OrderedDict) + unordered_data = DynamicField() + + Doc.drop_collection() + + doc = Doc(ordered_data=OrderedDict(raw_data), unordered_data=dict(raw_data)).save() + + # checks that the data is in order + self.assertEqual(type(doc.ordered_data), OrderedDict) + self.assertEqual(type(doc.unordered_data), dict) + self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a') + + # checks that the data is stored to the database in order + pymongo_db = pymongo.MongoClient()[MONGO_TEST_DB] + if IS_PYMONGO_3: + codec_option = CodecOptions(document_class=OrderedDict) + db_doc = pymongo_db.doc.with_options(codec_options=codec_option).find_one() + else: + db_doc = pymongo_db.doc.find_one(as_class=OrderedDict) + + self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a') + + def test_dynamicfield_with_wrong_container_class(self): + with self.assertRaises(ValidationError): + class DocWithInvalidField: + data = DynamicField(container_class=list) + + def test_dynamicfield_with_wrong_container_class_and_reload_docuemnt(self): + # This is because 'codec_options' is supported on pymongo3 or later + if IS_PYMONGO_3: + class OrderedDocument(Document): + my_metaclass = TopLevelDocumentMetaclass + __metaclass__ = TopLevelDocumentMetaclass + + @classmethod + def _get_collection(cls): + collection = super(OrderedDocument, cls)._get_collection() + opts = CodecOptions(document_class=OrderedDict) + + return collection.with_options(codec_options=opts) + + raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)] + + class Doc(OrderedDocument): + data = DynamicField(container_class=OrderedDict) + + Doc.drop_collection() + + doc = Doc(data=OrderedDict(raw_data)).save() + doc.reload() + + self.assertEqual(type(doc.data), OrderedDict) + self.assertEqual(','.join(doc.data.keys()), 'd,c,b,a') class CachedReferenceFieldTest(MongoDBTestCase): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 7f58a85b..9a976611 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -2,10 +2,15 @@ import unittest from bson import DBRef, ObjectId +from collections import OrderedDict from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import query_counter +from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.base import TopLevelDocumentMetaclass +if IS_PYMONGO_3: + from bson import CodecOptions class FieldTest(unittest.TestCase): @@ -1287,5 +1292,70 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) + def test_dynamic_field_dereference(self): + class Merchandise(Document): + name = StringField() + price = IntField() + + class Store(Document): + merchandises = DynamicField() + + Merchandise.drop_collection() + Store.drop_collection() + + merchandises = { + '#1': Merchandise(name='foo', price=100).save(), + '#2': Merchandise(name='bar', price=120).save(), + '#3': Merchandise(name='baz', price=110).save(), + } + Store(merchandises=merchandises).save() + + store = Store.objects().first() + for obj in store.merchandises.values(): + self.assertFalse(isinstance(obj, Merchandise)) + + store.select_related() + for obj in store.merchandises.values(): + self.assertTrue(isinstance(obj, Merchandise)) + + def test_dynamic_field_dereference_with_ordering_guarantee_on_pymongo3(self): + # This is because 'codec_options' is supported on pymongo3 or later + if IS_PYMONGO_3: + class OrderedDocument(Document): + my_metaclass = TopLevelDocumentMetaclass + __metaclass__ = TopLevelDocumentMetaclass + + @classmethod + def _get_collection(cls): + collection = super(OrderedDocument, cls)._get_collection() + opts = CodecOptions(document_class=OrderedDict) + + return collection.with_options(codec_options=opts) + + class Merchandise(Document): + name = StringField() + price = IntField() + + class Store(OrderedDocument): + merchandises = DynamicField(container_class=OrderedDict) + + Merchandise.drop_collection() + Store.drop_collection() + + merchandises = OrderedDict() + merchandises['#1'] = Merchandise(name='foo', price=100).save() + merchandises['#2'] = Merchandise(name='bar', price=120).save() + merchandises['#3'] = Merchandise(name='baz', price=110).save() + + Store(merchandises=merchandises).save() + + store = Store.objects().first() + + store.select_related() + + # confirms that the load data order is same with the one at storing + self.assertTrue(type(store.merchandises), OrderedDict) + self.assertEqual(','.join(store.merchandises.keys()), '#1,#2,#3') + if __name__ == '__main__': unittest.main()