diff --git a/mongoengine/document.py b/mongoengine/document.py index 0fa2460d..b79e5e97 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -332,68 +332,20 @@ class Document(BaseDocument): signals.pre_save_post_validation.send(self.__class__, document=self, created=created, **signal_kwargs) + if self._meta.get('auto_create_index', True): + self.ensure_indexes() + try: - collection = self._get_collection() - if self._meta.get('auto_create_index', True): - self.ensure_indexes() + # Save a new document or update an existing one if created: - if force_insert: - object_id = collection.insert(doc, **write_concern) - else: - object_id = collection.save(doc, **write_concern) - # In PyMongo 3.0, the save() call calls internally the _update() call - # but they forget to return the _id value passed back, therefore getting it back here - # Correct behaviour in 2.X and in 3.0.1+ versions - if not object_id and pymongo.version_tuple == (3, 0): - pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) - object_id = ( - self._qs.filter(pk=pk_as_mongo_obj).first() and - self._qs.filter(pk=pk_as_mongo_obj).first().pk - ) # TODO doesn't this make 2 queries? + object_id = self._save_create(doc, force_insert, write_concern) else: - object_id = doc['_id'] - updates, removals = self._delta() - # Need to add shard key to query, or you get an error - if save_condition is not None: - select_dict = transform.query(self.__class__, - **save_condition) - else: - select_dict = {} - select_dict['_id'] = object_id - shard_key = self._meta.get('shard_key', tuple()) - for k in shard_key: - path = self._lookup_field(k.split('.')) - actual_key = [p.db_field for p in path] - val = doc - for ak in actual_key: - val = val[ak] - select_dict['.'.join(actual_key)] = val - - def is_new_object(last_error): - if last_error is not None: - updated = last_error.get('updatedExisting') - if updated is not None: - return not updated - return created - - update_query = {} - - if updates: - update_query['$set'] = updates - if removals: - update_query['$unset'] = removals - if updates or removals: - upsert = save_condition is None - last_error = collection.update(select_dict, update_query, - upsert=upsert, **write_concern) - if not upsert and last_error['n'] == 0: - raise SaveConditionError('Race condition preventing' - ' document update detected') - created = is_new_object(last_error) + object_id, created = self._save_update(doc, save_condition, + write_concern) if cascade is None: - cascade = self._meta.get( - 'cascade', False) or cascade_kwargs is not None + cascade = (self._meta.get('cascade', False) or + cascade_kwargs is not None) if cascade: kwargs = { @@ -406,6 +358,7 @@ class Document(BaseDocument): kwargs.update(cascade_kwargs) kwargs['_refs'] = _refs self.cascade_save(**kwargs) + except pymongo.errors.DuplicateKeyError as err: message = u'Tried to save duplicate unique keys (%s)' raise NotUniqueError(message % six.text_type(err)) @@ -418,16 +371,91 @@ class Document(BaseDocument): raise NotUniqueError(message % six.text_type(err)) raise OperationError(message % six.text_type(err)) + # Make sure we store the PK on this document now that it's saved id_field = self._meta['id_field'] if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) signals.post_save.send(self.__class__, document=self, created=created, **signal_kwargs) + self._clear_changed_fields() self._created = False + return self + def _save_create(self, doc, force_insert, write_concern): + """Save a new document. + + Helper method, should only be used inside save(). + """ + collection = self._get_collection() + + if force_insert: + return collection.insert(doc, **write_concern) + + object_id = collection.save(doc, **write_concern) + + # In PyMongo 3.0, the save() call calls internally the _update() call + # but they forget to return the _id value passed back, therefore getting it back here + # Correct behaviour in 2.X and in 3.0.1+ versions + if not object_id and pymongo.version_tuple == (3, 0): + pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk) + object_id = ( + self._qs.filter(pk=pk_as_mongo_obj).first() and + self._qs.filter(pk=pk_as_mongo_obj).first().pk + ) # TODO doesn't this make 2 queries? + + return object_id + + def _save_update(self, doc, save_condition, write_concern): + """Update an existing document. + + Helper method, should only be used inside save(). + """ + collection = self._get_collection() + object_id = doc['_id'] + created = False + + select_dict = {} + if save_condition is not None: + select_dict = transform.query(self.__class__, **save_condition) + + select_dict['_id'] = object_id + + # Need to add shard key to query, or you get an error + shard_key = self._meta.get('shard_key', tuple()) + for k in shard_key: + path = self._lookup_field(k.split('.')) + actual_key = [p.db_field for p in path] + val = doc + for ak in actual_key: + val = val[ak] + select_dict['.'.join(actual_key)] = val + + updates, removals = self._delta() + update_query = {} + if updates: + update_query['$set'] = updates + if removals: + update_query['$unset'] = removals + if updates or removals: + upsert = save_condition is None + last_error = collection.update(select_dict, update_query, + upsert=upsert, **write_concern) + if not upsert and last_error['n'] == 0: + raise SaveConditionError('Race condition preventing' + ' document update detected') + if last_error is not None: + updated_existing = last_error.get('updatedExisting') + if updated_existing is False: + created = True + # !!! This is bad, means we accidentally created a new, + # potentially corrupted document. See + # https://github.com/MongoEngine/mongoengine/issues/564 + + return object_id, created + def cascade_save(self, **kwargs): """Recursively save any references and generic references on the document.