Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
2579e0b840 make EmbeddedDocument not hashable by default 2017-04-16 21:27:44 -04:00
4 changed files with 38 additions and 18 deletions

View File

@@ -1,6 +1,6 @@
import copy
import numbers
from collections.abc import Hashable
from collections import Hashable
from functools import partial
from bson import ObjectId, json_util
@@ -272,13 +272,6 @@ class BaseDocument(object):
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
if getattr(self, 'pk', None) is None:
# For new object
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
def clean(self):
"""
Hook for doing document level data cleaning before validation is run.

View File

@@ -60,6 +60,12 @@ class EmbeddedDocument(BaseDocument):
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
# A generic embedded document doesn't have any immutable properties
# that describe it uniquely, hence it shouldn't be hashable. You can
# define your own __hash__ method on a subclass if you need your
# embedded documents to be hashable.
__hash__ = None
def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._instance = None
@@ -160,6 +166,15 @@ class Document(BaseDocument):
"""Set the primary key."""
return setattr(self, self._meta['id_field'], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
and doesn't have a PK yet, return the default object hash instead.
"""
if self.pk is None:
return super(BaseDocument, self).__hash__()
else:
return hash(self.pk)
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
@@ -808,7 +823,7 @@ class Document(BaseDocument):
collection = cls._get_collection()
# 746: when connection is via mongos, the read preference is not necessarily an indication that
# this code runs on a secondary
if collection.is_mongos is not None and collection.read_preference.mode > 1:
if not collection.is_mongos and collection.read_preference > 1:
return
# determine if an index which we are creating includes

View File

@@ -6,7 +6,7 @@ import socket
import time
import uuid
import warnings
from collections.abc import Mapping
from collections import Mapping
from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON

View File

@@ -2164,7 +2164,7 @@ class InstanceTest(unittest.TestCase):
class BlogPost(Document):
pass
# Clear old datas
# Clear old data
User.drop_collection()
BlogPost.drop_collection()
@@ -2176,17 +2176,18 @@ class InstanceTest(unittest.TestCase):
b1 = BlogPost.objects.create()
b2 = BlogPost.objects.create()
# in List
# Make sure docs are properly identified in a list (__eq__ is used
# for the comparison).
all_user_list = list(User.objects.all())
self.assertTrue(u1 in all_user_list)
self.assertTrue(u2 in all_user_list)
self.assertTrue(u3 in all_user_list)
self.assertFalse(u4 in all_user_list) # New object
self.assertFalse(b1 in all_user_list) # Other object
self.assertFalse(b2 in all_user_list) # Other object
self.assertTrue(u4 not in all_user_list) # New object
self.assertTrue(b1 not in all_user_list) # Other object
self.assertTrue(b2 not in all_user_list) # Other object
# in Dict
# Make sure docs can be used as keys in a dict (__hash__ is used
# for hashing the docs).
all_user_dic = {}
for u in User.objects.all():
all_user_dic[u] = "OK"
@@ -2198,9 +2199,20 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(all_user_dic.get(b1, False), False) # Other object
self.assertEqual(all_user_dic.get(b2, False), False) # Other object
# in Set
# Make sure docs are properly identified in a set (__hash__ is used
# for hashing the docs).
all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set)
self.assertTrue(u4 not in all_user_set)
self.assertTrue(b1 not in all_user_list)
self.assertTrue(b2 not in all_user_list)
# Make sure duplicate docs aren't accepted in the set
self.assertEqual(len(all_user_set), 3)
all_user_set.add(u1)
all_user_set.add(u2)
all_user_set.add(u3)
self.assertEqual(len(all_user_set), 3)
def test_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])