diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f5eae8ff..a7666efd 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -214,8 +214,9 @@ class BaseDocument(object): def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True + return self.id == other.id + if isinstance(other, DBRef): + return self.id == other.id return False def __ne__(self, other): diff --git a/mongoengine/document.py b/mongoengine/document.py index 114778eb..60ecfe9e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -67,7 +67,7 @@ class EmbeddedDocument(BaseDocument): def __eq__(self, other): if isinstance(other, self.__class__): - return self.to_mongo() == other.to_mongo() + return self._data == other._data return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..a842a221 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2452,5 +2452,31 @@ class InstanceTest(unittest.TestCase): f1.ref # Dereferences lazily self.assertEqual(f1, f2) + def test_dbref_equality(self): + class Test2(Document): + name = StringField() + + class Test(Document): + name = StringField() + test2 = ReferenceField('Test2') + + Test.drop_collection() + Test2.drop_collection() + + t2 = Test2(name='a') + t2.save() + + t = Test(name='b', test2 = t2) + + f = Test._from_son(t.to_mongo()) + + dbref = f._data['test2'] + obj = f.test2 + self.assertTrue(isinstance(dbref, DBRef)) + self.assertTrue(isinstance(obj, Test2)) + self.assertTrue(obj.id == dbref.id) + self.assertTrue(obj == dbref) + self.assertTrue(dbref == obj) + if __name__ == '__main__': unittest.main()