Fix merge MongoEngine/mongoengine#799
This commit is contained in:
		| @@ -241,10 +241,12 @@ class BaseDocument(object): | |||||||
|         return txt_type('%s object' % self.__class__.__name__) |         return txt_type('%s object' % self.__class__.__name__) | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id'): |         if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: | ||||||
|             return self.id == other.id |             return self.id == other.id | ||||||
|         if isinstance(other, DBRef): |         if isinstance(other, DBRef): | ||||||
|             return self._get_collection_name() == other.collection and self.id == other.id |             return self._get_collection_name() == other.collection and self.id == other.id | ||||||
|  |         if self.id is None: | ||||||
|  |             return self is other | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def __ne__(self, other): |     def __ne__(self, other): | ||||||
|   | |||||||
| @@ -2542,6 +2542,10 @@ class InstanceTest(unittest.TestCase): | |||||||
|             doc_name = StringField() |             doc_name = StringField() | ||||||
|             doc = EmbeddedDocumentField(Embedded) |             doc = EmbeddedDocumentField(Embedded) | ||||||
|  |  | ||||||
|  |             def __eq__(self, other): | ||||||
|  |                 return (self.doc_name == other.doc_name and | ||||||
|  |                         self.doc == other.doc) | ||||||
|  |  | ||||||
|         classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) |         classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) | ||||||
|         dict_doc = Doc(**{"doc_name": "my doc", |         dict_doc = Doc(**{"doc_name": "my doc", | ||||||
|                           "doc": {"name": "embedded doc"}}) |                           "doc": {"name": "embedded doc"}}) | ||||||
| @@ -2558,6 +2562,10 @@ class InstanceTest(unittest.TestCase): | |||||||
|             doc_name = StringField() |             doc_name = StringField() | ||||||
|             docs = ListField(EmbeddedDocumentField(Embedded)) |             docs = ListField(EmbeddedDocumentField(Embedded)) | ||||||
|  |  | ||||||
|  |             def __eq__(self, other): | ||||||
|  |                 return (self.doc_name == other.doc_name and | ||||||
|  |                         self.docs == other.docs) | ||||||
|  |  | ||||||
|         classic_doc = Doc(doc_name="my doc", docs=[ |         classic_doc = Doc(doc_name="my doc", docs=[ | ||||||
|                           Embedded(name="embedded doc1"), |                           Embedded(name="embedded doc1"), | ||||||
|                           Embedded(name="embedded doc2")]) |                           Embedded(name="embedded doc2")]) | ||||||
| @@ -2792,5 +2800,16 @@ class InstanceTest(unittest.TestCase): | |||||||
|         u_from_db = User.objects.get(name='user') |         u_from_db = User.objects.get(name='user') | ||||||
|         self.assertEquals(u_from_db.height, None) |         self.assertEquals(u_from_db.height, None) | ||||||
|  |  | ||||||
|  |     def test_not_saved_eq(self): | ||||||
|  |         """Ensure we can compare documents not saved. | ||||||
|  |         """ | ||||||
|  |         class Person(Document): | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         p = Person() | ||||||
|  |         p1 = Person() | ||||||
|  |         self.assertNotEqual(p, p1) | ||||||
|  |         self.assertEqual(p, p) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -51,6 +51,10 @@ class TestJson(unittest.TestCase): | |||||||
|             string = StringField() |             string = StringField() | ||||||
|             embedded_field = EmbeddedDocumentField(Embedded) |             embedded_field = EmbeddedDocumentField(Embedded) | ||||||
|  |  | ||||||
|  |             def __eq__(self, other): | ||||||
|  |                 return (self.string == other.string and | ||||||
|  |                         self.embedded_field == other.embedded_field) | ||||||
|  |  | ||||||
|         doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) |         doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) | ||||||
|  |  | ||||||
|         doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) |         doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) | ||||||
| @@ -99,6 +103,10 @@ class TestJson(unittest.TestCase): | |||||||
|             generic_embedded_document_field = GenericEmbeddedDocumentField( |             generic_embedded_document_field = GenericEmbeddedDocumentField( | ||||||
|                                         default=lambda: EmbeddedDoc()) |                                         default=lambda: EmbeddedDoc()) | ||||||
|  |  | ||||||
|  |             def __eq__(self, other): | ||||||
|  |                 import json | ||||||
|  |                 return json.loads(self.to_json()) == json.loads(other.to_json()) | ||||||
|  |  | ||||||
|         doc = Doc() |         doc = Doc() | ||||||
|         self.assertEqual(doc, Doc.from_json(doc.to_json())) |         self.assertEqual(doc, Doc.from_json(doc.to_json())) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user