Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
7195236a3b better db_field validation 2017-05-07 20:26:52 -04:00
7 changed files with 230 additions and 103 deletions

View File

@@ -16,6 +16,8 @@ python:
- 2.7 - 2.7
- 3.5 - 3.5
- pypy - pypy
- pypy3.3-5.2-alpha1
env: env:
- MONGODB=2.6 PYMONGO=2.7 - MONGODB=2.6 PYMONGO=2.7

View File

@@ -1,3 +1,4 @@
from collections import OrderedDict
from bson import DBRef, SON from bson import DBRef, SON
import six import six
@@ -201,6 +202,10 @@ class DeReference(object):
as_tuple = isinstance(items, tuple) as_tuple = isinstance(items, tuple)
iterator = enumerate(items) iterator = enumerate(items)
data = [] data = []
elif isinstance(items, OrderedDict):
is_list = False
iterator = items.iteritems()
data = OrderedDict()
else: else:
is_list = False is_list = False
iterator = items.iteritems() iterator = items.iteritems()

View File

@@ -6,6 +6,7 @@ import socket
import time import time
import uuid import uuid
import warnings import warnings
from collections import Mapping
from operator import itemgetter from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON from bson import Binary, DBRef, ObjectId, SON
@@ -704,6 +705,14 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def __init__(self, container_class=dict, *args, **kwargs):
self._container_cls = container_class
if not issubclass(self._container_cls, Mapping):
self.error('The class that is specified in `container_class` parameter '
'must be a subclass of `dict`.')
super(DynamicField, self).__init__(*args, **kwargs)
def to_mongo(self, value, use_db_field=True, fields=None): def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB compatible type. """Convert a Python type to a MongoDB compatible type.
""" """
@@ -729,7 +738,7 @@ class DynamicField(BaseField):
is_list = True is_list = True
value = {k: v for k, v in enumerate(value)} value = {k: v for k, v in enumerate(value)}
data = {} data = self._container_cls()
for k, v in value.iteritems(): for k, v in value.iteritems():
data[k] = self.to_mongo(v, use_db_field, fields) data[k] = self.to_mongo(v, use_db_field, fields)

View File

@@ -67,6 +67,7 @@ class BaseQuerySet(object):
self._scalar = [] self._scalar = []
self._none = False self._none = False
self._as_pymongo = False self._as_pymongo = False
self._as_pymongo_coerce = False
self._search_text = None self._search_text = None
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
@@ -727,12 +728,11 @@ class BaseQuerySet(object):
'%s is not a subclass of BaseQuerySet' % new_qs.__name__) '%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_snapshot', '_timeout', '_class_check', '_slave_okay', '_timeout', '_class_check', '_slave_okay', '_read_preference',
'_read_preference', '_iter', '_scalar', '_as_pymongo', '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference', '_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms', '_search_text', 'only_fields', '_max_time_ms', '_comment')
'_comment')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
@@ -939,8 +939,7 @@ class BaseQuerySet(object):
posts = BlogPost.objects(...).fields(slice__comments=5) posts = BlogPost.objects(...).fields(slice__comments=5)
:param kwargs: A set of keyword arguments identifying what to :param kwargs: A set keywors arguments identifying what to include.
include, exclude, or slice.
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
@@ -1129,15 +1128,16 @@ class BaseQuerySet(object):
"""An alias for scalar""" """An alias for scalar"""
return self.scalar(*fields) return self.scalar(*fields)
def as_pymongo(self): def as_pymongo(self, coerce_types=False):
"""Instead of returning Document instances, return raw values from """Instead of returning Document instances, return raw values from
pymongo. pymongo.
This method is particularly useful if you don't need dereferencing :param coerce_types: Field types (if applicable) would be use to
and care primarily about the speed of data retrieval. coerce types.
""" """
queryset = self.clone() queryset = self.clone()
queryset._as_pymongo = True queryset._as_pymongo = True
queryset._as_pymongo_coerce = coerce_types
return queryset return queryset
def max_time_ms(self, ms): def max_time_ms(self, ms):
@@ -1799,25 +1799,59 @@ class BaseQuerySet(object):
return tuple(data) return tuple(data)
def _get_as_pymongo(self, doc): def _get_as_pymongo(self, row):
"""Clean up a PyMongo doc, removing fields that were only fetched # Extract which fields paths we should follow if .fields(...) was
for the sake of MongoEngine's implementation, and return it. # used. If not, handle all fields.
""" if not getattr(self, '__as_pymongo_fields', None):
# Always remove _cls as a MongoEngine's implementation detail. self.__as_pymongo_fields = []
if '_cls' in doc:
del doc['_cls']
# If the _id was not included in a .only or was excluded in a .exclude, for field in self._loaded_fields.fields - set(['_cls']):
# remove it from the doc (we always fetch it so that we can properly self.__as_pymongo_fields.append(field)
# construct documents). while '.' in field:
fields = self._loaded_fields field, _ = field.rsplit('.', 1)
if fields and '_id' in doc and ( self.__as_pymongo_fields.append(field)
(fields.value == QueryFieldList.ONLY and '_id' not in fields.fields) or
(fields.value == QueryFieldList.EXCLUDE and '_id' in fields.fields)
):
del doc['_id']
return doc all_fields = not self.__as_pymongo_fields
def clean(data, path=None):
path = path or ''
if isinstance(data, dict):
new_data = {}
for key, value in data.iteritems():
new_path = '%s.%s' % (path, key) if path else key
if all_fields:
include_field = True
elif self._loaded_fields.value == QueryFieldList.ONLY:
include_field = new_path in self.__as_pymongo_fields
else:
include_field = new_path not in self.__as_pymongo_fields
if include_field:
new_data[key] = clean(value, path=new_path)
data = new_data
elif isinstance(data, list):
data = [clean(d, path=path) for d in data]
else:
if self._as_pymongo_coerce:
# If we need to coerce types, we need to determine the
# type of this field and use the corresponding
# .to_python(...)
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
if obj is None:
break
elif isinstance(obj, EmbeddedDocumentField):
obj = obj.document_type
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where

View File

@@ -5,9 +5,11 @@ import uuid
import math import math
import itertools import itertools
import re import re
import pymongo
import sys import sys
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from collections import OrderedDict
import six import six
try: try:
@@ -26,9 +28,12 @@ except ImportError:
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
_document_registry) _document_registry, TopLevelDocumentMetaclass)
from tests.utils import MongoDBTestCase from tests.utils import MongoDBTestCase, MONGO_TEST_DB
from mongoengine.python_support import IS_PYMONGO_3
if IS_PYMONGO_3:
from bson import CodecOptions
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
@@ -4183,6 +4188,67 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.assertTrue(hasattr(CustomData.c_field, 'custom_data')) self.assertTrue(hasattr(CustomData.c_field, 'custom_data'))
self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
def test_dynamicfield_with_container_class(self):
"""
Tests that object can be stored in order by DynamicField class
with container_class parameter.
"""
raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)]
class Doc(Document):
ordered_data = DynamicField(container_class=OrderedDict)
unordered_data = DynamicField()
Doc.drop_collection()
doc = Doc(ordered_data=OrderedDict(raw_data), unordered_data=dict(raw_data)).save()
# checks that the data is in order
self.assertEqual(type(doc.ordered_data), OrderedDict)
self.assertEqual(type(doc.unordered_data), dict)
self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a')
# checks that the data is stored to the database in order
pymongo_db = pymongo.MongoClient()[MONGO_TEST_DB]
if IS_PYMONGO_3:
codec_option = CodecOptions(document_class=OrderedDict)
db_doc = pymongo_db.doc.with_options(codec_options=codec_option).find_one()
else:
db_doc = pymongo_db.doc.find_one(as_class=OrderedDict)
self.assertEqual(','.join(doc.ordered_data.keys()), 'd,c,b,a')
def test_dynamicfield_with_wrong_container_class(self):
with self.assertRaises(ValidationError):
class DocWithInvalidField:
data = DynamicField(container_class=list)
def test_dynamicfield_with_wrong_container_class_and_reload_docuemnt(self):
# This is because 'codec_options' is supported on pymongo3 or later
if IS_PYMONGO_3:
class OrderedDocument(Document):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
@classmethod
def _get_collection(cls):
collection = super(OrderedDocument, cls)._get_collection()
opts = CodecOptions(document_class=OrderedDict)
return collection.with_options(codec_options=opts)
raw_data = [('d', 1), ('c', 2), ('b', 3), ('a', 4)]
class Doc(OrderedDocument):
data = DynamicField(container_class=OrderedDict)
Doc.drop_collection()
doc = Doc(data=OrderedDict(raw_data)).save()
doc.reload()
self.assertEqual(type(doc.data), OrderedDict)
self.assertEqual(','.join(doc.data.keys()), 'd,c,b,a')
class CachedReferenceFieldTest(MongoDBTestCase): class CachedReferenceFieldTest(MongoDBTestCase):

View File

@@ -4047,35 +4047,6 @@ class QuerySetTest(unittest.TestCase):
plist = list(Person.objects.scalar('name', 'state')) plist = list(Person.objects.scalar('name', 'state'))
self.assertEqual(plist, [(u'Wilson JR', s1)]) self.assertEqual(plist, [(u'Wilson JR', s1)])
def test_generic_reference_field_with_only_and_as_pymongo(self):
class TestPerson(Document):
name = StringField()
class TestActivity(Document):
name = StringField()
owner = GenericReferenceField()
TestPerson.drop_collection()
TestActivity.drop_collection()
person = TestPerson(name='owner')
person.save()
a1 = TestActivity(name='a1', owner=person)
a1.save()
activity = TestActivity.objects(owner=person).scalar('id', 'owner').no_dereference().first()
self.assertEqual(activity[0], a1.pk)
self.assertEqual(activity[1]['_ref'], DBRef('test_person', person.pk))
activity = TestActivity.objects(owner=person).only('id', 'owner')[0]
self.assertEqual(activity.pk, a1.pk)
self.assertEqual(activity.owner, person)
activity = TestActivity.objects(owner=person).only('id', 'owner').as_pymongo().first()
self.assertEqual(activity['_id'], a1.pk)
self.assertTrue(activity['owner']['_ref'], DBRef('test_person', person.pk))
def test_scalar_db_field(self): def test_scalar_db_field(self):
class TestDoc(Document): class TestDoc(Document):
@@ -4421,44 +4392,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
def test_as_pymongo(self): def test_as_pymongo(self):
from decimal import Decimal
class LastLogin(EmbeddedDocument): from decimal import Decimal
location = StringField()
ip = StringField()
class User(Document): class User(Document):
id = ObjectIdField('_id') id = ObjectIdField('_id')
name = StringField() name = StringField()
age = IntField() age = IntField()
price = DecimalField() price = DecimalField()
last_login = EmbeddedDocumentField(LastLogin)
User.drop_collection() User.drop_collection()
User(name="Bob Dole", age=89, price=Decimal('1.11')).save()
User.objects.create(name="Bob Dole", age=89, price=Decimal('1.11')) User(name="Barack Obama", age=51, price=Decimal('2.22')).save()
User.objects.create(
name="Barack Obama",
age=51,
price=Decimal('2.22'),
last_login=LastLogin(
location='White House',
ip='104.107.108.116'
)
)
results = User.objects.as_pymongo()
self.assertEqual(
set(results[0].keys()),
set(['_id', 'name', 'age', 'price'])
)
self.assertEqual(
set(results[1].keys()),
set(['_id', 'name', 'age', 'price', 'last_login'])
)
results = User.objects.only('id', 'name').as_pymongo() results = User.objects.only('id', 'name').as_pymongo()
self.assertEqual(set(results[0].keys()), set(['_id', 'name'])) self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name']))
users = User.objects.only('name', 'price').as_pymongo() users = User.objects.only('name', 'price').as_pymongo()
results = list(users) results = list(users)
@@ -4469,20 +4417,16 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], 2.22) self.assertEqual(results[1]['price'], 2.22)
users = User.objects.only('name', 'last_login').as_pymongo() # Test coerce_types
users = User.objects.only(
'name', 'price').as_pymongo(coerce_types=True)
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
self.assertEqual(results[0], { self.assertEqual(results[0]['name'], 'Bob Dole')
'name': 'Bob Dole' self.assertEqual(results[0]['price'], Decimal('1.11'))
}) self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1], { self.assertEqual(results[1]['price'], Decimal('2.22'))
'name': 'Barack Obama',
'last_login': {
'location': 'White House',
'ip': '104.107.108.116'
}
})
def test_as_pymongo_json_limit_fields(self): def test_as_pymongo_json_limit_fields(self):
@@ -4646,6 +4590,7 @@ class QuerySetTest(unittest.TestCase):
def test_no_cache(self): def test_no_cache(self):
"""Ensure you can add meta data to file""" """Ensure you can add meta data to file"""
class Noddy(Document): class Noddy(Document):
fields = DictField() fields = DictField()
@@ -4663,19 +4608,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(list(docs)), 100) self.assertEqual(len(list(docs)), 100)
# Can't directly get a length of a no-cache queryset.
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
len(docs) len(docs)
# Another iteration over the queryset should result in another db op.
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0)
list(docs) list(docs)
self.assertEqual(q, 1) self.assertEqual(q, 1)
# ... and another one to double-check.
with query_counter() as q:
list(docs) list(docs)
self.assertEqual(q, 1) self.assertEqual(q, 2)
def test_nested_queryset_iterator(self): def test_nested_queryset_iterator(self):
# Try iterating the same queryset twice, nested. # Try iterating the same queryset twice, nested.

View File

@@ -2,10 +2,15 @@
import unittest import unittest
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
from collections import OrderedDict
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter from mongoengine.context_managers import query_counter
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.base import TopLevelDocumentMetaclass
if IS_PYMONGO_3:
from bson import CodecOptions
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@@ -1287,5 +1292,70 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_dynamic_field_dereference(self):
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(Document):
merchandises = DynamicField()
Merchandise.drop_collection()
Store.drop_collection()
merchandises = {
'#1': Merchandise(name='foo', price=100).save(),
'#2': Merchandise(name='bar', price=120).save(),
'#3': Merchandise(name='baz', price=110).save(),
}
Store(merchandises=merchandises).save()
store = Store.objects().first()
for obj in store.merchandises.values():
self.assertFalse(isinstance(obj, Merchandise))
store.select_related()
for obj in store.merchandises.values():
self.assertTrue(isinstance(obj, Merchandise))
def test_dynamic_field_dereference_with_ordering_guarantee_on_pymongo3(self):
# This is because 'codec_options' is supported on pymongo3 or later
if IS_PYMONGO_3:
class OrderedDocument(Document):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
@classmethod
def _get_collection(cls):
collection = super(OrderedDocument, cls)._get_collection()
opts = CodecOptions(document_class=OrderedDict)
return collection.with_options(codec_options=opts)
class Merchandise(Document):
name = StringField()
price = IntField()
class Store(OrderedDocument):
merchandises = DynamicField(container_class=OrderedDict)
Merchandise.drop_collection()
Store.drop_collection()
merchandises = OrderedDict()
merchandises['#1'] = Merchandise(name='foo', price=100).save()
merchandises['#2'] = Merchandise(name='bar', price=120).save()
merchandises['#3'] = Merchandise(name='baz', price=110).save()
Store(merchandises=merchandises).save()
store = Store.objects().first()
store.select_related()
# confirms that the load data order is same with the one at storing
self.assertTrue(type(store.merchandises), OrderedDict)
self.assertEqual(','.join(store.merchandises.keys()), '#1,#2,#3')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()