Merge pull request #1218 from bbenne10/master

Curry **kwargs through to_mongo on fields
This commit is contained in:
Emmanuel Leblond 2016-01-26 15:53:21 +01:00
commit c946b06be5
7 changed files with 60 additions and 46 deletions

View File

@ -234,3 +234,4 @@ that much better:
* Paul-Armand Verhaegen (https://github.com/paularmand) * Paul-Armand Verhaegen (https://github.com/paularmand)
* Steven Rossiter (https://github.com/BeardedSteve) * Steven Rossiter (https://github.com/BeardedSteve)
* Luo Peng (https://github.com/RussellLuo) * Luo Peng (https://github.com/RussellLuo)
* Bryan Bennett (https://github.com/bbenne10)

View File

@ -7,6 +7,7 @@ Changes in 0.10.6
- Add support for mocking MongoEngine based on mongomock. #1151 - Add support for mocking MongoEngine based on mongomock. #1151
- Fixed not being able to run tests on Windows. #1153 - Fixed not being able to run tests on Windows. #1153
- Allow creation of sparse compound indexes. #1114 - Allow creation of sparse compound indexes. #1114
- Fixed not being able to specify `use_db_field=False` on `ListField(EmbeddedDocumentField)` instances
Changes in 0.10.5 Changes in 0.10.5
================= =================

View File

@ -325,20 +325,17 @@ class BaseDocument(object):
if value is not None: if value is not None:
if isinstance(field, EmbeddedDocumentField): if fields:
if fields: key = '%s.' % field_name
key = '%s.' % field_name embedded_fields = [
embedded_fields = [ i.replace(key, '') for i in fields
i.replace(key, '') for i in fields if i.startswith(key)]
if i.startswith(key)]
else:
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
else: else:
value = field.to_mongo(value) embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
# Handle self generating fields # Handle self generating fields
if value is None and field._auto_gen: if value is None and field._auto_gen:

View File

@ -158,7 +158,7 @@ class BaseField(object):
""" """
return value return value
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type.
""" """
return self.to_python(value) return self.to_python(value)
@ -325,7 +325,7 @@ class ComplexBaseField(BaseField):
key=operator.itemgetter(0))] key=operator.itemgetter(0))]
return value_dict return value_dict
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type.
""" """
Document = _import_class("Document") Document = _import_class("Document")
@ -337,9 +337,10 @@ class ComplexBaseField(BaseField):
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):
if isinstance(value, Document): if isinstance(value, Document):
return GenericReferenceField().to_mongo(value) return GenericReferenceField().to_mongo(
value, **kwargs)
cls = value.__class__ cls = value.__class__
val = value.to_mongo() val = value.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls # If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument): if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__ val['_cls'] = cls.__name__
@ -354,7 +355,7 @@ class ComplexBaseField(BaseField):
return value return value
if self.field: if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) value_dict = dict([(key, self.field.to_mongo(item, **kwargs))
for key, item in value.iteritems()]) for key, item in value.iteritems()])
else: else:
value_dict = {} value_dict = {}
@ -373,19 +374,20 @@ class ComplexBaseField(BaseField):
meta.get('allow_inheritance', ALLOW_INHERITANCE) meta.get('allow_inheritance', ALLOW_INHERITANCE)
is True) is True)
if not allow_inheritance and not self.field: if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v) value_dict[k] = GenericReferenceField().to_mongo(
v, **kwargs)
else: else:
collection = v._get_collection_name() collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk) value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'): elif hasattr(v, 'to_mongo'):
cls = v.__class__ cls = v.__class__
val = v.to_mongo() val = v.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls # If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)): if isinstance(v, (Document, EmbeddedDocument)):
val['_cls'] = cls.__name__ val['_cls'] = cls.__name__
value_dict[k] = val value_dict[k] = val
else: else:
value_dict[k] = self.to_mongo(v) value_dict[k] = self.to_mongo(v, **kwargs)
if is_list: # Convert back to a list if is_list: # Convert back to a list
return [v for _, v in sorted(value_dict.items(), return [v for _, v in sorted(value_dict.items(),
@ -443,7 +445,7 @@ class ObjectIdField(BaseField):
pass pass
return value return value
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
if not isinstance(value, ObjectId): if not isinstance(value, ObjectId):
try: try:
return ObjectId(unicode(value)) return ObjectId(unicode(value))
@ -618,7 +620,7 @@ class GeoJsonBaseField(BaseField):
if errors: if errors:
return "Invalid MultiPolygon:\n%s" % ", ".join(errors) return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
if isinstance(value, dict): if isinstance(value, dict):
return value return value
return SON([("type", self._type), ("coordinates", value)]) return SON([("type", self._type), ("coordinates", value)])

View File

@ -325,7 +325,7 @@ class DecimalField(BaseField):
return value return value
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding)
def to_mongo(self, value, use_db_field=True): def to_mongo(self, value, **kwargs):
if value is None: if value is None:
return value return value
if self.force_string: if self.force_string:
@ -388,7 +388,7 @@ class DateTimeField(BaseField):
if not isinstance(new_value, (datetime.datetime, datetime.date)): if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value) self.error(u'cannot parse date "%s"' % value)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
if value is None: if value is None:
return value return value
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
@ -511,7 +511,7 @@ class ComplexDateTimeField(StringField):
except: except:
return original_value return original_value
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
value = self.to_python(value) value = self.to_python(value)
return self._convert_from_datetime(value) return self._convert_from_datetime(value)
@ -546,11 +546,10 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) return self.document_type._from_son(value, _auto_dereference=self._auto_dereference)
return value return value
def to_mongo(self, value, use_db_field=True, fields=[]): def to_mongo(self, value, **kwargs):
if not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
return value return value
return self.document_type.to_mongo(value, use_db_field, return self.document_type.to_mongo(value, **kwargs)
fields=fields)
def validate(self, value, clean=True): def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
@ -600,11 +599,11 @@ class GenericEmbeddedDocumentField(BaseField):
value.validate(clean=clean) value.validate(clean=clean)
def to_mongo(self, document, use_db_field=True): def to_mongo(self, document, **kwargs):
if document is None: if document is None:
return None return None
data = document.to_mongo(use_db_field) data = document.to_mongo(**kwargs)
if '_cls' not in data: if '_cls' not in data:
data['_cls'] = document._class_name data['_cls'] = document._class_name
return data return data
@ -616,7 +615,7 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data""" Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB compatible type. """Convert a Python type to a MongoDB compatible type.
""" """
@ -625,7 +624,7 @@ class DynamicField(BaseField):
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):
cls = value.__class__ cls = value.__class__
val = value.to_mongo() val = value.to_mongo(**kwargs)
# If we its a document thats not inherited add _cls # If we its a document thats not inherited add _cls
if isinstance(value, Document): if isinstance(value, Document):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__} val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
@ -643,7 +642,7 @@ class DynamicField(BaseField):
data = {} data = {}
for k, v in value.iteritems(): for k, v in value.iteritems():
data[k] = self.to_mongo(v) data[k] = self.to_mongo(v, **kwargs)
value = data value = data
if is_list: # Convert back to a list if is_list: # Convert back to a list
@ -755,8 +754,8 @@ class SortedListField(ListField):
self._order_reverse = kwargs.pop('reverse') self._order_reverse = kwargs.pop('reverse')
super(SortedListField, self).__init__(field, **kwargs) super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
value = super(SortedListField, self).to_mongo(value) value = super(SortedListField, self).to_mongo(value, **kwargs)
if self._ordering is not None: if self._ordering is not None:
return sorted(value, key=itemgetter(self._ordering), return sorted(value, key=itemgetter(self._ordering),
reverse=self._order_reverse) reverse=self._order_reverse)
@ -942,7 +941,7 @@ class ReferenceField(BaseField):
return super(ReferenceField, self).__get__(instance, owner) return super(ReferenceField, self).__get__(instance, owner)
def to_mongo(self, document): def to_mongo(self, document, **kwargs):
if isinstance(document, DBRef): if isinstance(document, DBRef):
if not self.dbref: if not self.dbref:
return document.id return document.id
@ -965,7 +964,7 @@ class ReferenceField(BaseField):
id_field_name = cls._meta['id_field'] id_field_name = cls._meta['id_field']
id_field = cls._fields[id_field_name] id_field = cls._fields[id_field_name]
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_, **kwargs)
if self.document_type._meta.get('abstract'): if self.document_type._meta.get('abstract'):
collection = cls._get_collection_name() collection = cls._get_collection_name()
return DBRef(collection, id_, cls=cls._class_name) return DBRef(collection, id_, cls=cls._class_name)
@ -1088,7 +1087,7 @@ class CachedReferenceField(BaseField):
return super(CachedReferenceField, self).__get__(instance, owner) return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document): def to_mongo(self, document, **kwargs):
id_field_name = self.document_type._meta['id_field'] id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name] id_field = self.document_type._fields[id_field_name]
@ -1103,10 +1102,11 @@ class CachedReferenceField(BaseField):
# TODO: should raise here or will fail next statement # TODO: should raise here or will fail next statement
value = SON(( value = SON((
("_id", id_field.to_mongo(id_)), ("_id", id_field.to_mongo(id_, **kwargs)),
)) ))
value.update(dict(document.to_mongo(fields=self.fields))) kwargs['fields'] = self.fields
value.update(dict(document.to_mongo(**kwargs)))
return value return value
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -1222,7 +1222,7 @@ class GenericReferenceField(BaseField):
doc = doc_cls._from_son(doc) doc = doc_cls._from_son(doc)
return doc return doc
def to_mongo(self, document, use_db_field=True): def to_mongo(self, document, **kwargs):
if document is None: if document is None:
return None return None
@ -1241,7 +1241,7 @@ class GenericReferenceField(BaseField):
else: else:
id_ = document id_ = document
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_, **kwargs)
collection = document._get_collection_name() collection = document._get_collection_name()
ref = DBRef(collection, id_) ref = DBRef(collection, id_)
return SON(( return SON((
@ -1270,7 +1270,7 @@ class BinaryField(BaseField):
value = bin_type(value) value = bin_type(value)
return super(BinaryField, self).__set__(instance, value) return super(BinaryField, self).__set__(instance, value)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
return Binary(value) return Binary(value)
def validate(self, value): def validate(self, value):
@ -1495,7 +1495,7 @@ class FileField(BaseField):
db_alias=db_alias, db_alias=db_alias,
collection_name=collection_name) collection_name=collection_name)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
# Store the GridFS file id in MongoDB # Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None: if isinstance(value, self.proxy_class) and value.grid_id is not None:
return value.grid_id return value.grid_id
@ -1845,7 +1845,7 @@ class UUIDField(BaseField):
return original_value return original_value
return value return value
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
if not self._binary: if not self._binary:
return unicode(value) return unicode(value)
elif isinstance(value, basestring): elif isinstance(value, basestring):

View File

@ -679,6 +679,19 @@ class InstanceTest(unittest.TestCase):
doc = Doc.objects.get() doc = Doc.objects.get()
self.assertHasInstance(doc.embedded_field[0], doc) self.assertHasInstance(doc.embedded_field[0], doc)
def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs
"""
class Embedded(EmbeddedDocument):
string = StringField(db_field='s')
class Doc(Document):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo(
use_db_field=False).to_dict()
self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
def test_instance_is_set_on_setattr(self): def test_instance_is_set_on_setattr(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):

View File

@ -3380,7 +3380,7 @@ class FieldTest(unittest.TestCase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs) super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value): def to_mongo(self, value, **kwargs):
return value return value
def to_python(self, value): def to_python(self, value):