Compare commits
7 Commits
test-write
...
cleaner-sa
Author | SHA1 | Date | |
---|---|---|---|
|
a8889b6dfb | ||
|
d05301b3a1 | ||
|
a120eae5ae | ||
|
3d75573889 | ||
|
c6240ca415 | ||
|
74b37d11cf | ||
|
e07cb82c15 |
@@ -41,7 +41,7 @@ class BaseField(object):
|
|||||||
"""
|
"""
|
||||||
:param db_field: The database field to store this field in
|
:param db_field: The database field to store this field in
|
||||||
(defaults to the name of the field)
|
(defaults to the name of the field)
|
||||||
:param name: Depreciated - use db_field
|
:param name: Deprecated - use db_field
|
||||||
:param required: If the field is required. Whether it has to have a
|
:param required: If the field is required. Whether it has to have a
|
||||||
value or not. Defaults to False.
|
value or not. Defaults to False.
|
||||||
:param default: (optional) The default value for this field if no value
|
:param default: (optional) The default value for this field if no value
|
||||||
@@ -81,6 +81,17 @@ class BaseField(object):
|
|||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
self._owner_document = None
|
self._owner_document = None
|
||||||
|
|
||||||
|
# Validate the db_field
|
||||||
|
if isinstance(self.db_field, six.string_types) and (
|
||||||
|
'.' in self.db_field or
|
||||||
|
'\0' in self.db_field or
|
||||||
|
self.db_field.startswith('$')
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
'field names cannot contain dots (".") or null characters '
|
||||||
|
'("\\0"), and they must not start with a dollar sign ("$").'
|
||||||
|
)
|
||||||
|
|
||||||
# Detect and report conflicts between metadata and base properties.
|
# Detect and report conflicts between metadata and base properties.
|
||||||
conflicts = set(dir(self)) & set(kwargs)
|
conflicts = set(dir(self)) & set(kwargs)
|
||||||
if conflicts:
|
if conflicts:
|
||||||
|
@@ -332,68 +332,20 @@ class Document(BaseDocument):
|
|||||||
signals.pre_save_post_validation.send(self.__class__, document=self,
|
signals.pre_save_post_validation.send(self.__class__, document=self,
|
||||||
created=created, **signal_kwargs)
|
created=created, **signal_kwargs)
|
||||||
|
|
||||||
|
if self._meta.get('auto_create_index', True):
|
||||||
|
self.ensure_indexes()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
collection = self._get_collection()
|
# Save a new document or update an existing one
|
||||||
if self._meta.get('auto_create_index', True):
|
|
||||||
self.ensure_indexes()
|
|
||||||
if created:
|
if created:
|
||||||
if force_insert:
|
object_id = self._save_create(doc, force_insert, write_concern)
|
||||||
object_id = collection.insert(doc, **write_concern)
|
|
||||||
else:
|
|
||||||
object_id = collection.save(doc, **write_concern)
|
|
||||||
# In PyMongo 3.0, the save() call calls internally the _update() call
|
|
||||||
# but they forget to return the _id value passed back, therefore getting it back here
|
|
||||||
# Correct behaviour in 2.X and in 3.0.1+ versions
|
|
||||||
if not object_id and pymongo.version_tuple == (3, 0):
|
|
||||||
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
|
|
||||||
object_id = (
|
|
||||||
self._qs.filter(pk=pk_as_mongo_obj).first() and
|
|
||||||
self._qs.filter(pk=pk_as_mongo_obj).first().pk
|
|
||||||
) # TODO doesn't this make 2 queries?
|
|
||||||
else:
|
else:
|
||||||
object_id = doc['_id']
|
object_id, created = self._save_update(doc, save_condition,
|
||||||
updates, removals = self._delta()
|
write_concern)
|
||||||
# Need to add shard key to query, or you get an error
|
|
||||||
if save_condition is not None:
|
|
||||||
select_dict = transform.query(self.__class__,
|
|
||||||
**save_condition)
|
|
||||||
else:
|
|
||||||
select_dict = {}
|
|
||||||
select_dict['_id'] = object_id
|
|
||||||
shard_key = self._meta.get('shard_key', tuple())
|
|
||||||
for k in shard_key:
|
|
||||||
path = self._lookup_field(k.split('.'))
|
|
||||||
actual_key = [p.db_field for p in path]
|
|
||||||
val = doc
|
|
||||||
for ak in actual_key:
|
|
||||||
val = val[ak]
|
|
||||||
select_dict['.'.join(actual_key)] = val
|
|
||||||
|
|
||||||
def is_new_object(last_error):
|
|
||||||
if last_error is not None:
|
|
||||||
updated = last_error.get('updatedExisting')
|
|
||||||
if updated is not None:
|
|
||||||
return not updated
|
|
||||||
return created
|
|
||||||
|
|
||||||
update_query = {}
|
|
||||||
|
|
||||||
if updates:
|
|
||||||
update_query['$set'] = updates
|
|
||||||
if removals:
|
|
||||||
update_query['$unset'] = removals
|
|
||||||
if updates or removals:
|
|
||||||
upsert = save_condition is None
|
|
||||||
last_error = collection.update(select_dict, update_query,
|
|
||||||
upsert=upsert, **write_concern)
|
|
||||||
if not upsert and last_error['n'] == 0:
|
|
||||||
raise SaveConditionError('Race condition preventing'
|
|
||||||
' document update detected')
|
|
||||||
created = is_new_object(last_error)
|
|
||||||
|
|
||||||
if cascade is None:
|
if cascade is None:
|
||||||
cascade = self._meta.get(
|
cascade = (self._meta.get('cascade', False) or
|
||||||
'cascade', False) or cascade_kwargs is not None
|
cascade_kwargs is not None)
|
||||||
|
|
||||||
if cascade:
|
if cascade:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -406,6 +358,7 @@ class Document(BaseDocument):
|
|||||||
kwargs.update(cascade_kwargs)
|
kwargs.update(cascade_kwargs)
|
||||||
kwargs['_refs'] = _refs
|
kwargs['_refs'] = _refs
|
||||||
self.cascade_save(**kwargs)
|
self.cascade_save(**kwargs)
|
||||||
|
|
||||||
except pymongo.errors.DuplicateKeyError as err:
|
except pymongo.errors.DuplicateKeyError as err:
|
||||||
message = u'Tried to save duplicate unique keys (%s)'
|
message = u'Tried to save duplicate unique keys (%s)'
|
||||||
raise NotUniqueError(message % six.text_type(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
@@ -418,16 +371,91 @@ class Document(BaseDocument):
|
|||||||
raise NotUniqueError(message % six.text_type(err))
|
raise NotUniqueError(message % six.text_type(err))
|
||||||
raise OperationError(message % six.text_type(err))
|
raise OperationError(message % six.text_type(err))
|
||||||
|
|
||||||
|
# Make sure we store the PK on this document now that it's saved
|
||||||
id_field = self._meta['id_field']
|
id_field = self._meta['id_field']
|
||||||
if created or id_field not in self._meta.get('shard_key', []):
|
if created or id_field not in self._meta.get('shard_key', []):
|
||||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||||
|
|
||||||
signals.post_save.send(self.__class__, document=self,
|
signals.post_save.send(self.__class__, document=self,
|
||||||
created=created, **signal_kwargs)
|
created=created, **signal_kwargs)
|
||||||
|
|
||||||
self._clear_changed_fields()
|
self._clear_changed_fields()
|
||||||
self._created = False
|
self._created = False
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _save_create(self, doc, force_insert, write_concern):
|
||||||
|
"""Save a new document.
|
||||||
|
|
||||||
|
Helper method, should only be used inside save().
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
|
||||||
|
if force_insert:
|
||||||
|
return collection.insert(doc, **write_concern)
|
||||||
|
|
||||||
|
object_id = collection.save(doc, **write_concern)
|
||||||
|
|
||||||
|
# In PyMongo 3.0, the save() call calls internally the _update() call
|
||||||
|
# but they forget to return the _id value passed back, therefore getting it back here
|
||||||
|
# Correct behaviour in 2.X and in 3.0.1+ versions
|
||||||
|
if not object_id and pymongo.version_tuple == (3, 0):
|
||||||
|
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
|
||||||
|
object_id = (
|
||||||
|
self._qs.filter(pk=pk_as_mongo_obj).first() and
|
||||||
|
self._qs.filter(pk=pk_as_mongo_obj).first().pk
|
||||||
|
) # TODO doesn't this make 2 queries?
|
||||||
|
|
||||||
|
return object_id
|
||||||
|
|
||||||
|
def _save_update(self, doc, save_condition, write_concern):
|
||||||
|
"""Update an existing document.
|
||||||
|
|
||||||
|
Helper method, should only be used inside save().
|
||||||
|
"""
|
||||||
|
collection = self._get_collection()
|
||||||
|
object_id = doc['_id']
|
||||||
|
created = False
|
||||||
|
|
||||||
|
select_dict = {}
|
||||||
|
if save_condition is not None:
|
||||||
|
select_dict = transform.query(self.__class__, **save_condition)
|
||||||
|
|
||||||
|
select_dict['_id'] = object_id
|
||||||
|
|
||||||
|
# Need to add shard key to query, or you get an error
|
||||||
|
shard_key = self._meta.get('shard_key', tuple())
|
||||||
|
for k in shard_key:
|
||||||
|
path = self._lookup_field(k.split('.'))
|
||||||
|
actual_key = [p.db_field for p in path]
|
||||||
|
val = doc
|
||||||
|
for ak in actual_key:
|
||||||
|
val = val[ak]
|
||||||
|
select_dict['.'.join(actual_key)] = val
|
||||||
|
|
||||||
|
updates, removals = self._delta()
|
||||||
|
update_query = {}
|
||||||
|
if updates:
|
||||||
|
update_query['$set'] = updates
|
||||||
|
if removals:
|
||||||
|
update_query['$unset'] = removals
|
||||||
|
if updates or removals:
|
||||||
|
upsert = save_condition is None
|
||||||
|
last_error = collection.update(select_dict, update_query,
|
||||||
|
upsert=upsert, **write_concern)
|
||||||
|
if not upsert and last_error['n'] == 0:
|
||||||
|
raise SaveConditionError('Race condition preventing'
|
||||||
|
' document update detected')
|
||||||
|
if last_error is not None:
|
||||||
|
updated_existing = last_error.get('updatedExisting')
|
||||||
|
if updated_existing is False:
|
||||||
|
created = True
|
||||||
|
# !!! This is bad, means we accidentally created a new,
|
||||||
|
# potentially corrupted document. See
|
||||||
|
# https://github.com/MongoEngine/mongoengine/issues/564
|
||||||
|
|
||||||
|
return object_id, created
|
||||||
|
|
||||||
def cascade_save(self, **kwargs):
|
def cascade_save(self, **kwargs):
|
||||||
"""Recursively save any references and generic references on the
|
"""Recursively save any references and generic references on the
|
||||||
document.
|
document.
|
||||||
|
@@ -306,6 +306,24 @@ class FieldTest(unittest.TestCase):
|
|||||||
person.id = '497ce96f395f2f052a494fd4'
|
person.id = '497ce96f395f2f052a494fd4'
|
||||||
person.validate()
|
person.validate()
|
||||||
|
|
||||||
|
def test_db_field_validation(self):
|
||||||
|
"""Ensure that db_field doesn't accept invalid values."""
|
||||||
|
|
||||||
|
# dot in the name
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
class User(Document):
|
||||||
|
name = StringField(db_field='user.name')
|
||||||
|
|
||||||
|
# name starting with $
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
class User(Document):
|
||||||
|
name = StringField(db_field='$name')
|
||||||
|
|
||||||
|
# name containing a null character
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
class User(Document):
|
||||||
|
name = StringField(db_field='name\0')
|
||||||
|
|
||||||
def test_string_validation(self):
|
def test_string_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to string fields.
|
"""Ensure that invalid values cannot be assigned to string fields.
|
||||||
"""
|
"""
|
||||||
@@ -3973,30 +3991,25 @@ class FieldTest(unittest.TestCase):
|
|||||||
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
|
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
|
||||||
instanciate a document with a field that's not defined.
|
instanciate a document with a field that's not defined.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
foo = StringField(db_field='f')
|
foo = StringField()
|
||||||
|
|
||||||
def test():
|
with self.assertRaises(FieldDoesNotExist):
|
||||||
Doc(bar='test')
|
Doc(bar='test')
|
||||||
|
|
||||||
self.assertRaises(FieldDoesNotExist, test)
|
|
||||||
|
|
||||||
def test_undefined_field_exception_with_strict(self):
|
def test_undefined_field_exception_with_strict(self):
|
||||||
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
|
"""Tests if a `FieldDoesNotExist` exception is raised when trying to
|
||||||
instanciate a document with a field that's not defined,
|
instanciate a document with a field that's not defined,
|
||||||
even when strict is set to False.
|
even when strict is set to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Doc(Document):
|
class Doc(Document):
|
||||||
foo = StringField(db_field='f')
|
foo = StringField()
|
||||||
meta = {'strict': False}
|
meta = {'strict': False}
|
||||||
|
|
||||||
def test():
|
with self.assertRaises(FieldDoesNotExist):
|
||||||
Doc(bar='test')
|
Doc(bar='test')
|
||||||
|
|
||||||
self.assertRaises(FieldDoesNotExist, test)
|
|
||||||
|
|
||||||
def test_long_field_is_considered_as_int64(self):
|
def test_long_field_is_considered_as_int64(self):
|
||||||
"""
|
"""
|
||||||
Tests that long fields are stored as long in mongo, even if long value
|
Tests that long fields are stored as long in mongo, even if long value
|
||||||
|
@@ -296,6 +296,19 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
conn = get_connection('t2')
|
conn = get_connection('t2')
|
||||||
self.assertFalse(get_tz_awareness(conn))
|
self.assertFalse(get_tz_awareness(conn))
|
||||||
|
|
||||||
|
def test_write_concern(self):
|
||||||
|
"""Ensure write concern can be specified in connect() via
|
||||||
|
a kwarg or as part of the connection URI.
|
||||||
|
"""
|
||||||
|
conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true')
|
||||||
|
conn2 = connect('testing', alias='conn2', w=1, j=True)
|
||||||
|
if IS_PYMONGO_3:
|
||||||
|
self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
|
||||||
|
self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
|
||||||
|
else:
|
||||||
|
self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True})
|
||||||
|
self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True})
|
||||||
|
|
||||||
def test_datetime(self):
|
def test_datetime(self):
|
||||||
connect('mongoenginetest', tz_aware=True)
|
connect('mongoenginetest', tz_aware=True)
|
||||||
d = datetime.datetime(2010, 5, 5, tzinfo=utc)
|
d = datetime.datetime(2010, 5, 5, tzinfo=utc)
|
||||||
|
Reference in New Issue
Block a user