Compare commits

..

2 Commits

Author SHA1 Message Date
Stefan Wojcik
3e44ca4577 fix pymongo < 3 2016-12-29 12:12:50 -05:00
Stefan Wojcik
d378bd09a5 test write concern 2016-12-29 10:45:52 -05:00
14 changed files with 183 additions and 396 deletions

View File

@@ -35,22 +35,16 @@ setup.py install``.
Dependencies Dependencies
============ ============
All of the dependencies can easily be installed via `pip <https://pip.pypa.io/>`_. At the very least, you'll need these two packages to use MongoEngine:
- pymongo>=2.7.1 - pymongo>=2.7.1
- six>=1.10.0 - sphinx (optional - for documentation generation)
If you utilize a ``DateTimeField``, you might also use a more flexible date parser:
Optional Dependencies
---------------------
- **Image Fields**: Pillow>=2.0.0
- dateutil>=2.1.0 - dateutil>=2.1.0
If you need to use an ``ImageField`` or ``ImageGridFsProxy``: .. note
MongoEngine always runs it's test suite against the latest patch version of each dependecy. e.g.: PyMongo 3.0.1
- Pillow>=2.0.0
If you want to generate the documentation (e.g. to contribute to it):
- sphinx
Examples Examples
======== ========

View File

@@ -5,8 +5,6 @@ Changelog
Development Development
=========== ===========
- (Fill this out as you fix issues and develop you features). - (Fill this out as you fix issues and develop you features).
- POTENTIAL BREAKING CHANGE: Fixed limit/skip/hint/batch_size chaining #1476
- POTENTIAL BREAKING CHANGE: Changed a public `QuerySet.clone_into` method to a private `QuerySet._clone_into` #1476
- Fixed connecting to a replica set with PyMongo 2.x #1436 - Fixed connecting to a replica set with PyMongo 2.x #1436
- Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237 - Fixed an obscure error message when filtering by `field__in=non_iterable`. #1237
@@ -15,7 +13,6 @@ Changes in 0.11.0
- BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428 - BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428
- BREAKING CHANGE: Dropped Python 2.6 support. #1428 - BREAKING CHANGE: Dropped Python 2.6 support. #1428
- BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428 - BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428
- BREAKING CHANGE: Accessing a broken reference will raise a `DoesNotExist` error. In the past it used to return `None`. #1334
- Fixed absent rounding for DecimalField when `force_string` is set. #1103 - Fixed absent rounding for DecimalField when `force_string` is set. #1103
Changes in 0.10.8 Changes in 0.10.8

View File

@@ -361,6 +361,11 @@ Its value can take any of the following constants:
In Django, be sure to put all apps that have such delete rule declarations in In Django, be sure to put all apps that have such delete rule declarations in
their :file:`models.py` in the :const:`INSTALLED_APPS` tuple. their :file:`models.py` in the :const:`INSTALLED_APPS` tuple.
.. warning::
Signals are not triggered when doing cascading updates / deletes - if this
is required you must manually handle the update / delete.
Generic reference fields Generic reference fields
'''''''''''''''''''''''' ''''''''''''''''''''''''
A second kind of reference field also exists, A second kind of reference field also exists,

View File

@@ -142,4 +142,11 @@ cleaner looking while still allowing manual execution of the callback::
modified = DateTimeField() modified = DateTimeField()
ReferenceFields and Signals
---------------------------
Currently `reverse_delete_rule` does not trigger signals on the other part of
the relationship. If this is required you must manually handle the
reverse deletion.
.. _blinker: http://pypi.python.org/pypi/blinker .. _blinker: http://pypi.python.org/pypi/blinker

View File

@@ -2,20 +2,6 @@
Upgrading Upgrading
######### #########
Development
***********
(Fill this out whenever you introduce breaking changes to MongoEngine)
This release includes various fixes for the `BaseQuerySet` methods and how they
are chained together. Since version 0.10.1 applying limit/skip/hint/batch_size
to an already-existing queryset wouldn't modify the underlying PyMongo cursor.
This has been fixed now, so you'll need to make sure that your code didn't rely
on the broken implementation.
Additionally, a public `BaseQuerySet.clone_into` has been renamed to a private
`_clone_into`. If you directly used that method in your code, you'll need to
rename its occurrences.
0.11.0 0.11.0
****** ******
This release includes a major rehaul of MongoEngine's code quality and This release includes a major rehaul of MongoEngine's code quality and

View File

@@ -429,7 +429,7 @@ class StrictDict(object):
def __eq__(self, other): def __eq__(self, other):
return self.items() == other.items() return self.items() == other.items()
def __ne__(self, other): def __neq__(self, other):
return self.items() != other.items() return self.items() != other.items()
@classmethod @classmethod

View File

@@ -41,7 +41,7 @@ class BaseField(object):
""" """
:param db_field: The database field to store this field in :param db_field: The database field to store this field in
(defaults to the name of the field) (defaults to the name of the field)
:param name: Deprecated - use db_field :param name: Depreciated - use db_field
:param required: If the field is required. Whether it has to have a :param required: If the field is required. Whether it has to have a
value or not. Defaults to False. value or not. Defaults to False.
:param default: (optional) The default value for this field if no value :param default: (optional) The default value for this field if no value
@@ -81,17 +81,6 @@ class BaseField(object):
self.sparse = sparse self.sparse = sparse
self._owner_document = None self._owner_document = None
# Validate the db_field
if isinstance(self.db_field, six.string_types) and (
'.' in self.db_field or
'\0' in self.db_field or
self.db_field.startswith('$')
):
raise ValueError(
'field names cannot contain dots (".") or null characters '
'("\\0"), and they must not start with a dollar sign ("$").'
)
# Detect and report conflicts between metadata and base properties. # Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs) conflicts = set(dir(self)) & set(kwargs)
if conflicts: if conflicts:
@@ -189,18 +178,14 @@ class BaseField(object):
pass pass
def _validate_choices(self, value): def _validate_choices(self, value):
"""Validate that value is a valid choice for this field."""
Document = _import_class('Document') Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
# Field choices can be given as an iterable (e.g. tuple/list/set) of
# values or an iterable of value-label pairs, e.g. ('XS', 'Extra Small').
# It the latter case, extract just the values for comparison.
choice_list = self.choices choice_list = self.choices
if isinstance(next(iter(choice_list)), (list, tuple)): if isinstance(choice_list[0], (list, tuple)):
choice_list = [val for val, label in choice_list] choice_list = [k for k, _ in choice_list]
# Validate Document/EmbeddedDocument choices # Choices which are other types of Documents
if isinstance(value, (Document, EmbeddedDocument)): if isinstance(value, (Document, EmbeddedDocument)):
if not any(isinstance(value, c) for c in choice_list): if not any(isinstance(value, c) for c in choice_list):
self.error( self.error(
@@ -208,8 +193,7 @@ class BaseField(object):
six.text_type(choice_list) six.text_type(choice_list)
) )
) )
# Choices which are types other than Documents
# Validate any other type of choices
elif value not in choice_list: elif value not in choice_list:
self.error('Value must be one of %s' % six.text_type(choice_list)) self.error('Value must be one of %s' % six.text_type(choice_list))

View File

@@ -332,20 +332,68 @@ class Document(BaseDocument):
signals.pre_save_post_validation.send(self.__class__, document=self, signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs) created=created, **signal_kwargs)
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
try: try:
# Save a new document or update an existing one collection = self._get_collection()
if self._meta.get('auto_create_index', True):
self.ensure_indexes()
if created: if created:
object_id = self._save_create(doc, force_insert, write_concern) if force_insert:
object_id = collection.insert(doc, **write_concern)
else:
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
else: else:
object_id, created = self._save_update(doc, save_condition, object_id = doc['_id']
write_concern) updates, removals = self._delta()
# Need to add shard key to query, or you get an error
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
updated = last_error.get('updatedExisting')
if updated is not None:
return not updated
return created
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
if cascade is None: if cascade is None:
cascade = (self._meta.get('cascade', False) or cascade = self._meta.get(
cascade_kwargs is not None) 'cascade', False) or cascade_kwargs is not None
if cascade: if cascade:
kwargs = { kwargs = {
@@ -358,7 +406,6 @@ class Document(BaseDocument):
kwargs.update(cascade_kwargs) kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs kwargs['_refs'] = _refs
self.cascade_save(**kwargs) self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))
@@ -371,91 +418,16 @@ class Document(BaseDocument):
raise NotUniqueError(message % six.text_type(err)) raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err)) raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
id_field = self._meta['id_field'] id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []): if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs) created=created, **signal_kwargs)
self._clear_changed_fields() self._clear_changed_fields()
self._created = False self._created = False
return self return self
def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
if force_insert:
return collection.insert(doc, **write_concern)
object_id = collection.save(doc, **write_concern)
# In PyMongo 3.0, the save() call calls internally the _update() call
# but they forget to return the _id value passed back, therefore getting it back here
# Correct behaviour in 2.X and in 3.0.1+ versions
if not object_id and pymongo.version_tuple == (3, 0):
pk_as_mongo_obj = self._fields.get(self._meta['id_field']).to_mongo(self.pk)
object_id = (
self._qs.filter(pk=pk_as_mongo_obj).first() and
self._qs.filter(pk=pk_as_mongo_obj).first().pk
) # TODO doesn't this make 2 queries?
return object_id
def _save_update(self, doc, save_condition, write_concern):
"""Update an existing document.
Helper method, should only be used inside save().
"""
collection = self._get_collection()
object_id = doc['_id']
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
select_dict['_id'] = object_id
# Need to add shard key to query, or you get an error
shard_key = self._meta.get('shard_key', tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
updates, removals = self._delta()
update_query = {}
if updates:
update_query['$set'] = updates
if removals:
update_query['$unset'] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
if last_error is not None:
updated_existing = last_error.get('updatedExisting')
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
# potentially corrupted document. See
# https://github.com/MongoEngine/mongoengine/issues/564
return object_id, created
def cascade_save(self, **kwargs): def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the """Recursively save any references and generic references on the
document. document.

View File

@@ -50,8 +50,8 @@ class FieldDoesNotExist(Exception):
or an :class:`~mongoengine.EmbeddedDocument`. or an :class:`~mongoengine.EmbeddedDocument`.
To avoid this behavior on data loading, To avoid this behavior on data loading,
you should set the :attr:`strict` to ``False`` you should the :attr:`strict` to ``False``
in the :attr:`meta` dictionary. in the :attr:`meta` dictionnary.
""" """

View File

@@ -888,6 +888,10 @@ class ReferenceField(BaseField):
Foo.register_delete_rule(Bar, 'foo', NULLIFY) Foo.register_delete_rule(Bar, 'foo', NULLIFY)
.. note ::
`reverse_delete_rule` does not trigger pre / post delete signals to be
triggered.
.. versionchanged:: 0.5 added `reverse_delete_rule` .. versionchanged:: 0.5 added `reverse_delete_rule`
""" """

View File

@@ -86,7 +86,6 @@ class BaseQuerySet(object):
self._batch_size = None self._batch_size = None
self.only_fields = [] self.only_fields = []
self._max_time_ms = None self._max_time_ms = None
self._comment = None
def __call__(self, q_obj=None, class_check=True, read_preference=None, def __call__(self, q_obj=None, class_check=True, read_preference=None,
**query): **query):
@@ -707,36 +706,39 @@ class BaseQuerySet(object):
with switch_db(self._document, alias) as cls: with switch_db(self._document, alias) as cls:
collection = cls._get_collection() collection = cls._get_collection()
return self._clone_into(self.__class__(self._document, collection)) return self.clone_into(self.__class__(self._document, collection))
def clone(self): def clone(self):
"""Create a copy of the current queryset.""" """Creates a copy of the current
return self._clone_into(self.__class__(self._document, self._collection_obj)) :class:`~mongoengine.queryset.QuerySet`
def _clone_into(self, new_qs): .. versionadded:: 0.5
"""Copy all of the relevant properties of this queryset to
a new queryset (which has to be an instance of
:class:`~mongoengine.queryset.base.BaseQuerySet`).
""" """
if not isinstance(new_qs, BaseQuerySet): return self.clone_into(self.__class__(self._document, self._collection_obj))
def clone_into(self, cls):
"""Creates a copy of the current
:class:`~mongoengine.queryset.base.BaseQuerySet` into another child class
"""
if not isinstance(cls, BaseQuerySet):
raise OperationError( raise OperationError(
'%s is not a subclass of BaseQuerySet' % new_qs.__name__) '%s is not a subclass of BaseQuerySet' % cls.__name__)
copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
'_where_clause', '_loaded_fields', '_ordering', '_snapshot', '_where_clause', '_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_class_check', '_slave_okay', '_read_preference', '_timeout', '_class_check', '_slave_okay', '_read_preference',
'_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce',
'_limit', '_skip', '_hint', '_auto_dereference', '_limit', '_skip', '_hint', '_auto_dereference',
'_search_text', 'only_fields', '_max_time_ms', '_comment') '_search_text', 'only_fields', '_max_time_ms')
for prop in copy_props: for prop in copy_props:
val = getattr(self, prop) val = getattr(self, prop)
setattr(new_qs, prop, copy.copy(val)) setattr(cls, prop, copy.copy(val))
if self._cursor_obj: if self._cursor_obj:
new_qs._cursor_obj = self._cursor_obj.clone() cls._cursor_obj = self._cursor_obj.clone()
return new_qs return cls
def select_related(self, max_depth=1): def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects or """Handles dereferencing of :class:`~bson.dbref.DBRef` objects or
@@ -758,11 +760,7 @@ class BaseQuerySet(object):
""" """
queryset = self.clone() queryset = self.clone()
queryset._limit = n if n != 0 else 1 queryset._limit = n if n != 0 else 1
# Return self to allow chaining
# If a cursor object has already been created, apply the limit to it.
if queryset._cursor_obj:
queryset._cursor_obj.limit(queryset._limit)
return queryset return queryset
def skip(self, n): def skip(self, n):
@@ -773,11 +771,6 @@ class BaseQuerySet(object):
""" """
queryset = self.clone() queryset = self.clone()
queryset._skip = n queryset._skip = n
# If a cursor object has already been created, apply the skip to it.
if queryset._cursor_obj:
queryset._cursor_obj.skip(queryset._skip)
return queryset return queryset
def hint(self, index=None): def hint(self, index=None):
@@ -795,11 +788,6 @@ class BaseQuerySet(object):
""" """
queryset = self.clone() queryset = self.clone()
queryset._hint = index queryset._hint = index
# If a cursor object has already been created, apply the hint to it.
if queryset._cursor_obj:
queryset._cursor_obj.hint(queryset._hint)
return queryset return queryset
def batch_size(self, size): def batch_size(self, size):
@@ -813,11 +801,6 @@ class BaseQuerySet(object):
""" """
queryset = self.clone() queryset = self.clone()
queryset._batch_size = size queryset._batch_size = size
# If a cursor object has already been created, apply the batch size to it.
if queryset._cursor_obj:
queryset._cursor_obj.batch_size(queryset._batch_size)
return queryset return queryset
def distinct(self, field): def distinct(self, field):
@@ -989,31 +972,13 @@ class BaseQuerySet(object):
def order_by(self, *keys): def order_by(self, *keys):
"""Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The
order may be specified by prepending each of the keys by a + or a -. order may be specified by prepending each of the keys by a + or a -.
Ascending order is assumed. If no keys are passed, existing ordering Ascending order is assumed.
is cleared instead.
:param keys: fields to order the query results by; keys may be :param keys: fields to order the query results by; keys may be
prefixed with **+** or **-** to determine the ordering direction prefixed with **+** or **-** to determine the ordering direction
""" """
queryset = self.clone() queryset = self.clone()
queryset._ordering = queryset._get_order_by(keys)
old_ordering = queryset._ordering
new_ordering = queryset._get_order_by(keys)
if queryset._cursor_obj:
# If a cursor object has already been created, apply the sort to it
if new_ordering:
queryset._cursor_obj.sort(new_ordering)
# If we're trying to clear a previous explicit ordering, we need
# to clear the cursor entirely (because PyMongo doesn't allow
# clearing an existing sort on a cursor).
elif old_ordering:
queryset._cursor_obj = None
queryset._ordering = new_ordering
return queryset return queryset
def comment(self, text): def comment(self, text):
@@ -1459,13 +1424,10 @@ class BaseQuerySet(object):
raise StopIteration raise StopIteration
raw_doc = self._cursor.next() raw_doc = self._cursor.next()
if self._as_pymongo: if self._as_pymongo:
return self._get_as_pymongo(raw_doc) return self._get_as_pymongo(raw_doc)
doc = self._document._from_son(raw_doc,
doc = self._document._from_son( _auto_dereference=self._auto_dereference, only_fields=self.only_fields)
raw_doc, _auto_dereference=self._auto_dereference,
only_fields=self.only_fields)
if self._scalar: if self._scalar:
return self._get_scalar(doc) return self._get_scalar(doc)
@@ -1475,6 +1437,7 @@ class BaseQuerySet(object):
def rewind(self): def rewind(self):
"""Rewind the cursor to its unevaluated state. """Rewind the cursor to its unevaluated state.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
self._iter = False self._iter = False
@@ -1524,54 +1487,43 @@ class BaseQuerySet(object):
@property @property
def _cursor(self): def _cursor(self):
"""Return a PyMongo cursor object corresponding to this queryset.""" if self._cursor_obj is None:
# If _cursor_obj already exists, return it immediately. # In PyMongo 3+, we define the read preference on a collection
if self._cursor_obj is not None: # level, not a cursor level. Thus, we need to get a cloned
return self._cursor_obj # collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# Create a new PyMongo cursor. if self._ordering:
# XXX In PyMongo 3+, we define the read preference on a collection # Apply query ordering
# level, not a cursor level. Thus, we need to get a cloned collection self._cursor_obj.sort(self._ordering)
# object using `with_options` first. elif self._ordering is None and self._document._meta['ordering']:
if IS_PYMONGO_3 and self._read_preference is not None: # Otherwise, apply the ordering from the document model, unless
self._cursor_obj = self._collection\ # it's been explicitly cleared via order_by with no arguments
.with_options(read_preference=self._read_preference)\ order = self._get_order_by(self._document._meta['ordering'])
.find(self._query, **self._cursor_args) self._cursor_obj.sort(order)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
self._cursor_obj.where(where_clause)
# Apply ordering to the cursor. if self._limit is not None:
# XXX self._ordering can be equal to: self._cursor_obj.limit(self._limit)
# * None if we didn't explicitly call order_by on this queryset.
# * A list of PyMongo-style sorting tuples.
# * An empty list if we explicitly called order_by() without any
# arguments. This indicates that we want to clear the default
# ordering.
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
elif self._ordering is None and self._document._meta['ordering']:
# default ordering
order = self._get_order_by(self._document._meta['ordering'])
self._cursor_obj.sort(order)
if self._limit is not None: if self._skip is not None:
self._cursor_obj.limit(self._limit) self._cursor_obj.skip(self._skip)
if self._skip is not None: if self._hint != -1:
self._cursor_obj.skip(self._skip) self._cursor_obj.hint(self._hint)
if self._hint != -1: if self._batch_size is not None:
self._cursor_obj.hint(self._hint) self._cursor_obj.batch_size(self._batch_size)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj return self._cursor_obj
@@ -1746,13 +1698,7 @@ class BaseQuerySet(object):
return ret return ret
def _get_order_by(self, keys): def _get_order_by(self, keys):
"""Given a list of MongoEngine-style sort keys, return a list """Creates a list of order by fields"""
of sorting tuples that can be applied to a PyMongo cursor. For
example:
>>> qs._get_order_by(['-last_name', 'first_name'])
[('last_name', -1), ('first_name', 1)]
"""
key_list = [] key_list = []
for key in keys: for key in keys:
if not key: if not key:
@@ -1765,19 +1711,17 @@ class BaseQuerySet(object):
direction = pymongo.ASCENDING direction = pymongo.ASCENDING
if key[0] == '-': if key[0] == '-':
direction = pymongo.DESCENDING direction = pymongo.DESCENDING
if key[0] in ('-', '+'): if key[0] in ('-', '+'):
key = key[1:] key = key[1:]
key = key.replace('__', '.') key = key.replace('__', '.')
try: try:
key = self._document._translate_field_name(key) key = self._document._translate_field_name(key)
except Exception: except Exception:
# TODO this exception should be more specific
pass pass
key_list.append((key, direction)) key_list.append((key, direction))
if self._cursor_obj and key_list:
self._cursor_obj.sort(key_list)
return key_list return key_list
def _get_scalar(self, doc): def _get_scalar(self, doc):
@@ -1875,21 +1819,10 @@ class BaseQuerySet(object):
return code return code
def _chainable_method(self, method_name, val): def _chainable_method(self, method_name, val):
"""Call a particular method on the PyMongo cursor call a particular chainable method
with the provided value.
"""
queryset = self.clone() queryset = self.clone()
method = getattr(queryset._cursor, method_name)
# Get an existing cursor object or create a new one method(val)
cursor = queryset._cursor
# Find the requested method on the cursor and call it with the
# provided value
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
setattr(queryset, '_' + method_name, val) setattr(queryset, '_' + method_name, val)
return queryset return queryset
# Deprecated # Deprecated

View File

@@ -136,15 +136,13 @@ class QuerySet(BaseQuerySet):
return self._len return self._len
def no_cache(self): def no_cache(self):
"""Convert to a non-caching queryset """Convert to a non_caching queryset
.. versionadded:: 0.8.3 Convert to non caching queryset .. versionadded:: 0.8.3 Convert to non caching queryset
""" """
if self._result_cache is not None: if self._result_cache is not None:
raise OperationError('QuerySet already cached') raise OperationError('QuerySet already cached')
return self.clone_into(QuerySetNoCache(self._document, self._collection))
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
class QuerySetNoCache(BaseQuerySet): class QuerySetNoCache(BaseQuerySet):
@@ -155,7 +153,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to caching queryset .. versionadded:: 0.8.3 Convert to caching queryset
""" """
return self._clone_into(QuerySet(self._document, self._collection)) return self.clone_into(QuerySet(self._document, self._collection))
def __repr__(self): def __repr__(self):
"""Provides the string representation of the QuerySet """Provides the string representation of the QuerySet

View File

@@ -306,24 +306,6 @@ class FieldTest(unittest.TestCase):
person.id = '497ce96f395f2f052a494fd4' person.id = '497ce96f395f2f052a494fd4'
person.validate() person.validate()
def test_db_field_validation(self):
"""Ensure that db_field doesn't accept invalid values."""
# dot in the name
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='user.name')
# name starting with $
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='$name')
# name containing a null character
with self.assertRaises(ValueError):
class User(Document):
name = StringField(db_field='name\0')
def test_string_validation(self): def test_string_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
""" """
@@ -1985,7 +1967,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(content, User.objects.first().groups[0].content) self.assertEqual(content, User.objects.first().groups[0].content)
def test_reference_miss(self): def test_reference_miss(self):
"""Ensure an exception is raised when dereferencing unknown document """Ensure an exception is raised when dereferencing unknow document
""" """
class Foo(Document): class Foo(Document):
@@ -3205,22 +3187,8 @@ class FieldTest(unittest.TestCase):
shirt.size = "XS" shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate) self.assertRaises(ValidationError, shirt.validate)
def test_choices_as_set(self):
"""Ensure that sets can be used as field choices"""
class Shirt(Document):
size = StringField(choices={'S', 'M', 'L', 'XL', 'XXL'})
Shirt.drop_collection() Shirt.drop_collection()
shirt = Shirt()
shirt.validate()
shirt.size = "S"
shirt.validate()
shirt.size = "XS"
self.assertRaises(ValidationError, shirt.validate)
def test_choices_validation_documents(self): def test_choices_validation_documents(self):
""" """
Ensure fields with document choices validate given a valid choice. Ensure fields with document choices validate given a valid choice.
@@ -4005,25 +3973,30 @@ class FieldTest(unittest.TestCase):
"""Tests if a `FieldDoesNotExist` exception is raised when trying to """Tests if a `FieldDoesNotExist` exception is raised when trying to
instanciate a document with a field that's not defined. instanciate a document with a field that's not defined.
""" """
class Doc(Document):
foo = StringField()
with self.assertRaises(FieldDoesNotExist): class Doc(Document):
foo = StringField(db_field='f')
def test():
Doc(bar='test') Doc(bar='test')
self.assertRaises(FieldDoesNotExist, test)
def test_undefined_field_exception_with_strict(self): def test_undefined_field_exception_with_strict(self):
"""Tests if a `FieldDoesNotExist` exception is raised when trying to """Tests if a `FieldDoesNotExist` exception is raised when trying to
instanciate a document with a field that's not defined, instanciate a document with a field that's not defined,
even when strict is set to False. even when strict is set to False.
""" """
class Doc(Document): class Doc(Document):
foo = StringField() foo = StringField(db_field='f')
meta = {'strict': False} meta = {'strict': False}
with self.assertRaises(FieldDoesNotExist): def test():
Doc(bar='test') Doc(bar='test')
self.assertRaises(FieldDoesNotExist, test)
def test_long_field_is_considered_as_int64(self): def test_long_field_is_considered_as_int64(self):
""" """
Tests that long fields are stored as long in mongo, even if long value Tests that long fields are stored as long in mongo, even if long value

View File

@@ -106,111 +106,58 @@ class QuerySetTest(unittest.TestCase):
list(BlogPost.objects(author2__name="test")) list(BlogPost.objects(author2__name="test"))
def test_find(self): def test_find(self):
"""Ensure that a query returns a valid set of results.""" """Ensure that a query returns a valid set of results.
user_a = self.Person.objects.create(name='User A', age=20) """
user_b = self.Person.objects.create(name='User B', age=30) self.Person(name="User A", age=20).save()
self.Person(name="User B", age=30).save()
# Find all people in the collection # Find all people in the collection
people = self.Person.objects people = self.Person.objects
self.assertEqual(people.count(), 2) self.assertEqual(people.count(), 2)
results = list(people) results = list(people)
self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode)))
self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0], user_a)
self.assertEqual(results[0].name, 'User A')
self.assertEqual(results[0].age, 20) self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1], user_b)
self.assertEqual(results[1].name, 'User B')
self.assertEqual(results[1].age, 30) self.assertEqual(results[1].age, 30)
# Filter people by age # Use a query to filter the people found to just person1
people = self.Person.objects(age=20) people = self.Person.objects(age=20)
self.assertEqual(people.count(), 1) self.assertEqual(people.count(), 1)
person = people.next() person = people.next()
self.assertEqual(person, user_a)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20) self.assertEqual(person.age, 20)
def test_limit(self): # Test limit
"""Ensure that QuerySet.limit works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test limit on a new queryset
people = list(self.Person.objects.limit(1)) people = list(self.Person.objects.limit(1))
self.assertEqual(len(people), 1) self.assertEqual(len(people), 1)
self.assertEqual(people[0], user_a) self.assertEqual(people[0].name, 'User A')
# Test limit on an existing queryset # Test skip
people = self.Person.objects
self.assertEqual(len(people), 2)
people2 = people.limit(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_a)
# Test chaining of only after limit
person = self.Person.objects().limit(1).only('name').first()
self.assertEqual(person, user_a)
self.assertEqual(person.name, 'User A')
self.assertEqual(person.age, None)
def test_skip(self):
"""Ensure that QuerySet.skip works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
# Test skip on a new queryset
people = list(self.Person.objects.skip(1)) people = list(self.Person.objects.skip(1))
self.assertEqual(len(people), 1) self.assertEqual(len(people), 1)
self.assertEqual(people[0], user_b) self.assertEqual(people[0].name, 'User B')
# Test skip on an existing queryset person3 = self.Person(name="User C", age=40)
people = self.Person.objects person3.save()
self.assertEqual(len(people), 2)
people2 = people.skip(1)
self.assertEqual(len(people), 2)
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test chaining of only after skip
person = self.Person.objects().skip(1).only('name').first()
self.assertEqual(person, user_b)
self.assertEqual(person.name, 'User B')
self.assertEqual(person.age, None)
def test_slice(self):
"""Ensure slicing a queryset works as expected."""
user_a = self.Person.objects.create(name='User A', age=20)
user_b = self.Person.objects.create(name='User B', age=30)
user_c = self.Person.objects.create(name="User C", age=40)
# Test slice limit # Test slice limit
people = list(self.Person.objects[:2]) people = list(self.Person.objects[:2])
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
self.assertEqual(people[0], user_a) self.assertEqual(people[0].name, 'User A')
self.assertEqual(people[1], user_b) self.assertEqual(people[1].name, 'User B')
# Test slice skip # Test slice skip
people = list(self.Person.objects[1:]) people = list(self.Person.objects[1:])
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
self.assertEqual(people[0], user_b) self.assertEqual(people[0].name, 'User B')
self.assertEqual(people[1], user_c) self.assertEqual(people[1].name, 'User C')
# Test slice limit and skip # Test slice limit and skip
people = list(self.Person.objects[1:2]) people = list(self.Person.objects[1:2])
self.assertEqual(len(people), 1) self.assertEqual(len(people), 1)
self.assertEqual(people[0], user_b) self.assertEqual(people[0].name, 'User B')
# Test slice limit and skip on an existing queryset
people = self.Person.objects
self.assertEqual(len(people), 3)
people2 = people[1:2]
self.assertEqual(len(people2), 1)
self.assertEqual(people2[0], user_b)
# Test slice limit and skip cursor reset # Test slice limit and skip cursor reset
qs = self.Person.objects[1:2] qs = self.Person.objects[1:2]
@@ -221,7 +168,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 1) self.assertEqual(len(people), 1)
self.assertEqual(people[0].name, 'User B') self.assertEqual(people[0].name, 'User B')
# Test empty slice
people = list(self.Person.objects[1:1]) people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0) self.assertEqual(len(people), 0)
@@ -241,6 +187,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("[<Person: Person object>, <Person: Person object>]", self.assertEqual("[<Person: Person object>, <Person: Person object>]",
"%s" % self.Person.objects[51:53]) "%s" % self.Person.objects[51:53])
# Test only after limit
self.assertEqual(self.Person.objects().limit(2).only('name')[0].age, None)
# Test only after skip
self.assertEqual(self.Person.objects().skip(2).only('name')[0].age, None)
def test_find_one(self): def test_find_one(self):
"""Ensure that a query using find_one returns a valid result. """Ensure that a query using find_one returns a valid result.
""" """
@@ -1274,7 +1226,6 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
# default ordering should be used by default
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first() BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1) self.assertEqual(len(q.get_ops()), 1)
@@ -1283,28 +1234,11 @@ class QuerySetTest(unittest.TestCase):
{'published_date': -1} {'published_date': -1}
) )
# calling order_by() should clear the default ordering
with db_ops_tracker() as q: with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first() BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1) self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query']) self.assertFalse('$orderby' in q.get_ops()[0]['query'])
# calling an explicit order_by should use a specified sort
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by('published_date').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'],
{'published_date': 1}
)
# calling order_by() after an explicit sort should clear it
with db_ops_tracker() as q:
qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
qs.order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self): def test_no_ordering_for_get(self):
""" Ensure that Doc.objects.get doesn't use any ordering. """ Ensure that Doc.objects.get doesn't use any ordering.
""" """