From 1951b52aa5b62839f000deacb5ef9a02592c03f4 Mon Sep 17 00:00:00 2001 From: Emmanuel Leblond Date: Tue, 9 Jun 2015 16:20:35 +0200 Subject: [PATCH] Fix #1017 (document clash between same ids but different collections) --- mongoengine/dereference.py | 25 ++++++++++++++++--------- tests/test_dereference.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 0428397c..8e8920d4 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -128,21 +128,25 @@ class DeReference(object): """ object_map = {} for collection, dbrefs in self.reference_map.iteritems(): - refs = [dbref for dbref in dbrefs - if unicode(dbref).encode('utf-8') not in object_map] if hasattr(collection, 'objects'): # We have a document class for the refs + col_name = collection._get_collection_name() + refs = [dbref for dbref in dbrefs + if (col_name, dbref) not in object_map] references = collection.objects.in_bulk(refs) for key, doc in references.iteritems(): - object_map[key] = doc + object_map[(col_name, key)] = doc else: # Generic reference: use the refs data to convert to document if isinstance(doc_type, (ListField, DictField, MapField,)): continue + refs = [dbref for dbref in dbrefs + if (collection, dbref) not in object_map] + if doc_type: references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: doc = doc_type._from_son(ref) - object_map[doc.id] = doc + object_map[(collection, doc.id)] = doc else: references = get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: @@ -154,7 +158,7 @@ class DeReference(object): for x in collection.split('_')))._from_son(ref) else: doc = doc_type._from_son(ref) - object_map[doc.id] = doc + object_map[(collection, doc.id)] = doc return object_map def _attach_objects(self, items, depth=0, instance=None, name=None): @@ -180,7 +184,8 @@ class DeReference(object): if isinstance(items, (dict, SON)): if '_ref' in items: - return self.object_map.get(items['_ref'].id, items) + return self.object_map.get( + (items['_ref'].collection, items['_ref'].id), items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) _cls = doc._data.pop('_cls', None) @@ -216,9 +221,11 @@ class DeReference(object): for field_name, field in v._fields.iteritems(): v = data[k]._data.get(field_name, None) if isinstance(v, (DBRef)): - data[k]._data[field_name] = self.object_map.get(v.id, v) + data[k]._data[field_name] = self.object_map.get( + (v.collection, v.id), v) elif isinstance(v, (dict, SON)) and '_ref' in v: - data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) + data[k]._data[field_name] = self.object_map.get( + (v['_ref'].collection , v['_ref'].id), v) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = "{0}.{1}.{2}".format(name, k, field_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) @@ -226,7 +233,7 @@ class DeReference(object): item_name = '%s.%s' % (name, k) if name else name data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) elif hasattr(v, 'id'): - data[k] = self.object_map.get(v.id, v) + data[k] = self.object_map.get((v.collection, v.id), v) if instance and name: if is_list: diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 2115b45a..e1ae3740 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1026,6 +1026,43 @@ class FieldTest(unittest.TestCase): self.assertEqual(type(foo.bar), Bar) self.assertEqual(type(foo.baz), Baz) + + def test_document_reload_reference_integrity(self): + """ + Ensure reloading a document with multiple similar id + in different collections doesn't mix them. + """ + class Topic(Document): + id = IntField(primary_key=True) + class User(Document): + id = IntField(primary_key=True) + name = StringField() + class Message(Document): + id = IntField(primary_key=True) + topic = ReferenceField(Topic) + author = ReferenceField(User) + + Topic.drop_collection() + User.drop_collection() + Message.drop_collection() + + # All objects share the same id, but each in a different collection + topic = Topic(id=1).save() + user = User(id=1, name='user-name').save() + Message(id=1, topic=topic, author=user).save() + + concurrent_change_user = User.objects.get(id=1) + concurrent_change_user.name = 'new-name' + concurrent_change_user.save() + self.assertNotEqual(user.name, 'new-name') + + msg = Message.objects.get(id=1) + msg.reload() + self.assertEqual(msg.topic, topic) + self.assertEqual(msg.author, user) + self.assertEqual(msg.author.name, 'new-name') + + def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly """