From 24db0d14996dc7edd3478113447a755cb661f94a Mon Sep 17 00:00:00 2001 From: Florian Schlachter Date: Fri, 5 Feb 2010 00:35:49 +0100 Subject: [PATCH 1/3] return db-object to allow low-level access from outside via connect() --- mongoengine/connection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index ee8d735b..770d34f6 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -60,3 +60,4 @@ def connect(db, username=None, password=None, **kwargs): _db_name = db _db_username = username _db_password = password + return _get_db() \ No newline at end of file From 210e9e23af65e5f5e6e7c8caf506a50055f626a0 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Fri, 12 Feb 2010 02:31:41 +0000 Subject: [PATCH 2/3] Dereferencing of referenced documents within lists --- mongoengine/fields.py | 23 +++++++++++++++++++++++ tests/fields.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 4695bf81..4c739d95 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -155,6 +155,29 @@ class ListField(BaseField): self.field = field super(ListField, self).__init__(**kwargs) + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + if isinstance(self.field, ReferenceField): + referenced_type = self.field.document_type + # Get value from document instance if available + value_list = instance._data.get(self.name) + deref_list = [] + for value in value_list: + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + deref_list.append(referenced_type._from_son(value)) + else: + deref_list.append(value) + instance._data[self.name] = deref_list + + return super(ListField, self).__get__(instance, owner) + def to_python(self, value): return [self.field.to_python(item) for item in value] diff --git a/tests/fields.py b/tests/fields.py index b35a9142..6eee972e 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -285,6 +285,34 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() + + def test_list_item_dereference(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User)) + + User.drop_collection() + Group.drop_collection() + + user1 = User(name='user1') + user1.save() + user2 = User(name='user2') + user2.save() + + group = Group(members=[user1, user2]) + group.save() + + group_obj = Group.objects.first() + + self.assertEqual(group_obj.members[0].name, user1.name) + self.assertEqual(group_obj.members[1].name, user2.name) + + User.drop_collection() + Group.drop_collection() def test_reference_query_conversion(self): """Ensure that ReferenceFields can be queried using objects and values From ea1fe6a53802dee9ee462ca92bb94ff0dd173061 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Fri, 12 Feb 2010 11:21:51 +0000 Subject: [PATCH 3/3] Fixed set/unset issue with ListFields --- mongoengine/base.py | 4 ++-- mongoengine/fields.py | 27 +++++++++++++++------------ mongoengine/queryset.py | 8 ++++---- tests/queryset.py | 24 ++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 024602a9..2de3477d 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -55,7 +55,7 @@ class BaseField(object): """ return self.to_python(value) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. """ return value @@ -81,7 +81,7 @@ class ObjectIdField(BaseField): raise ValidationError(e.message) return value - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) def validate(self, value): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 4c739d95..a4d315e7 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -136,7 +136,7 @@ class EmbeddedDocumentField(BaseField): def lookup_member(self, member_name): return self.document._fields.get(member_name) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -166,15 +166,16 @@ class ListField(BaseField): referenced_type = self.field.document_type # Get value from document instance if available value_list = instance._data.get(self.name) - deref_list = [] - for value in value_list: - # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) - deref_list.append(referenced_type._from_son(value)) - else: - deref_list.append(value) - instance._data[self.name] = deref_list + if value_list: + deref_list = [] + for value in value_list: + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + deref_list.append(referenced_type._from_son(value)) + else: + deref_list.append(value) + instance._data[self.name] = deref_list return super(ListField, self).__get__(instance, owner) @@ -197,7 +198,9 @@ class ListField(BaseField): raise ValidationError('All items in a list field must be of the ' 'specified type') - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): + if op in ('set', 'unset'): + return [self.field.to_mongo(v) for v in value] return self.field.to_mongo(value) def lookup_member(self, member_name): @@ -273,7 +276,7 @@ class ReferenceField(BaseField): collection = self.document_type._meta['collection'] return pymongo.dbref.DBRef(collection, id_) - def prepare_query_value(self, value): + def prepare_query_value(self, op, value): return self.to_mongo(value) def validate(self, value): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 8257e5f4..4c57bd7b 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -284,10 +284,10 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'): - value = field.prepare_query_value(value) + value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all'): # 'in', 'nin' and 'all' require a list of values - value = [field.prepare_query_value(v) for v in value] + value = [field.prepare_query_value(op, v) for v in value] if op: value = {'$' + op: value} @@ -487,9 +487,9 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] if op in (None, 'set', 'unset', 'push', 'pull'): - value = field.prepare_query_value(value) + value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): - value = [field.prepare_query_value(v) for v in value] + value = [field.prepare_query_value(op, v) for v in value] key = '.'.join(parts) diff --git a/tests/queryset.py b/tests/queryset.py index d05fac29..7939997d 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -594,6 +594,30 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_update_value_conversion(self): + """Ensure that values used in updates are converted before use. + """ + class Group(Document): + members = ListField(ReferenceField(self.Person)) + + Group.drop_collection() + + user1 = self.Person(name='user1') + user1.save() + user2 = self.Person(name='user2') + user2.save() + + group = Group() + group.save() + + Group.objects(id=group.id).update(set__members=[user1, user2]) + group.reload() + + self.assertTrue(len(group.members) == 2) + self.assertEqual(group.members[0].name, user1.name) + self.assertEqual(group.members[1].name, user2.name) + + Group.drop_collection() def test_types_index(self): """Ensure that and index is used when '_types' is being used in a