Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
caa9b34361 use an external sphinx rtd theme 2017-04-30 15:18:38 -04:00
13 changed files with 265 additions and 176 deletions

View File

@@ -16,6 +16,7 @@ python:
- 2.7 - 2.7
- 3.5 - 3.5
- pypy - pypy
- pypy3
env: env:
- MONGODB=2.6 PYMONGO=2.7 - MONGODB=2.6 PYMONGO=2.7

View File

@@ -6,12 +6,6 @@ Development
=========== ===========
- (Fill this out as you fix issues and develop your features). - (Fill this out as you fix issues and develop your features).
Changes in 0.14.0
=================
- BREAKING CHANGE: Removed the `coerce_types` param from `QuerySet.as_pymongo` #1549
- POTENTIAL BREAKING CHANGE: Made EmbeddedDocument not hashable by default #1528
- Improved code quality #1531, #1540, #1541, #1547
Changes in 0.13.0 Changes in 0.13.0
================= =================
- POTENTIAL BREAKING CHANGE: Added Unicode support to the `EmailField`, see - POTENTIAL BREAKING CHANGE: Added Unicode support to the `EmailField`, see

View File

@@ -6,18 +6,6 @@ Development
*********** ***********
(Fill this out whenever you introduce breaking changes to MongoEngine) (Fill this out whenever you introduce breaking changes to MongoEngine)
0.14.0
******
This release includes a few bug fixes and a significant code cleanup. The most
important change is that `QuerySet.as_pymongo` no longer supports a
`coerce_types` mode. If you used it in the past, a) please let us know of your
use case, b) you'll need to override `as_pymongo` to get the desired outcome.
This release also makes the EmbeddedDocument not hashable by default. If you
use embedded documents in sets or dictionaries, you might have to override
`__hash__` and implement a hashing logic specific to your use case. See #1528
for the reason behind this change.
0.13.0 0.13.0
****** ******
This release adds Unicode support to the `EmailField` and changes its This release adds Unicode support to the `EmailField` and changes its

View File

@@ -23,7 +23,7 @@ __all__ = (list(document.__all__) + list(fields.__all__) +
list(signals.__all__) + list(errors.__all__)) list(signals.__all__) + list(errors.__all__))
VERSION = (0, 14, 0) VERSION = (0, 13, 0)
def get_version(): def get_version():

View File

@@ -81,14 +81,7 @@ class BaseField(object):
self.sparse = sparse self.sparse = sparse
self._owner_document = None self._owner_document = None
# Make sure db_field is a string (if it's explicitly defined). # Validate the db_field
if (
self.db_field is not None and
not isinstance(self.db_field, six.string_types)
):
raise TypeError('db_field should be a string.')
# Make sure db_field doesn't contain any forbidden characters.
if isinstance(self.db_field, six.string_types) and ( if isinstance(self.db_field, six.string_types) and (
'.' in self.db_field or '.' in self.db_field or
'\0' in self.db_field or '\0' in self.db_field or

View File

@@ -1,3 +1,4 @@
from collections import OrderedDict
from bson import DBRef, SON from bson import DBRef, SON
import six import six
@@ -201,6 +202,10 @@ class DeReference(object):
as_tuple = isinstance(items, tuple) as_tuple = isinstance(items, tuple)
iterator = enumerate(items) iterator = enumerate(items)
data = [] data = []
elif isinstance(items, OrderedDict):
is_list = False
iterator = items.iteritems()
data = OrderedDict()
else: else:
is_list = False is_list = False
iterator = items.iteritems() iterator = items.iteritems()

View File

@@ -300,7 +300,7 @@ class Document(BaseDocument):
created. created.
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents. updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param clean: call the document clean method, requires `validate` to be :param clean: call the document clean method, requires `validate` to be
True. True.
@@ -441,21 +441,6 @@ class Document(BaseDocument):
return object_id return object_id
def _get_update_doc(self):
"""Return a dict containing all the $set and $unset operations
that should be sent to MongoDB based on the changes made to this
Document.
"""
updates, removals = self._delta()
update_doc = {}
if updates:
update_doc['$set'] = updates
if removals:
update_doc['$unset'] = removals
return update_doc
def _save_update(self, doc, save_condition, write_concern): def _save_update(self, doc, save_condition, write_concern):
"""Update an existing document. """Update an existing document.
@@ -481,10 +466,15 @@ class Document(BaseDocument):
val = val[ak] val = val[ak]
select_dict['.'.join(actual_key)] = val select_dict['.'.join(actual_key)] = val
update_doc = self._get_update_doc() updates, removals = self._delta()
if update_doc: update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None upsert = save_condition is None
last_error = collection.update(select_dict, update_doc, last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern) upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0: if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing' raise SaveConditionError('Race condition preventing'

View File

@@ -6,6 +6,7 @@ import socket
import time import time
import uuid import uuid
import warnings import warnings
from collections import Mapping
from operator import itemgetter from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON from bson import Binary, DBRef, ObjectId, SON
@@ -704,6 +705,14 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def __init__(self, container_class=dict, *args, **kwargs):
self._container_cls = container_class
if not issubclass(self._container_cls, Mapping):
self.error('The class that is specified in `container_class` parameter '
'must be a subclass of `dict`.')
super(DynamicField, self).__init__(*args, **kwargs)
def to_mongo(self, value, use_db_field=True, fields=None): def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB compatible type. """Convert a Python type to a MongoDB compatible type.
""" """
@@ -729,7 +738,7 @@ class DynamicField(BaseField):
is_list = True is_list = True
value = {k: v for k, v in enumerate(value)} value = {k: v for k, v in enumerate(value)}
data = {} data = self._container_cls()
for k, v in value.iteritems(): for k, v in value.iteritems():
data[k] = self.to_mongo(v, use_db_field, fields) data[k] = self.to_mongo(v, use_db_field, fields)

View File

@@ -67,6 +67,7 @@ class BaseQuerySet(object):
self._scalar = [] self._scalar = []
self._none = False self._none = False
self._as_pymongo = False self._as_pymongo = False
self._as_pymongo_coerce = False
self._search_text = None self._search_text = None
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
@@ -727,12 +728,11 @@ class BaseQuerySet(object):
'%s is not a subclass of BaseQuerySet' % new_qs.__name__) '%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_snapshot', '_timeout', '_class_check', '_slave_okay', '_timeout', '_class_check', '_slave_okay', '_read_preference',
'_read_preference', '_iter', '_scalar', '_as_pymongo', '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference', '_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms', '_search_text', 'only_fields', '_max_time_ms', '_comment')
'_comment')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
@@ -939,8 +939,7 @@ class BaseQuerySet(object):
posts = BlogPost.objects(...).fields(slice__comments=5) posts = BlogPost.objects(...).fields(slice__comments=5)
:param kwargs: A set of keyword arguments identifying what to :param kwargs: A set keywors arguments identifying what to include.
include, exclude, or slice.
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
@@ -1129,15 +1128,16 @@ class BaseQuerySet(object):
"""An alias for scalar""" """An alias for scalar"""
return self.scalar(*fields) return self.scalar(*fields)
def as_pymongo(self): def as_pymongo(self, coerce_types=False):
"""Instead of returning Document instances, return raw values from """Instead of returning Document instances, return raw values from
pymongo. pymongo.
This method is particularly useful if you don't need dereferencing :param coerce_types: Field types (if applicable) would be use to
and care primarily about the speed of data retrieval. coerce types.
""" """
queryset = self.clone() queryset = self.clone()
queryset._as_pymongo = True queryset._as_pymongo = True
queryset._as_pymongo_coerce = coerce_types
return queryset return queryset
def max_time_ms(self, ms): def max_time_ms(self, ms):
@@ -1799,25 +1799,59 @@ class BaseQuerySet(object):
return tuple(data) return tuple(data)
def _get_as_pymongo(self, doc): def _get_as_pymongo(self, row):
"""Clean up a PyMongo doc, removing fields that were only fetched # Extract which fields paths we should follow if .fields(...) was
for the sake of MongoEngine's implementation, and return it. # used. If not, handle all fields.
""" if not getattr(self, '__as_pymongo_fields', None):
# Always remove _cls as a MongoEngine's implementation detail. self.__as_pymongo_fields = []
if '_cls' in doc:
del doc['_cls']
# If the _id was not included in a .only or was excluded in a .exclude, for field in self._loaded_fields.fields - set(['_cls']):
# remove it from the doc (we always fetch it so that we can properly self.__as_pymongo_fields.append(field)
# construct documents). while '.' in field:
fields = self._loaded_fields field, _ = field.rsplit('.', 1)
if fields and '_id' in doc and ( self.__as_pymongo_fields.append(field)
(fields.value == QueryFieldList.ONLY and '_id' not in fields.fields) or
(fields.value == QueryFieldList.EXCLUDE and '_id' in fields.fields)
):
del doc['_id']
return doc all_fields = not self.__as_pymongo_fields
def clean(data, path=None):
path = path or ''
if isinstance(data, dict):
new_data = {}
for key, value in data.iteritems():
new_path = '%s.%s' % (path, key) if path else key
if all_fields:
include_field = True
elif self._loaded_fields.value == QueryFieldList.ONLY:
include_field = new_path in self.__as_pymongo_fields
else:
include_field = new_path not in self.__as_pymongo_fields
if include_field:
new_data[key] = clean(value, path=new_path)
data = new_data
elif isinstance(data, list):
data = [clean(d, path=path) for d in data]
else:
if self._as_pymongo_coerce:
# If we need to coerce types, we need to determine the
# type of this field and use the corresponding
# .to_python(...)
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
if obj is None:
break
elif isinstance(obj, EmbeddedDocumentField):
obj = obj.document_type
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where

View File

@@ -242,7 +242,7 @@ class InstanceTest(unittest.TestCase):
Zoo.drop_collection() Zoo.drop_collection()
class Zoo(Document): class Zoo(Document):
animals = ListField(GenericReferenceField()) animals = ListField(GenericReferenceField(Animal))
# Save a reference to each animal # Save a reference to each animal
zoo = Zoo(animals=Animal.objects) zoo = Zoo(animals=Animal.objects)

View File

@@ -5,9 +5,11 @@ import uuid
import math import math
import itertools import itertools
import re import re
import pymongo
import sys import sys
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from collections import OrderedDict
import six import six
try: try:
@@ -26,9 +28,12 @@ except ImportError:
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
_document_registry) _document_registry, TopLevelDocumentMetaclass)
from tests.utils import MongoDBTestCase from tests.utils import MongoDBTestCase, MONGO_TEST_DB
from mongoengine.python_support import IS_PYMONGO_3
if IS_PYMONGO_3:
from bson import CodecOptions
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
@@ -4183,6 +4188,67 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.assertTrue(hasattr(CustomData.c_field, 'custom_data')) self.assertTrue(hasattr(CustomData.c_field, 'custom_data'))
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
def test_dynamicfield_with_container_class(self):
"""
Tests that object can be stored in order by DynamicField class
with container_class parameter.
"""
raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)]
class Doc(Document):
ordered_data = DynamicField(container_class=OrderedDict)
unordered_data = DynamicField()
Doc.drop_collection()
doc = Doc(ordered_data=OrderedDict(raw_data), unordered_data=dict(raw_data)).save()
# checks that the data is in order
self.assertEqual(type(doc.ordered_data), OrderedDict)
self.assertEqual(type(doc.unordered_data), dict)
self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a')
# checks that the data is stored to the database in order
pymongo_db = pymongo.MongoClient()[MONGO_TEST_DB]
if IS_PYMONGO_3:
codec_option = CodecOptions(document_class=OrderedDict)
db_doc = pymongo_db.doc.with_options(codec_options=codec_option).find_one()
else:
db_doc = pymongo_db.doc.find_one(as_class=OrderedDict)
self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a')
def test_dynamicfield_with_wrong_container_class(self):
with self.assertRaises(ValidationError):
class DocWithInvalidField:
data = DynamicField(container_class=list)
def test_dynamicfield_with_wrong_container_class_and_reload_docuemnt(self):
# This is because 'codec_options' is supported on pymongo3 or later
if IS_PYMONGO_3:
class OrderedDocument(Document):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
@classmethod
def _get_collection(cls):
collection = super(OrderedDocument, cls)._get_collection()
opts = CodecOptions(document_class=OrderedDict)
return collection.with_options(codec_options=opts)
raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)]
class Doc(OrderedDocument):
data = DynamicField(container_class=OrderedDict)
Doc.drop_collection()
doc = Doc(data=OrderedDict(raw_data)).save()
doc.reload()
self.assertEqual(type(doc.data), OrderedDict)
self.assertEqual(','.join(doc.data.keys()), 'd,c,b,a')
class CachedReferenceFieldTest(MongoDBTestCase): class CachedReferenceFieldTest(MongoDBTestCase):

View File

@@ -917,9 +917,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 3) self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):
"""Make sure we don't perform unnecessary db operations when
none of document's fields were updated.
"""
class Person(Document): class Person(Document):
name = StringField() name = StringField()
owns = ListField(ReferenceField('Organization')) owns = ListField(ReferenceField('Organization'))
@@ -927,8 +925,8 @@ class QuerySetTest(unittest.TestCase):
class Organization(Document): class Organization(Document):
name = StringField() name = StringField()
owner = ReferenceField(Person) owner = ReferenceField('Person')
employees = ListField(ReferenceField(Person)) employees = ListField(ReferenceField('Person'))
class Project(Document): class Project(Document):
name = StringField() name = StringField()
@@ -947,35 +945,35 @@ class QuerySetTest(unittest.TestCase):
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
# Fetching a document should result in a query. fresh_o1 = Organization.objects.get(id=o1.id)
org = Organization.objects.get(id=o1.id) self.assertEqual(1, q)
self.assertEqual(q, 1) fresh_o1._get_changed_fields()
self.assertEqual(1, q)
# Checking changed fields of a newly fetched document should not
# result in a query.
org._get_changed_fields()
self.assertEqual(q, 1)
# Saving a doc without changing any of its fields should not result
# in a query (with or without cascade=False).
org = Organization.objects.get(id=o1.id)
with query_counter() as q: with query_counter() as q:
org.save()
self.assertEqual(q, 0) self.assertEqual(q, 0)
org = Organization.objects.get(id=o1.id) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.save() # No changes, does nothing
self.assertEqual(q, 1)
with query_counter() as q: with query_counter() as q:
org.save(cascade=False)
self.assertEqual(q, 0) self.assertEqual(q, 0)
# Saving a doc after you append a reference to it should result in fresh_o1 = Organization.objects.get(id=o1.id)
# two db operations (a query for the reference and an update). fresh_o1.save(cascade=False) # No changes, does nothing
# TODO dereferencing of p2 shouldn't be necessary.
org = Organization.objects.get(id=o1.id) self.assertEqual(q, 1)
with query_counter() as q: with query_counter() as q:
org.employees.append(p2) # dereferences p2 self.assertEqual(q, 0)
org.save() # saves the org
self.assertEqual(q, 2) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.employees.append(p2) # Dereferences
fresh_o1.save(cascade=False) # Saves
self.assertEqual(q, 3)
@skip_pymongo3 @skip_pymongo3
def test_slave_okay(self): def test_slave_okay(self):
@@ -4047,35 +4045,6 @@ class QuerySetTest(unittest.TestCase):
plist = list(Person.objects.scalar('name', 'state')) plist = list(Person.objects.scalar('name', 'state'))
self.assertEqual(plist, [(u'Wilson JR', s1)]) self.assertEqual(plist, [(u'Wilson JR', s1)])
def test_generic_reference_field_with_only_and_as_pymongo(self):
class TestPerson(Document):
name = StringField()
class TestActivity(Document):
name = StringField()
owner = GenericReferenceField()
TestPerson.drop_collection()
TestActivity.drop_collection()
person = TestPerson(name='owner')
person.save()
a1 = TestActivity(name='a1', owner=person)
a1.save()
activity = TestActivity.objects(owner=person).scalar('id', 'owner').no_dereference().first()
self.assertEqual(activity[0], a1.pk)
self.assertEqual(activity[1]['_ref'], DBRef('test_person', person.pk))
activity = TestActivity.objects(owner=person).only('id', 'owner')[0]
self.assertEqual(activity.pk, a1.pk)
self.assertEqual(activity.owner, person)
activity = TestActivity.objects(owner=person).only('id', 'owner').as_pymongo().first()
self.assertEqual(activity['_id'], a1.pk)
self.assertTrue(activity['owner']['_ref'], DBRef('test_person', person.pk))
def test_scalar_db_field(self): def test_scalar_db_field(self):
class TestDoc(Document): class TestDoc(Document):
@@ -4421,44 +4390,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
def test_as_pymongo(self): def test_as_pymongo(self):
from decimal import Decimal
class LastLogin(EmbeddedDocument): from decimal import Decimal
location = StringField()
ip = StringField()
class User(Document): class User(Document):
id = ObjectIdField('_id') id = ObjectIdField('_id')
name = StringField() name = StringField()
age = IntField() age = IntField()
price = DecimalField() price = DecimalField()
last_login = EmbeddedDocumentField(LastLogin)
User.drop_collection() User.drop_collection()
User(name="Bob Dole", age=89, price=Decimal('1.11')).save()
User.objects.create(name="Bob Dole", age=89, price=Decimal('1.11')) User(name="Barack Obama", age=51, price=Decimal('2.22')).save()
User.objects.create(
name="Barack Obama",
age=51,
price=Decimal('2.22'),
last_login=LastLogin(
location='White House',
ip='104.107.108.116'
)
)
results = User.objects.as_pymongo()
self.assertEqual(
set(results[0].keys()),
set(['_id', 'name', 'age', 'price'])
)
self.assertEqual(
set(results[1].keys()),
set(['_id', 'name', 'age', 'price', 'last_login'])
)
results = User.objects.only('id', 'name').as_pymongo() results = User.objects.only('id', 'name').as_pymongo()
self.assertEqual(set(results[0].keys()), set(['_id', 'name'])) self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name']))
users = User.objects.only('name', 'price').as_pymongo() users = User.objects.only('name', 'price').as_pymongo()
results = list(users) results = list(users)
@@ -4469,20 +4415,16 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], 2.22) self.assertEqual(results[1]['price'], 2.22)
users = User.objects.only('name', 'last_login').as_pymongo() # Test coerce_types
users = User.objects.only(
'name', 'price').as_pymongo(coerce_types=True)
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
self.assertEqual(results[0], { self.assertEqual(results[0]['name'], 'Bob Dole')
'name': 'Bob Dole' self.assertEqual(results[0]['price'], Decimal('1.11'))
}) self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1], { self.assertEqual(results[1]['price'], Decimal('2.22'))
'name': 'Barack Obama',
'last_login': {
'location': 'White House',
'ip': '104.107.108.116'
}
})
def test_as_pymongo_json_limit_fields(self): def test_as_pymongo_json_limit_fields(self):
@@ -4646,6 +4588,7 @@ class QuerySetTest(unittest.TestCase):
def test_no_cache(self): def test_no_cache(self):
"""Ensure you can add meta data to file""" """Ensure you can add meta data to file"""
class Noddy(Document): class Noddy(Document):
fields = DictField() fields = DictField()
@@ -4663,19 +4606,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(list(docs)), 100) self.assertEqual(len(list(docs)), 100)
# Can't directly get a length of a no-cache queryset.
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
len(docs) len(docs)
# Another iteration over the queryset should result in another db op.
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0)
list(docs) list(docs)
self.assertEqual(q, 1) self.assertEqual(q, 1)
# ... and another one to double-check.
with query_counter() as q:
list(docs) list(docs)
self.assertEqual(q, 1) self.assertEqual(q, 2)
def test_nested_queryset_iterator(self): def test_nested_queryset_iterator(self):
# Try iterating the same queryset twice, nested. # Try iterating the same queryset twice, nested.

View File

@@ -2,10 +2,15 @@
import unittest import unittest
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
from collections import OrderedDict
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter from mongoengine.context_managers import query_counter
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.base import TopLevelDocumentMetaclass
if IS_PYMONGO_3:
from bson import CodecOptions
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@@ -1287,5 +1292,70 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_dynamic_field_dereference(self):
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(Document):
merchandises = DynamicField()
Merchandise.drop_collection()
Store.drop_collection()
merchandises = {
'#1': Merchandise(name='foo', price=100).save(),
'#2': Merchandise(name='bar', price=120).save(),
'#3': Merchandise(name='baz', price=110).save(),
}
Store(merchandises=merchandises).save()
store = Store.objects().first()
for obj in store.merchandises.values():
self.assertFalse(isinstance(obj, Merchandise))
store.select_related()
for obj in store.merchandises.values():
self.assertTrue(isinstance(obj, Merchandise))
def test_dynamic_field_dereference_with_ordering_guarantee_on_pymongo3(self):
# This is because 'codec_options' is supported on pymongo3 or later
if IS_PYMONGO_3:
class OrderedDocument(Document):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
@classmethod
def _get_collection(cls):
collection = super(OrderedDocument, cls)._get_collection()
opts = CodecOptions(document_class=OrderedDict)
return collection.with_options(codec_options=opts)
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(OrderedDocument):
merchandises = DynamicField(container_class=OrderedDict)
Merchandise.drop_collection()
Store.drop_collection()
merchandises = OrderedDict()
merchandises['#1'] = Merchandise(name='foo', price=100).save()
merchandises['#2'] = Merchandise(name='bar', price=120).save()
merchandises['#3'] = Merchandise(name='baz', price=110).save()
Store(merchandises=merchandises).save()
store = Store.objects().first()
store.select_related()
# confirms that the load data order is same with the one at storing
self.assertTrue(type(store.merchandises), OrderedDict)
self.assertEqual(','.join(store.merchandises.keys()), '#1,#2,#3')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()