diff --git a/AUTHORS b/AUTHORS index 10d04c68..1cf7d78a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -258,3 +258,4 @@ that much better: * Leonardo Domingues (https://github.com/leodmgs) * Agustin Barto (https://github.com/abarto) * Stankiewicz Mateusz (https://github.com/mas15) + * Felix Schultheiß (https://github.com/felix-smashdocs) diff --git a/mongoengine/document.py b/mongoengine/document.py index 4a57d511..801c8df8 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -464,9 +464,9 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): # insert_one will provoke UniqueError alongside save does not # therefore, it need to catch and call replace_one. if "_id" in doc: - raw_object = wc_collection.find_one_and_replace( - {"_id": doc["_id"]}, doc - ) + select_dict = {"_id": doc["_id"]} + select_dict = self._integrate_shard_key(doc, select_dict) + raw_object = wc_collection.find_one_and_replace(select_dict, doc) if raw_object: return doc["_id"] @@ -489,6 +489,23 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): return update_doc + def _integrate_shard_key(self, doc, select_dict): + """Integrates the collection's shard key to the `select_dict`, which will be used for the query. + The value from the shard key is taken from the `doc` and finally the select_dict is returned. + """ + + # Need to add shard key to query, or you get an error + shard_key = self._meta.get("shard_key", tuple()) + for k in shard_key: + path = self._lookup_field(k.split(".")) + actual_key = [p.db_field for p in path] + val = doc + for ak in actual_key: + val = val[ak] + select_dict[".".join(actual_key)] = val + + return select_dict + def _save_update(self, doc, save_condition, write_concern): """Update an existing document. @@ -504,15 +521,7 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): select_dict["_id"] = object_id - # Need to add shard key to query, or you get an error - shard_key = self._meta.get("shard_key", tuple()) - for k in shard_key: - path = self._lookup_field(k.split(".")) - actual_key = [p.db_field for p in path] - val = doc - for ak in actual_key: - val = val[ak] - select_dict[".".join(actual_key)] = val + select_dict = self._integrate_shard_key(doc, select_dict) update_doc = self._get_update_doc() if update_doc: @@ -919,7 +928,7 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): @classmethod def list_indexes(cls): - """ Lists all of the indexes that should be created for given + """Lists all of the indexes that should be created for given collection. It includes all the indexes from super- and sub-classes. """ if cls._meta.get("abstract"): @@ -984,7 +993,7 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): @classmethod def compare_indexes(cls): - """ Compares the indexes defined in MongoEngine with the ones + """Compares the indexes defined in MongoEngine with the ones existing in the database. Returns any missing/extra indexes. """ diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 9554659c..50533af9 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -500,7 +500,7 @@ class TestDocumentInstance(MongoDBTestCase): doc.reload() Animal.drop_collection() - def test_update_shard_key_routing(self): + def test_save_update_shard_key_routing(self): """Ensures updating a doc with a specified shard_key includes it in the query. """ @@ -528,6 +528,29 @@ class TestDocumentInstance(MongoDBTestCase): Animal.drop_collection() + def test_save_create_shard_key_routing(self): + """Ensures inserting a doc with a specified shard_key includes it in + the query. + """ + + class Animal(Document): + _id = UUIDField(binary=False, primary_key=True, default=uuid.uuid4) + is_mammal = BooleanField() + name = StringField() + meta = {"shard_key": ("is_mammal",)} + + Animal.drop_collection() + doc = Animal(is_mammal=True, name="Dog") + + with query_counter() as q: + doc.save() + query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] + assert query_op["op"] == "command" + assert query_op["command"]["findAndModify"] == "animal" + assert set(query_op["command"]["query"].keys()) == set(["_id", "is_mammal"]) + + Animal.drop_collection() + def test_reload_with_changed_fields(self): """Ensures reloading will not affect changed fields"""