Better BaseDocument equality check when not saved

When 2 instances of a Document had id = None they would be considered
equal unless an __eq__ were implemented.

We now return False for such case. It now behaves more similar to
Django's ORM.
This commit is contained in:
André Ericson 2014-11-09 16:13:49 -03:00
parent c4f7db6c04
commit 2af55baa9a
3 changed files with 30 additions and 1 deletions

View File

@ -241,10 +241,12 @@ class BaseDocument(object):
return txt_type('%s object' % self.__class__.__name__) return txt_type('%s object' % self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'): if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
return self.id == other.id return self.id == other.id
if isinstance(other, DBRef): if isinstance(other, DBRef):
return self._get_collection_name() == other.collection and self.id == other.id return self._get_collection_name() == other.collection and self.id == other.id
if self.id is None:
return self is other
return False return False
def __ne__(self, other): def __ne__(self, other):

View File

@ -2503,6 +2503,10 @@ class InstanceTest(unittest.TestCase):
doc_name = StringField() doc_name = StringField()
doc = EmbeddedDocumentField(Embedded) doc = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
return (self.doc_name == other.doc_name and
self.doc == other.doc)
classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc"))
dict_doc = Doc(**{"doc_name": "my doc", dict_doc = Doc(**{"doc_name": "my doc",
"doc": {"name": "embedded doc"}}) "doc": {"name": "embedded doc"}})
@ -2519,6 +2523,10 @@ class InstanceTest(unittest.TestCase):
doc_name = StringField() doc_name = StringField()
docs = ListField(EmbeddedDocumentField(Embedded)) docs = ListField(EmbeddedDocumentField(Embedded))
def __eq__(self, other):
return (self.doc_name == other.doc_name and
self.docs == other.docs)
classic_doc = Doc(doc_name="my doc", docs=[ classic_doc = Doc(doc_name="my doc", docs=[
Embedded(name="embedded doc1"), Embedded(name="embedded doc1"),
Embedded(name="embedded doc2")]) Embedded(name="embedded doc2")])
@ -2719,5 +2727,16 @@ class InstanceTest(unittest.TestCase):
self.assertEquals(p4.height, 189) self.assertEquals(p4.height, 189)
self.assertEquals(Person.objects(height=189).count(), 1) self.assertEquals(Person.objects(height=189).count(), 1)
def test_not_saved_eq(self):
"""Ensure we can compare documents not saved.
"""
class Person(Document):
pass
p = Person()
p1 = Person()
self.assertNotEqual(p, p1)
self.assertEqual(p, p)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -51,6 +51,10 @@ class TestJson(unittest.TestCase):
string = StringField() string = StringField()
embedded_field = EmbeddedDocumentField(Embedded) embedded_field = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
return (self.string == other.string and
self.embedded_field == other.embedded_field)
doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) doc = Doc(string="Hi", embedded_field=Embedded(string="Hi"))
doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) doc_json = doc.to_json(sort_keys=True, separators=(',', ':'))
@ -99,6 +103,10 @@ class TestJson(unittest.TestCase):
generic_embedded_document_field = GenericEmbeddedDocumentField( generic_embedded_document_field = GenericEmbeddedDocumentField(
default=lambda: EmbeddedDoc()) default=lambda: EmbeddedDoc())
def __eq__(self, other):
import json
return json.loads(self.to_json()) == json.loads(other.to_json())
doc = Doc() doc = Doc()
self.assertEqual(doc, Doc.from_json(doc.to_json())) self.assertEqual(doc, Doc.from_json(doc.to_json()))