diff --git a/docs/changelog.rst b/docs/changelog.rst index d1714e78..eb8ccc34 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Better BaseDocument equality check when not saved #798 - OperationError: Shard Keys are immutable. Tried to update id even though the document is not yet saved #771 - with_limit_and_skip for count should default like in pymongo #759 - Fix storing value of precision attribute in DecimalField #787 diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 8995a304..aea251e4 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -241,10 +241,12 @@ class BaseDocument(object): return txt_type('%s object' % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): + if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: return self.id == other.id if isinstance(other, DBRef): return self._get_collection_name() == other.collection and self.id == other.id + if self.id is None: + return self is other return False def __ne__(self, other): diff --git a/tests/document/instance.py b/tests/document/instance.py index 36118512..40c25f8d 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2542,6 +2542,10 @@ class InstanceTest(unittest.TestCase): doc_name = StringField() doc = EmbeddedDocumentField(Embedded) + def __eq__(self, other): + return (self.doc_name == other.doc_name and + self.doc == other.doc) + classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}}) @@ -2558,6 +2562,10 @@ class InstanceTest(unittest.TestCase): doc_name = StringField() docs = ListField(EmbeddedDocumentField(Embedded)) + def __eq__(self, other): + return (self.doc_name == other.doc_name and + self.docs == other.docs) + classic_doc = Doc(doc_name="my doc", docs=[ Embedded(name="embedded doc1"), Embedded(name="embedded doc2")]) @@ -2792,5 +2800,16 @@ class InstanceTest(unittest.TestCase): u_from_db = User.objects.get(name='user') self.assertEquals(u_from_db.height, None) + def test_not_saved_eq(self): + """Ensure we can compare documents not saved. + """ + class Person(Document): + pass + + p = Person() + p1 = Person() + self.assertNotEqual(p, p1) + self.assertEqual(p, p) + if __name__ == '__main__': unittest.main() diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index fd7795f7..f47b5de5 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -51,6 +51,10 @@ class TestJson(unittest.TestCase): string = StringField() embedded_field = EmbeddedDocumentField(Embedded) + def __eq__(self, other): + return (self.string == other.string and + self.embedded_field == other.embedded_field) + doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) @@ -99,6 +103,10 @@ class TestJson(unittest.TestCase): generic_embedded_document_field = GenericEmbeddedDocumentField( default=lambda: EmbeddedDoc()) + def __eq__(self, other): + import json + return json.loads(self.to_json()) == json.loads(other.to_json()) + doc = Doc() self.assertEqual(doc, Doc.from_json(doc.to_json()))