diff --git a/mongoengine/base.py b/mongoengine/base.py index 087684f2..d2eced9e 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -55,7 +55,7 @@ class BaseField(object): """ return self.to_python(value) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. """ return value @@ -82,7 +82,7 @@ class ObjectIdField(BaseField): raise ValidationError(e.message) return value - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) def validate(self, value): diff --git a/mongoengine/connection.py b/mongoengine/connection.py index da8f2baf..ec3bf784 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -59,3 +59,4 @@ def connect(db, username=None, password=None, **kwargs): _db_name = db _db_username = username _db_password = password + return _get_db() \ No newline at end of file diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1b742bfa..952150ab 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -188,7 +188,7 @@ class EmbeddedDocumentField(BaseField): def lookup_member(self, member_name): return self.document._fields.get(member_name) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -207,6 +207,30 @@ class ListField(BaseField): self.field = field super(ListField, self).__init__(**kwargs) + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + if isinstance(self.field, ReferenceField): + referenced_type = self.field.document_type + # Get value from document instance if available + value_list = instance._data.get(self.name) + if value_list: + deref_list = [] + for value in value_list: + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + deref_list.append(referenced_type._from_son(value)) + else: + deref_list.append(value) + instance._data[self.name] = deref_list + + return super(ListField, self).__get__(instance, owner) + def to_python(self, value): return [self.field.to_python(item) for item in value] @@ -226,7 +250,9 @@ class ListField(BaseField): raise ValidationError('All items in a list field must be of the ' 'specified type') - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): + if op in ('set', 'unset'): + return [self.field.to_mongo(v) for v in value] return self.field.to_mongo(value) def lookup_member(self, member_name): @@ -302,7 +328,7 @@ class ReferenceField(BaseField): collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, id_) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) def validate(self, value): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 8257e5f4..4c57bd7b 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -284,10 +284,10 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'): - value = field.prepare_query_value(value) + value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all'): # 'in', 'nin' and 'all' require a list of values - value = [field.prepare_query_value(v) for v in value] + value = [field.prepare_query_value(op, v) for v in value] if op: value = {'$' + op: value} @@ -487,9 +487,9 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] if op in (None, 'set', 'unset', 'push', 'pull'): - value = field.prepare_query_value(value) + value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): - value = [field.prepare_query_value(v) for v in value] + value = [field.prepare_query_value(op, v) for v in value] key = '.'.join(parts) diff --git a/tests/fields.py b/tests/fields.py index cc4cc5e1..97e5b4e8 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -286,6 +286,34 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() + + def test_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + user1 = User(name='user1') + user1.save() + user2 = User(name='user2') + user2.save() + + group = Group(members=[user1, user2]) + group.save() + + group_obj = Group.objects.first() + + self.assertEqual(group_obj.members[0].name, user1.name) + self.assertEqual(group_obj.members[1].name, user2.name) + + User.drop_collection() + Group.drop_collection() def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values diff --git a/tests/queryset.py b/tests/queryset.py index 4e058803..02f53f33 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -594,6 +594,30 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_update_value_conversion(self): + """Ensure that values used in updates are converted before use. + """ + class Group(Document): + members = ListField(ReferenceField(self.Person)) + + Group.drop_collection() + + user1 = self.Person(name='user1') + user1.save() + user2 = self.Person(name='user2') + user2.save() + + group = Group() + group.save() + + Group.objects(id=group.id).update(set__members=[user1, user2]) + group.reload() + + self.assertTrue(len(group.members) == 2) + self.assertEqual(group.members[0].name, user1.name) + self.assertEqual(group.members[1].name, user2.name) + + Group.drop_collection() def test_types_index(self): """Ensure that and index is used when '_types' is being used in a