cleaner as_pymongo + drop coerce_types

This commit is contained in:
Stefan Wojcik 2017-05-07 23:28:35 -04:00
parent 33e9ef2106
commit e751ab55c8
2 changed files with 63 additions and 73 deletions

View File

@ -67,7 +67,6 @@ class BaseQuerySet(object):
self._scalar = [] self._scalar = []
self._none = False self._none = False
self._as_pymongo = False self._as_pymongo = False
self._as_pymongo_coerce = False
self._search_text = None self._search_text = None
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
@ -728,11 +727,12 @@ class BaseQuerySet(object):
'%s is not a subclass of BaseQuerySet' % new_qs.__name__) '%s is not a subclass of BaseQuerySet' % new_qs.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot', '_where_clause', '_loaded_fields', '_ordering',
'_timeout', '_class_check', '_slave_okay', '_read_preference', '_snapshot', '_timeout', '_class_check', '_slave_okay',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', '_read_preference', '_iter', '_scalar', '_as_pymongo',
'_limit', '_skip', '_hint', '_auto_dereference', '_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms', '_comment') '_search_text', 'only_fields', '_max_time_ms',
'_comment')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
@ -939,7 +939,8 @@ class BaseQuerySet(object):
posts = BlogPost.objects(...).fields(slice__comments=5) posts = BlogPost.objects(...).fields(slice__comments=5)
:param kwargs: A set keywors arguments identifying what to include. :param kwargs: A set of keyword arguments identifying what to
include, exclude, or slice.
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
@ -1128,16 +1129,12 @@ class BaseQuerySet(object):
"""An alias for scalar""" """An alias for scalar"""
return self.scalar(*fields) return self.scalar(*fields)
def as_pymongo(self, coerce_types=False): def as_pymongo(self):
"""Instead of returning Document instances, return raw values from """Instead of returning Document instances, return raw values from
pymongo. pymongo.
:param coerce_types: Field types (if applicable) would be use to
coerce types.
""" """
queryset = self.clone() queryset = self.clone()
queryset._as_pymongo = True queryset._as_pymongo = True
queryset._as_pymongo_coerce = coerce_types
return queryset return queryset
def max_time_ms(self, ms): def max_time_ms(self, ms):
@ -1799,59 +1796,25 @@ class BaseQuerySet(object):
return tuple(data) return tuple(data)
def _get_as_pymongo(self, row): def _get_as_pymongo(self, doc):
# Extract which fields paths we should follow if .fields(...) was """Clean up a PyMongo doc, removing fields that were only fetched
# used. If not, handle all fields. for the sake of MongoEngine's implementation, and return it.
if not getattr(self, '__as_pymongo_fields', None): """
self.__as_pymongo_fields = [] # Always remove _cls as a MongoEngine's implementation detail.
if '_cls' in doc:
del doc['_cls']
for field in self._loaded_fields.fields - set(['_cls']): # If the _id was not included in a .only or was excluded in a .exclude,
self.__as_pymongo_fields.append(field) # remove it from the doc (we always fetch it so that we can properly
while '.' in field: # construct documents).
field, _ = field.rsplit('.', 1) fields = self._loaded_fields
self.__as_pymongo_fields.append(field) if fields and '_id' in doc and (
(fields.value == QueryFieldList.ONLY and '_id' not in fields.fields) or
(fields.value == QueryFieldList.EXCLUDE and '_id' in fields.fields)
):
del doc['_id']
all_fields = not self.__as_pymongo_fields return doc
def clean(data, path=None):
path = path or ''
if isinstance(data, dict):
new_data = {}
for key, value in data.iteritems():
new_path = '%s.%s' % (path, key) if path else key
if all_fields:
include_field = True
elif self._loaded_fields.value == QueryFieldList.ONLY:
include_field = new_path in self.__as_pymongo_fields
else:
include_field = new_path not in self.__as_pymongo_fields
if include_field:
new_data[key] = clean(value, path=new_path)
data = new_data
elif isinstance(data, list):
data = [clean(d, path=path) for d in data]
else:
if self._as_pymongo_coerce:
# If we need to coerce types, we need to determine the
# type of this field and use the corresponding
# .to_python(...)
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
obj = self._document
for chunk in path.split('.'):
obj = getattr(obj, chunk, None)
if obj is None:
break
elif isinstance(obj, EmbeddedDocumentField):
obj = obj.document_type
if obj and data is not None:
data = obj.to_python(data)
return data
return clean(row)
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where

View File

@ -4392,21 +4392,44 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
def test_as_pymongo(self): def test_as_pymongo(self):
from decimal import Decimal from decimal import Decimal
class LastLogin(EmbeddedDocument):
location = StringField()
ip = StringField()
class User(Document): class User(Document):
id = ObjectIdField('_id') id = ObjectIdField('_id')
name = StringField() name = StringField()
age = IntField() age = IntField()
price = DecimalField() price = DecimalField()
last_login = EmbeddedDocumentField(LastLogin)
User.drop_collection() User.drop_collection()
User(name="Bob Dole", age=89, price=Decimal('1.11')).save()
User(name="Barack Obama", age=51, price=Decimal('2.22')).save() User.objects.create(name="Bob Dole", age=89, price=Decimal('1.11'))
User.objects.create(
name="Barack Obama",
age=51,
price=Decimal('2.22'),
last_login=LastLogin(
location='White House',
ip='104.107.108.116'
)
)
results = User.objects.as_pymongo()
self.assertEqual(
set(results[0].keys()),
set(['_id', 'name', 'age', 'price'])
)
self.assertEqual(
set(results[1].keys()),
set(['_id', 'name', 'age', 'price', 'last_login'])
)
results = User.objects.only('id', 'name').as_pymongo() results = User.objects.only('id', 'name').as_pymongo()
self.assertEqual(sorted(results[0].keys()), sorted(['_id', 'name'])) self.assertEqual(set(results[0].keys()), set(['_id', 'name']))
users = User.objects.only('name', 'price').as_pymongo() users = User.objects.only('name', 'price').as_pymongo()
results = list(users) results = list(users)
@ -4417,16 +4440,20 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(results[1]['name'], 'Barack Obama') self.assertEqual(results[1]['name'], 'Barack Obama')
self.assertEqual(results[1]['price'], 2.22) self.assertEqual(results[1]['price'], 2.22)
# Test coerce_types users = User.objects.only('name', 'last_login').as_pymongo()
users = User.objects.only(
'name', 'price').as_pymongo(coerce_types=True)
results = list(users) results = list(users)
self.assertTrue(isinstance(results[0], dict)) self.assertTrue(isinstance(results[0], dict))
self.assertTrue(isinstance(results[1], dict)) self.assertTrue(isinstance(results[1], dict))
self.assertEqual(results[0]['name'], 'Bob Dole') self.assertEqual(results[0], {
self.assertEqual(results[0]['price'], Decimal('1.11')) 'name': 'Bob Dole'
self.assertEqual(results[1]['name'], 'Barack Obama') })
self.assertEqual(results[1]['price'], Decimal('2.22')) self.assertEqual(results[1], {
'name': 'Barack Obama',
'last_login': {
'location': 'White House',
'ip': '104.107.108.116'
}
})
def test_as_pymongo_json_limit_fields(self): def test_as_pymongo_json_limit_fields(self):