From c5c7378c63a7cf12223d1778a66f262f6c7eff6a Mon Sep 17 00:00:00 2001 From: tprimozi Date: Tue, 4 Feb 2014 13:41:17 +0000 Subject: [PATCH 1/2] Implemented equality between Documents and DBRefs --- mongoengine/base/document.py | 5 +++-- mongoengine/document.py | 2 +- tests/document/instance.py | 26 ++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f5eae8ff..a7666efd 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -214,8 +214,9 @@ class BaseDocument(object): def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True + return self.id == other.id + if isinstance(other, DBRef): + return self.id == other.id return False def __ne__(self, other): diff --git a/mongoengine/document.py b/mongoengine/document.py index 114778eb..60ecfe9e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -67,7 +67,7 @@ class EmbeddedDocument(BaseDocument): def __eq__(self, other): if isinstance(other, self.__class__): - return self.to_mongo() == other.to_mongo() + return self._data == other._data return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..a842a221 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2452,5 +2452,31 @@ class InstanceTest(unittest.TestCase): f1.ref # Dereferences lazily self.assertEqual(f1, f2) + def test_dbref_equality(self): + class Test2(Document): + name = StringField() + + class Test(Document): + name = StringField() + test2 = ReferenceField('Test2') + + Test.drop_collection() + Test2.drop_collection() + + t2 = Test2(name='a') + t2.save() + + t = Test(name='b', test2 = t2) + + f = Test._from_son(t.to_mongo()) + + dbref = f._data['test2'] + obj = f.test2 + self.assertTrue(isinstance(dbref, DBRef)) + self.assertTrue(isinstance(obj, Test2)) + self.assertTrue(obj.id == dbref.id) + self.assertTrue(obj == dbref) + self.assertTrue(dbref == obj) + if __name__ == '__main__': unittest.main() From 0523c2ea4b6d4f4788fe5c99581fa504c16eba74 Mon Sep 17 00:00:00 2001 From: tprimozi Date: Thu, 13 Feb 2014 18:12:33 +0000 Subject: [PATCH 2/2] Fixed document equality: documents in different collections can have equal ids. --- mongoengine/base/document.py | 2 +- tests/document/instance.py | 50 ++++++++++++++++++++++++++++++------ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index a7666efd..be0635ce 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -216,7 +216,7 @@ class BaseDocument(object): if isinstance(other, self.__class__) and hasattr(other, 'id'): return self.id == other.id if isinstance(other, DBRef): - return self.id == other.id + return self._get_collection_name() == other.collection and self.id == other.id return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index a842a221..cdf2b2c1 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2456,27 +2456,61 @@ class InstanceTest(unittest.TestCase): class Test2(Document): name = StringField() + class Test3(Document): + name = StringField() + class Test(Document): name = StringField() test2 = ReferenceField('Test2') + test3 = ReferenceField('Test3') Test.drop_collection() Test2.drop_collection() + Test3.drop_collection() t2 = Test2(name='a') t2.save() - t = Test(name='b', test2 = t2) + t3 = Test3(name='x') + t3.id = t2.id + t3.save() + + t = Test(name='b', test2=t2, test3=t3) f = Test._from_son(t.to_mongo()) - dbref = f._data['test2'] - obj = f.test2 - self.assertTrue(isinstance(dbref, DBRef)) - self.assertTrue(isinstance(obj, Test2)) - self.assertTrue(obj.id == dbref.id) - self.assertTrue(obj == dbref) - self.assertTrue(dbref == obj) + dbref2 = f._data['test2'] + obj2 = f.test2 + self.assertTrue(isinstance(dbref2, DBRef)) + self.assertTrue(isinstance(obj2, Test2)) + self.assertTrue(obj2.id == dbref2.id) + self.assertTrue(obj2 == dbref2) + self.assertTrue(dbref2 == obj2) + + dbref3 = f._data['test3'] + obj3 = f.test3 + self.assertTrue(isinstance(dbref3, DBRef)) + self.assertTrue(isinstance(obj3, Test3)) + self.assertTrue(obj3.id == dbref3.id) + self.assertTrue(obj3 == dbref3) + self.assertTrue(dbref3 == obj3) + + self.assertTrue(obj2.id == obj3.id) + self.assertTrue(dbref2.id == dbref3.id) + self.assertFalse(dbref2 == dbref3) + self.assertFalse(dbref3 == dbref2) + self.assertTrue(dbref2 != dbref3) + self.assertTrue(dbref3 != dbref2) + + self.assertFalse(obj2 == dbref3) + self.assertFalse(dbref3 == obj2) + self.assertTrue(obj2 != dbref3) + self.assertTrue(dbref3 != obj2) + + self.assertFalse(obj3 == dbref2) + self.assertFalse(dbref2 == obj3) + self.assertTrue(obj3 != dbref2) + self.assertTrue(dbref2 != obj3) if __name__ == '__main__': unittest.main()