diff --git a/AUTHORS b/AUTHORS index 4d5e69a3..411e274d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -229,3 +229,4 @@ that much better: * Emile Caron (https://github.com/emilecaron) * Amit Lichtenberg (https://github.com/amitlicht) * Lars Butler (https://github.com/larsbutler) + * George Macon (https://github.com/gmacon) diff --git a/docs/changelog.rst b/docs/changelog.rst index cd349793..6275d6d9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,10 @@ Changelog ========= +Changes in 0.10.2 +================= +- Allow shard key to point to a field in an embedded document. #551 + Changes in 0.10.1 ======================= - Fix infinite recursion with CASCADE delete rules under specific conditions. #1046 diff --git a/mongoengine/document.py b/mongoengine/document.py index 9d2d9c5f..bd2e7c5b 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -341,8 +341,12 @@ class Document(BaseDocument): select_dict['_id'] = object_id shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: - actual_key = self._db_field_map.get(k, k) - select_dict[actual_key] = doc[actual_key] + path = self._lookup_field(k.split('.')) + actual_key = [p.db_field for p in path] + val = doc + for ak in actual_key: + val = val[ak] + select_dict['.'.join(actual_key)] = val def is_new_object(last_error): if last_error is not None: @@ -444,7 +448,12 @@ class Document(BaseDocument): select_dict = {'pk': self.pk} shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: - select_dict[k] = getattr(self, k) + path = self._lookup_field(k.split('.')) + actual_key = [p.db_field for p in path] + val = self + for ak in actual_key: + val = getattr(val, ak) + select_dict['__'.join(actual_key)] = val return select_dict def update(self, **kwargs): diff --git a/tests/document/instance.py b/tests/document/instance.py index 56e6765a..5494ac6b 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -484,6 +484,20 @@ class InstanceTest(unittest.TestCase): doc.reload() Animal.drop_collection() + def test_reload_sharded_nested(self): + class SuperPhylum(EmbeddedDocument): + name = StringField() + + class Animal(Document): + superphylum = EmbeddedDocumentField(SuperPhylum) + meta = {'shard_key': ('superphylum.name',)} + + Animal.drop_collection() + doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) + doc.save() + doc.reload() + Animal.drop_collection() + def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly """ @@ -2715,6 +2729,32 @@ class InstanceTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + def test_shard_key_in_embedded_document(self): + class Foo(EmbeddedDocument): + foo = StringField() + + class Bar(Document): + meta = { + 'shard_key': ('foo.foo',) + } + foo = EmbeddedDocumentField(Foo) + bar = StringField() + + foo_doc = Foo(foo='hello') + bar_doc = Bar(foo=foo_doc, bar='world') + bar_doc.save() + + self.assertTrue(bar_doc.id is not None) + + bar_doc.bar = 'baz' + bar_doc.save() + + def change_shard_key(): + bar_doc.foo.foo = 'something' + bar_doc.save() + + self.assertRaises(OperationError, change_shard_key) + def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True)