Implemented equality between Documents and DBRefs
This commit is contained in:
parent
d4b3649640
commit
c5c7378c63
@ -214,8 +214,9 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, self.__class__) and hasattr(other, 'id'):
|
if isinstance(other, self.__class__) and hasattr(other, 'id'):
|
||||||
if self.id == other.id:
|
return self.id == other.id
|
||||||
return True
|
if isinstance(other, DBRef):
|
||||||
|
return self.id == other.id
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
|
@ -67,7 +67,7 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self.to_mongo() == other.to_mongo()
|
return self._data == other._data
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
|
@ -2452,5 +2452,31 @@ class InstanceTest(unittest.TestCase):
|
|||||||
f1.ref # Dereferences lazily
|
f1.ref # Dereferences lazily
|
||||||
self.assertEqual(f1, f2)
|
self.assertEqual(f1, f2)
|
||||||
|
|
||||||
|
def test_dbref_equality(self):
|
||||||
|
class Test2(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Test(Document):
|
||||||
|
name = StringField()
|
||||||
|
test2 = ReferenceField('Test2')
|
||||||
|
|
||||||
|
Test.drop_collection()
|
||||||
|
Test2.drop_collection()
|
||||||
|
|
||||||
|
t2 = Test2(name='a')
|
||||||
|
t2.save()
|
||||||
|
|
||||||
|
t = Test(name='b', test2 = t2)
|
||||||
|
|
||||||
|
f = Test._from_son(t.to_mongo())
|
||||||
|
|
||||||
|
dbref = f._data['test2']
|
||||||
|
obj = f.test2
|
||||||
|
self.assertTrue(isinstance(dbref, DBRef))
|
||||||
|
self.assertTrue(isinstance(obj, Test2))
|
||||||
|
self.assertTrue(obj.id == dbref.id)
|
||||||
|
self.assertTrue(obj == dbref)
|
||||||
|
self.assertTrue(dbref == obj)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user