Compare commits

...

6 Commits

Author SHA1 Message Date
Stefan Wojcik
98e1df0c45 Add continue_on_error optional kwarg to QuerySet.insert 2017-01-14 23:04:55 -05:00
Eli Boyarski
e5acbcc0dd Improved a docstring for FieldDoesNotExist (#1466) 2017-01-09 11:24:27 -05:00
Stefan Wojcik
1b6743ee53 add a changelog entry about broken references raising DoesNotExist 2017-01-08 14:50:16 -05:00
Eli Boyarski
b5fb82d95d Typo fix (#1463) 2017-01-08 12:57:36 -05:00
lanf0n
193aa4e1f2 [#1459] fix typo __neq__ to __ne__ (#1461) 2017-01-05 22:37:09 -05:00
Stefan Wójcik
ebd34427c7 Cleaner Document.save (#1458) 2016-12-30 05:43:56 -05:00
7 changed files with 144 additions and 80 deletions

View File

@@ -13,6 +13,7 @@ Changes in 0.11.0
- BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428 - BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428
- BREAKING CHANGE: Dropped Python 2.6 support. #1428 - BREAKING CHANGE: Dropped Python 2.6 support. #1428
- BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428 - BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428
- BREAKING CHANGE: Accessing a broken reference will raise a `DoesNotExist` error. In the past it used to return `None`. #1334
- Fixed absent rounding for DecimalField when `force_string` is set. #1103 - Fixed absent rounding for DecimalField when `force_string` is set. #1103
Changes in 0.10.8 Changes in 0.10.8

View File

@@ -429,7 +429,7 @@ class StrictDict(object):
def __eq__(self, other): def __eq__(self, other):
return self.items() == other.items() return self.items() == other.items()
def __neq__(self, other): def __ne__(self, other):
return self.items() != other.items() return self.items() != other.items()
@classmethod @classmethod

View File

@@ -332,68 +332,20 @@ class Document(BaseDocument):
signals.pre_save_post_validation.send(self.__class__, document=self, signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs) created=created, **signal_kwargs)
try:
collection = self._get_collection()
if self._meta.get('auto_create_index', True): if self._meta.get('auto_create_index', True):
self.ensure_indexes() self.ensure_indexes()
try:
# Save a new document or update an existing one
if created: if created:
if force_insert: object_id = self._save_create(doc, force_insert, write_concern)
object_id = collection.insert(doc, **write_concern)
else: else:
object_id = collection.save(doc, **write_concern) object_id, created = self._save_update(doc, save_condition,
# In PyMongo 3.0, the save() call calls internally the _update() call write_concern)
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
else:
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get('updatedExisting')
if updated is not None:
return not updated
return created
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
if cascade is None: if cascade is None:
cascade = self._meta.get( cascade = (self._meta.get('cascade', False) or
'cascade', False) or cascade_kwargs is not None cascade_kwargs is not None)
if cascade: if cascade:
kwargs = { kwargs = {
@@ -406,6 +358,7 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs) kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs kwargs['_refs'] = _refs
self.cascade_save(**kwargs) self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))
@@ -418,16 +371,91 @@ class Document(BaseDocument):
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err)) raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
id_field = self._meta['id_field'] id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []): if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs) created=created, **signal_kwargs)
self._clear_changed_fields() self._clear_changed_fields()
self._created = False self._created = False
return self return self
def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
if force_insert:
return collection.insert(doc, **write_concern)
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
return object_id
def _save_update(self, doc, save_condition, write_concern):
"""Update an existing document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
object_id = doc['_id']
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
select_dict['_id'] = object_id
# Need to add shard key to query, or you get an error
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
updates, removals = self._delta()
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
if last_error is not None:
updated_existing = last_error.get('updatedExisting')
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
# potentially corrupted document. See
# https://github.com/MongoEngine/mongoengine/issues/564
return object_id, created
def cascade_save(self, **kwargs): def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the """Recursively save any references and generic references on the
document. document.

View File

@@ -50,8 +50,8 @@ class FieldDoesNotExist(Exception):
or an :class:`~mongoengine.EmbeddedDocument`. or an :class:`~mongoengine.EmbeddedDocument`.
To avoid this behavior on data loading, To avoid this behavior on data loading,
you should the :attr:`strict` to ``False`` you should set the :attr:`strict` to ``False``
in the :attr:`meta` dictionnary. in the :attr:`meta` dictionary.
""" """

View File

@@ -296,22 +296,25 @@ class BaseQuerySet(object):
result = None result = None
return result return result
def insert(self, doc_or_docs, load_bulk=True, def insert(self, doc_or_docs, load_bulk=True, write_concern=None,
write_concern=None, signal_kwargs=None): signal_kwargs=None, continue_on_error=None):
"""bulk insert documents """bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted :param doc_or_docs: a document or list of documents to be inserted
:param load_bulk (optional): If True returns the list of document :param load_bulk (optional): If True returns the list of document
instances instances
:param write_concern: Extra keyword arguments are passed down to :param write_concern: Optional keyword argument passed down to
:meth:`~pymongo.collection.Collection.insert` :meth:`~pymongo.collection.Collection.insert`, representing
which will be used as options for the resultant the write concern. For example,
``getLastError`` command. For example, ``insert(..., write_concert={w: 2, fsync: True})`` will
``insert(..., {w: 2, fsync: True})`` will wait until at least wait until at least two servers have recorded the write
two servers have recorded the write and will force an fsync on and will force an fsync on each server being written to.
each server being written to.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to :parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls. the signal calls.
:param continue_on_error: Optional keyword argument passed down to
:meth:`~pymongo.collection.Collection.insert`. Defines what
to do when a document cannot be inserted (e.g. due to
duplicate IDs). Read PyMongo's docs for more info.
By default returns document instances, set ``load_bulk`` to False to By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds`` return just ``ObjectIds``
@@ -322,12 +325,10 @@ class BaseQuerySet(object):
""" """
Document = _import_class('Document') Document = _import_class('Document')
if write_concern is None: # Determine if we're inserting one doc or more
write_concern = {}
docs = doc_or_docs docs = doc_or_docs
return_one = False return_one = False
if isinstance(docs, Document) or issubclass(docs.__class__, Document): if isinstance(docs, Document):
return_one = True return_one = True
docs = [docs] docs = [docs]
@@ -344,9 +345,16 @@ class BaseQuerySet(object):
signals.pre_bulk_insert.send(self._document, signals.pre_bulk_insert.send(self._document,
documents=docs, **signal_kwargs) documents=docs, **signal_kwargs)
# Resolve optional insert kwargs
insert_kwargs = {}
if write_concern is not None:
insert_kwargs.update(write_concern)
if continue_on_error is not None:
insert_kwargs['continue_on_error'] = continue_on_error
raw = [doc.to_mongo() for doc in docs] raw = [doc.to_mongo() for doc in docs]
try: try:
ids = self._collection.insert(raw, **write_concern) ids = self._collection.insert(raw, **insert_kwargs)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))

View File

@@ -1985,7 +1985,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(content, User.objects.first().groups[0].content) self.assertEqual(content, User.objects.first().groups[0].content)
def test_reference_miss(self): def test_reference_miss(self):
"""Ensure an exception is raised when dereferencing unknow document """Ensure an exception is raised when dereferencing unknown document
""" """
class Foo(Document): class Foo(Document):

View File

@@ -766,8 +766,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(record.embed.field, 2) self.assertEqual(record.embed.field, 2)
def test_bulk_insert(self): def test_bulk_insert(self):
"""Ensure that bulk insert works """Ensure that bulk insert works."""
"""
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
name = StringField() name = StringField()
@@ -885,9 +884,37 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], def test_bulk_insert_continue_on_error(self):
write_concern={"w": 0, 'continue_on_error': True}) """Ensure that bulk insert works with the continue_on_error option."""
self.assertEqual(Blog.objects.count(), 3)
class Person(Document):
email = EmailField(unique=True)
Person.drop_collection()
Person.objects.insert([
Person(email='alice@example.com'),
Person(email='bob@example.com')
])
self.assertEqual(Person.objects.count(), 2)
new_docs = [
Person(email='alice@example.com'), # dupe
Person(email='bob@example.com'), # dupe
Person(email='steve@example.com') # new one
]
# By default inserting dupe docs should fail and no new docs should
# be inserted.
with self.assertRaises(NotUniqueError):
Person.objects.insert(new_docs)
self.assertEqual(Person.objects.count(), 2)
# With continue_on_error, new doc should be inserted, even though we
# still get a NotUniqueError caused by the other 2 dupes.
with self.assertRaises(NotUniqueError):
Person.objects.insert(new_docs, continue_on_error=True)
self.assertEqual(Person.objects.count(), 3)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):