Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
2579e0b840 make EmbeddedDocument not hashable by default 2017-04-16 21:27:44 -04:00
4 changed files with 38 additions and 18 deletions

View File

@@ -1,6 +1,6 @@
import copy import copy
import numbers import numbers
from collections.abc import Hashable from collections import Hashable
from functools import partial from functools import partial
from bson import ObjectId, json_util from bson import ObjectId, json_util
@@ -272,13 +272,6 @@ class BaseDocument(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self):
if getattr(self, 'pk', None) is None:
# For new object
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
def clean(self): def clean(self):
""" """
Hook for doing document level data cleaning before validation is run. Hook for doing document level data cleaning before validation is run.

View File

@@ -60,6 +60,12 @@ class EmbeddedDocument(BaseDocument):
my_metaclass = DocumentMetaclass my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass __metaclass__ = DocumentMetaclass
# A generic embedded document doesn't have any immutable properties
# that describe it uniquely, hence it shouldn't be hashable. You can
# define your own __hash__ method on a subclass if you need your
# embedded documents to be hashable.
__hash__ = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs) super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None self._instance = None
@@ -160,6 +166,15 @@ class Document(BaseDocument):
"""Set the primary key.""" """Set the primary key."""
return setattr(self, self._meta['id_field'], value) return setattr(self, self._meta['id_field'], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
and doesn't have a PK yet, return the default object hash instead.
"""
if self.pk is None:
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
"""Some Model using other db_alias""" """Some Model using other db_alias"""
@@ -808,7 +823,7 @@ class Document(BaseDocument):
collection = cls._get_collection() collection = cls._get_collection()
# 746: when connection is via mongos, the read preference is not necessarily an indication that # 746: when connection is via mongos, the read preference is not necessarily an indication that
# this code runs on a secondary # this code runs on a secondary
if collection.is_mongos is not None and collection.read_preference.mode > 1: if not collection.is_mongos and collection.read_preference > 1:
return return
# determine if an index which we are creating includes # determine if an index which we are creating includes

View File

@@ -6,7 +6,7 @@ import socket
import time import time
import uuid import uuid
import warnings import warnings
from collections.abc import Mapping from collections import Mapping
from operator import itemgetter from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON from bson import Binary, DBRef, ObjectId, SON

View File

@@ -2164,7 +2164,7 @@ class InstanceTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
pass pass
# Clear old datas # Clear old data
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
@@ -2176,17 +2176,18 @@ class InstanceTest(unittest.TestCase):
b1 = BlogPost.objects.create() b1 = BlogPost.objects.create()
b2 = BlogPost.objects.create() b2 = BlogPost.objects.create()
# in List # Make sure docs are properly identified in a list (__eq__ is used
# for the comparison).
all_user_list = list(User.objects.all()) all_user_list = list(User.objects.all())
self.assertTrue(u1 in all_user_list) self.assertTrue(u1 in all_user_list)
self.assertTrue(u2 in all_user_list) self.assertTrue(u2 in all_user_list)
self.assertTrue(u3 in all_user_list) self.assertTrue(u3 in all_user_list)
self.assertFalse(u4 in all_user_list) # New object self.assertTrue(u4 not in all_user_list) # New object
self.assertFalse(b1 in all_user_list) # Other object self.assertTrue(b1 not in all_user_list) # Other object
self.assertFalse(b2 in all_user_list) # Other object self.assertTrue(b2 not in all_user_list) # Other object
# in Dict # Make sure docs can be used as keys in a dict (__hash__ is used
# for hashing the docs).
all_user_dic = {} all_user_dic = {}
for u in User.objects.all(): for u in User.objects.all():
all_user_dic[u] = "OK" all_user_dic[u] = "OK"
@@ -2198,9 +2199,20 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(all_user_dic.get(b1, False), False) # Other object self.assertEqual(all_user_dic.get(b1, False), False) # Other object
self.assertEqual(all_user_dic.get(b2, False), False) # Other object self.assertEqual(all_user_dic.get(b2, False), False) # Other object
# in Set # Make sure docs are properly identified in a set (__hash__ is used
# for hashing the docs).
all_user_set = set(User.objects.all()) all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set) self.assertTrue(u1 in all_user_set)
self.assertTrue(u4 not in all_user_set)
self.assertTrue(b1 not in all_user_list)
self.assertTrue(b2 not in all_user_list)
# Make sure duplicate docs aren't accepted in the set
self.assertEqual(len(all_user_set), 3)
all_user_set.add(u1)
all_user_set.add(u2)
all_user_set.add(u3)
self.assertEqual(len(all_user_set), 3)
def test_picklable(self): def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])