From 3d0d2f48ad3fd4bfe2561a6f36cfb2caa54c724a Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 19 Jun 2012 10:57:43 +0100 Subject: [PATCH] Fixed map_field embedded db_field bug fixes hmarr/mongoengine#512 --- docs/changelog.rst | 1 + mongoengine/queryset.py | 8 ++++++-- tests/test_fields.py | 42 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 3be5f2cd..0f0965ee 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.6.X ================ +- Fixed map_field embedded db_field issue - Fixed .save() _delta issue with DbRefs - Fixed Django TestCase - Added cmp to Embedded Document diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ab35afbc..ed00a73a 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -620,6 +620,7 @@ class QuerySet(object): "Can't use index on unsubscriptable field (%s)" % err) fields.append(field_name) continue + if field is None: # Look up first field from the document if field_name == 'pk': @@ -637,8 +638,11 @@ class QuerySet(object): from mongoengine.fields import ReferenceField, GenericReferenceField if isinstance(field, (ReferenceField, GenericReferenceField)): raise InvalidQueryError('Cannot perform join in mongoDB: %s' % '__'.join(parts)) - # Look up subfield on the previous field - new_field = field.lookup_member(field_name) + if getattr(field, 'field', None): + new_field = field.field.lookup_member(field_name) + else: + # Look up subfield on the previous field + new_field = field.lookup_member(field_name) from base import ComplexBaseField if not new_field and isinstance(field, ComplexBaseField): fields.append(field_name) diff --git a/tests/test_fields.py b/tests/test_fields.py index 85f29119..04d5a34d 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -913,6 +913,48 @@ class FieldTest(unittest.TestCase): Extensible.drop_collection() + def test_embedded_mapfield_db_field(self): + + class Embedded(EmbeddedDocument): + number = IntField(default=0, db_field='i') + + class Test(Document): + my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field='x') + + Test.drop_collection() + + test = Test() + test.my_map['DICTIONARY_KEY'] = Embedded(number=1) + test.save() + + Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) + + test = Test.objects.get() + self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) + doc = self.db.test.find_one() + self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + + def test_embedded_db_field(self): + + class Embedded(EmbeddedDocument): + number = IntField(default=0, db_field='i') + + class Test(Document): + embedded = EmbeddedDocumentField(Embedded, db_field='x') + + Test.drop_collection() + + test = Test() + test.embedded = Embedded(number=1) + test.save() + + Test.objects.update_one(inc__embedded__number=1) + + test = Test.objects.get() + self.assertEqual(test.embedded.number, 2) + doc = self.db.test.find_one() + self.assertEqual(doc['x']['i'], 2) + def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to embedded document fields.