This commit is contained in:
Yohan Graterol 2014-11-09 21:31:56 -05:00
commit a12b2de74a
3 changed files with 30 additions and 1 deletions

View File

@ -241,10 +241,12 @@ class BaseDocument(object):
return txt_type('%s object' % self.__class__.__name__) return txt_type('%s object' % self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'): if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
return self.id == other.id return self.id == other.id
if isinstance(other, DBRef): if isinstance(other, DBRef):
return self._get_collection_name() == other.collection and self.id == other.id return self._get_collection_name() == other.collection and self.id == other.id
if self.id is None:
return self is other
return False return False
def __ne__(self, other): def __ne__(self, other):

View File

@ -2542,6 +2542,10 @@ class InstanceTest(unittest.TestCase):
doc_name = StringField() doc_name = StringField()
doc = EmbeddedDocumentField(Embedded) doc = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
return (self.doc_name == other.doc_name and
self.doc == other.doc)
classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc"))
dict_doc = Doc(**{"doc_name": "my doc", dict_doc = Doc(**{"doc_name": "my doc",
"doc": {"name": "embedded doc"}}) "doc": {"name": "embedded doc"}})
@ -2558,6 +2562,10 @@ class InstanceTest(unittest.TestCase):
doc_name = StringField() doc_name = StringField()
docs = ListField(EmbeddedDocumentField(Embedded)) docs = ListField(EmbeddedDocumentField(Embedded))
def __eq__(self, other):
return (self.doc_name == other.doc_name and
self.docs == other.docs)
classic_doc = Doc(doc_name="my doc", docs=[ classic_doc = Doc(doc_name="my doc", docs=[
Embedded(name="embedded doc1"), Embedded(name="embedded doc1"),
Embedded(name="embedded doc2")]) Embedded(name="embedded doc2")])
@ -2792,5 +2800,16 @@ class InstanceTest(unittest.TestCase):
u_from_db = User.objects.get(name='user') u_from_db = User.objects.get(name='user')
self.assertEquals(u_from_db.height, None) self.assertEquals(u_from_db.height, None)
def test_not_saved_eq(self):
"""Ensure we can compare documents not saved.
"""
class Person(Document):
pass
p = Person()
p1 = Person()
self.assertNotEqual(p, p1)
self.assertEqual(p, p)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -51,6 +51,10 @@ class TestJson(unittest.TestCase):
string = StringField() string = StringField()
embedded_field = EmbeddedDocumentField(Embedded) embedded_field = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
return (self.string == other.string and
self.embedded_field == other.embedded_field)
doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) doc = Doc(string="Hi", embedded_field=Embedded(string="Hi"))
doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) doc_json = doc.to_json(sort_keys=True, separators=(',', ':'))
@ -99,6 +103,10 @@ class TestJson(unittest.TestCase):
generic_embedded_document_field = GenericEmbeddedDocumentField( generic_embedded_document_field = GenericEmbeddedDocumentField(
default=lambda: EmbeddedDoc()) default=lambda: EmbeddedDoc())
def __eq__(self, other):
import json
return json.loads(self.to_json()) == json.loads(other.to_json())
doc = Doc() doc = Doc()
self.assertEqual(doc, Doc.from_json(doc.to_json())) self.assertEqual(doc, Doc.from_json(doc.to_json()))