From 2af55baa9aa9c038df948b5545a2ee44d82258b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ericson?= Date: Sun, 9 Nov 2014 16:13:49 -0300 Subject: [PATCH 1/2] Better BaseDocument equality check when not saved When 2 instances of a Document had id = None they would be considered equal unless an __eq__ were implemented. We now return False for such case. It now behaves more similar to Django's ORM. --- mongoengine/base/document.py | 4 +++- tests/document/instance.py | 19 +++++++++++++++++++ tests/document/json_serialisation.py | 8 ++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index bf5bdf79..6be685ae 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -241,10 +241,12 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): + if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: return self.id == other.id if isinstance(other, DBRef): return self._get_collection_name() == other.collection and self.id == other.id + if self.id is None: + return self is other return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index a22f3fbf..6dae88a0 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2503,6 +2503,10 @@ class InstanceTest(unittest.TestCase): doc_name = StringField() doc = EmbeddedDocumentField(Embedded) + def __eq__(self, other): + return (self.doc_name == other.doc_name and + self.doc == other.doc) + classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}}) @@ -2519,6 +2523,10 @@ class InstanceTest(unittest.TestCase): doc_name = StringField() docs = ListField(EmbeddedDocumentField(Embedded)) + def __eq__(self, other): + return (self.doc_name == other.doc_name and + self.docs == other.docs) + classic_doc = Doc(doc_name="my doc", docs=[ Embedded(name="embedded doc1"), Embedded(name="embedded doc2")]) @@ -2719,5 +2727,16 @@ class InstanceTest(unittest.TestCase): self.assertEquals(p4.height, 189) self.assertEquals(Person.objects(height=189).count(), 1) + def test_not_saved_eq(self): + """Ensure we can compare documents not saved. + """ + class Person(Document): + pass + + p = Person() + p1 = Person() + self.assertNotEqual(p, p1) + self.assertEqual(p, p) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index fd7795f7..f47b5de5 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -51,6 +51,10 @@ class TestJson(unittest.TestCase): string = StringField() embedded_field = EmbeddedDocumentField(Embedded) + def __eq__(self, other): + return (self.string == other.string and + self.embedded_field == other.embedded_field) + doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) @@ -99,6 +103,10 @@ class TestJson(unittest.TestCase): generic_embedded_document_field = GenericEmbeddedDocumentField( default=lambda: EmbeddedDoc()) + def __eq__(self, other): + import json + return json.loads(self.to_json()) == json.loads(other.to_json()) + doc = Doc() self.assertEqual(doc, Doc.from_json(doc.to_json())) From 2e01e0c30e7dbf373add8d97a1a76375eeacf3fe Mon Sep 17 00:00:00 2001 From: Yohan Graterol Date: Sun, 9 Nov 2014 21:32:50 -0500 Subject: [PATCH 2/2] Added merge to changelog.rst --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index d1714e78..eb8ccc34 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Better BaseDocument equality check when not saved #798 - OperationError: Shard Keys are immutable. Tried to update id even though the document is not yet saved #771 - with_limit_and_skip for count should default like in pymongo #759 - Fix storing value of precision attribute in DecimalField #787