diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index ce96837a..99c8af87 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -272,13 +272,6 @@ class BaseDocument(object): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - if getattr(self, 'pk', None) is None: - # For new object - return super(BaseDocument, self).__hash__() - else: - return hash(self.pk) - def clean(self): """ Hook for doing document level data cleaning before validation is run. diff --git a/mongoengine/document.py b/mongoengine/document.py index b79e5e97..0b903d20 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -60,6 +60,12 @@ class EmbeddedDocument(BaseDocument): my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass + # A generic embedded document doesn't have any immutable properties + # that describe it uniquely, hence it shouldn't be hashable. You can + # define your own __hash__ method on a subclass if you need your + # embedded documents to be hashable. + __hash__ = None + def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) self._instance = None @@ -160,6 +166,15 @@ class Document(BaseDocument): """Set the primary key.""" return setattr(self, self._meta['id_field'], value) + def __hash__(self): + """Return the hash based on the PK of this document. If it's new + and doesn't have a PK yet, return the default object hash instead. + """ + if self.pk is None: + return super(BaseDocument, self).__hash__() + else: + return hash(self.pk) + @classmethod def _get_db(cls): """Some Model using other db_alias""" diff --git a/tests/document/instance.py b/tests/document/instance.py index c59de96f..c98b1405 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2164,7 +2164,7 @@ class InstanceTest(unittest.TestCase): class BlogPost(Document): pass - # Clear old datas + # Clear old data User.drop_collection() BlogPost.drop_collection() @@ -2176,17 +2176,18 @@ class InstanceTest(unittest.TestCase): b1 = BlogPost.objects.create() b2 = BlogPost.objects.create() - # in List + # Make sure docs are properly identified in a list (__eq__ is used + # for the comparison). all_user_list = list(User.objects.all()) - self.assertTrue(u1 in all_user_list) self.assertTrue(u2 in all_user_list) self.assertTrue(u3 in all_user_list) - self.assertFalse(u4 in all_user_list) # New object - self.assertFalse(b1 in all_user_list) # Other object - self.assertFalse(b2 in all_user_list) # Other object + self.assertTrue(u4 not in all_user_list) # New object + self.assertTrue(b1 not in all_user_list) # Other object + self.assertTrue(b2 not in all_user_list) # Other object - # in Dict + # Make sure docs can be used as keys in a dict (__hash__ is used + # for hashing the docs). all_user_dic = {} for u in User.objects.all(): all_user_dic[u] = "OK" @@ -2198,9 +2199,20 @@ class InstanceTest(unittest.TestCase): self.assertEqual(all_user_dic.get(b1, False), False) # Other object self.assertEqual(all_user_dic.get(b2, False), False) # Other object - # in Set + # Make sure docs are properly identified in a set (__hash__ is used + # for hashing the docs). all_user_set = set(User.objects.all()) self.assertTrue(u1 in all_user_set) + self.assertTrue(u4 not in all_user_set) + self.assertTrue(b1 not in all_user_list) + self.assertTrue(b2 not in all_user_list) + + # Make sure duplicate docs aren't accepted in the set + self.assertEqual(len(all_user_set), 3) + all_user_set.add(u1) + all_user_set.add(u2) + all_user_set.add(u3) + self.assertEqual(len(all_user_set), 3) def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])