From 0523c2ea4b6d4f4788fe5c99581fa504c16eba74 Mon Sep 17 00:00:00 2001 From: tprimozi Date: Thu, 13 Feb 2014 18:12:33 +0000 Subject: [PATCH] Fixed document equality: documents in different collections can have equal ids. --- mongoengine/base/document.py | 2 +- tests/document/instance.py | 50 ++++++++++++++++++++++++++++++------ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index a7666efd..be0635ce 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -216,7 +216,7 @@ class BaseDocument(object): if isinstance(other, self.__class__) and hasattr(other, 'id'): return self.id == other.id if isinstance(other, DBRef): - return self.id == other.id + return self._get_collection_name() == other.collection and self.id == other.id return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index a842a221..cdf2b2c1 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2456,27 +2456,61 @@ class InstanceTest(unittest.TestCase): class Test2(Document): name = StringField() + class Test3(Document): + name = StringField() + class Test(Document): name = StringField() test2 = ReferenceField('Test2') + test3 = ReferenceField('Test3') Test.drop_collection() Test2.drop_collection() + Test3.drop_collection() t2 = Test2(name='a') t2.save() - t = Test(name='b', test2 = t2) + t3 = Test3(name='x') + t3.id = t2.id + t3.save() + + t = Test(name='b', test2=t2, test3=t3) f = Test._from_son(t.to_mongo()) - dbref = f._data['test2'] - obj = f.test2 - self.assertTrue(isinstance(dbref, DBRef)) - self.assertTrue(isinstance(obj, Test2)) - self.assertTrue(obj.id == dbref.id) - self.assertTrue(obj == dbref) - self.assertTrue(dbref == obj) + dbref2 = f._data['test2'] + obj2 = f.test2 + self.assertTrue(isinstance(dbref2, DBRef)) + self.assertTrue(isinstance(obj2, Test2)) + self.assertTrue(obj2.id == dbref2.id) + self.assertTrue(obj2 == dbref2) + self.assertTrue(dbref2 == obj2) + + dbref3 = f._data['test3'] + obj3 = f.test3 + self.assertTrue(isinstance(dbref3, DBRef)) + self.assertTrue(isinstance(obj3, Test3)) + self.assertTrue(obj3.id == dbref3.id) + self.assertTrue(obj3 == dbref3) + self.assertTrue(dbref3 == obj3) + + self.assertTrue(obj2.id == obj3.id) + self.assertTrue(dbref2.id == dbref3.id) + self.assertFalse(dbref2 == dbref3) + self.assertFalse(dbref3 == dbref2) + self.assertTrue(dbref2 != dbref3) + self.assertTrue(dbref3 != dbref2) + + self.assertFalse(obj2 == dbref3) + self.assertFalse(dbref3 == obj2) + self.assertTrue(obj2 != dbref3) + self.assertTrue(dbref3 != obj2) + + self.assertFalse(obj3 == dbref2) + self.assertFalse(dbref2 == obj3) + self.assertTrue(obj3 != dbref2) + self.assertTrue(dbref2 != obj3) if __name__ == '__main__': unittest.main()