diff --git a/mongoengine/document.py b/mongoengine/document.py index 4a57d511..a3a30bc1 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -464,8 +464,10 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): # insert_one will provoke UniqueError alongside save does not # therefore, it need to catch and call replace_one. if "_id" in doc: + select_dict = {"_id": doc["_id"]} + select_dict = self._integrate_shard_key(doc, select_dict) raw_object = wc_collection.find_one_and_replace( - {"_id": doc["_id"]}, doc + select_dict, doc ) if raw_object: return doc["_id"] @@ -489,6 +491,21 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): return update_doc + def _integrate_shard_key(self, doc, select_dict): + + # Need to add shard key to query, or you get an error + shard_key = self._meta.get("shard_key", tuple()) + for k in shard_key: + path = self._lookup_field(k.split(".")) + actual_key = [p.db_field for p in path] + val = doc + for ak in actual_key: + val = val[ak] + select_dict[".".join(actual_key)] = val + + return select_dict + + def _save_update(self, doc, save_condition, write_concern): """Update an existing document. @@ -504,15 +521,7 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass): select_dict["_id"] = object_id - # Need to add shard key to query, or you get an error - shard_key = self._meta.get("shard_key", tuple()) - for k in shard_key: - path = self._lookup_field(k.split(".")) - actual_key = [p.db_field for p in path] - val = doc - for ak in actual_key: - val = val[ak] - select_dict[".".join(actual_key)] = val + select_dict = self._integrate_shard_key(doc, select_dict) update_doc = self._get_update_doc() if update_doc: