Merge pull request #567 from tomprimozic/master

Implemented equality between Documents and DBRefs
This commit is contained in:
Ross Lawley 2014-06-27 11:37:24 +01:00
commit 324e3972a6
3 changed files with 64 additions and 3 deletions

View File

@ -229,8 +229,9 @@ class BaseDocument(object):
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'): if isinstance(other, self.__class__) and hasattr(other, 'id'):
if self.id == other.id: return self.id == other.id
return True if isinstance(other, DBRef):
return self._get_collection_name() == other.collection and self.id == other.id
return False return False
def __ne__(self, other): def __ne__(self, other):

View File

@ -69,7 +69,7 @@ class EmbeddedDocument(BaseDocument):
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return self.to_mongo() == other.to_mongo() return self._data == other._data
return False return False
def __ne__(self, other): def __ne__(self, other):

View File

@ -2556,5 +2556,65 @@ class InstanceTest(unittest.TestCase):
f1.ref # Dereferences lazily f1.ref # Dereferences lazily
self.assertEqual(f1, f2) self.assertEqual(f1, f2)
def test_dbref_equality(self):
class Test2(Document):
name = StringField()
class Test3(Document):
name = StringField()
class Test(Document):
name = StringField()
test2 = ReferenceField('Test2')
test3 = ReferenceField('Test3')
Test.drop_collection()
Test2.drop_collection()
Test3.drop_collection()
t2 = Test2(name='a')
t2.save()
t3 = Test3(name='x')
t3.id = t2.id
t3.save()
t = Test(name='b', test2=t2, test3=t3)
f = Test._from_son(t.to_mongo())
dbref2 = f._data['test2']
obj2 = f.test2
self.assertTrue(isinstance(dbref2, DBRef))
self.assertTrue(isinstance(obj2, Test2))
self.assertTrue(obj2.id == dbref2.id)
self.assertTrue(obj2 == dbref2)
self.assertTrue(dbref2 == obj2)
dbref3 = f._data['test3']
obj3 = f.test3
self.assertTrue(isinstance(dbref3, DBRef))
self.assertTrue(isinstance(obj3, Test3))
self.assertTrue(obj3.id == dbref3.id)
self.assertTrue(obj3 == dbref3)
self.assertTrue(dbref3 == obj3)
self.assertTrue(obj2.id == obj3.id)
self.assertTrue(dbref2.id == dbref3.id)
self.assertFalse(dbref2 == dbref3)
self.assertFalse(dbref3 == dbref2)
self.assertTrue(dbref2 != dbref3)
self.assertTrue(dbref3 != dbref2)
self.assertFalse(obj2 == dbref3)
self.assertFalse(dbref3 == obj2)
self.assertTrue(obj2 != dbref3)
self.assertTrue(dbref3 != obj2)
self.assertFalse(obj3 == dbref2)
self.assertFalse(dbref2 == obj3)
self.assertTrue(obj3 != dbref2)
self.assertTrue(dbref2 != obj3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()