From 9bbd8dbe624c385c68404834884f20b304b5b64f Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 4 Jan 2013 09:41:08 +0000 Subject: [PATCH] Querysets now return clones and are no longer edit in place Fixes #56 --- docs/changelog.rst | 1 + docs/upgrade.rst | 22 + mongoengine/connection.py | 6 +- mongoengine/queryset/queryset.py | 1182 ++++++++++++++++-------------- tests/queryset/queryset.py | 11 +- tests/queryset/visitor.py | 4 +- 6 files changed, 656 insertions(+), 570 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c6493038..4fd3e143 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -26,6 +26,7 @@ Changes in 0.8.X - Fix Django timezone support (#151) - Simplified Q objects, removed QueryTreeTransformerVisitor (#98) (#171) - FileFields now copyable (#198) +- Querysets now return clones and are no longer edit in place (#56) Changes in 0.7.9 ================ diff --git a/docs/upgrade.rst b/docs/upgrade.rst index bf48527c..9c6c9a9d 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -56,6 +56,28 @@ you will need to declare :attr:`allow_inheritance` in the meta data like so: :: meta = {'allow_inheritance': True} +Querysets +~~~~~~~~~ + +Querysets now return clones and should no longer be considered editable in +place. This brings us in line with how Django's querysets work and removes a +long running gotcha. If you edit your querysets inplace you will have to +update your code like so: :: + + # Old code: + mammals = Animal.objects(type="mammal") + mammals.filter(order="Carnivora") # Returns a cloned queryset that isn't assigned to anything - so this will break in 0.8 + [m for m in mammals] # This will return all mammals in 0.8 as the 2nd filter returned a new queryset + + # Update example a) assign queryset after a change: + mammals = Animal.objects(type="mammal") + carnivores = mammals.filter(order="Carnivora") # Reassign the new queryset so fitler can be applied + [m for m in carnivores] # This will return all carnivores + + # Update example b) chain the queryset: + mammals = Animal.objects(type="mammal").filter(order="Carnivora") # The final queryset is assgined to mammals + [m for m in mammals] # This will return all carnivores + Indexes ------- diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 1ccbbe31..87308ba3 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -28,8 +28,10 @@ def register_connection(alias, name, host='localhost', port=27017, :param name: the name of the specific database to use :param host: the host name of the :program:`mongod` instance to connect to :param port: the port that the :program:`mongod` instance is running on - :param is_slave: whether the connection can act as a slave ** Depreciated pymongo 2.0.1+ - :param read_preference: The read preference for the collection ** Added pymongo 2.1 + :param is_slave: whether the connection can act as a slave + ** Depreciated pymongo 2.0.1+ + :param read_preference: The read preference for the collection + ** Added pymongo 2.1 :param slaves: a list of aliases of slave connections; each of these must be a registered connection that has :attr:`is_slave` set to ``True`` :param username: username to authenticate with diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 3ea9f232..239975f3 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -42,7 +42,6 @@ class QuerySet(object): providing :class:`~mongoengine.Document` objects as the results. """ __dereference = False - __none = False def __init__(self, document, collection): self._document = document @@ -60,6 +59,7 @@ class QuerySet(object): self._read_preference = None self._iter = False self._scalar = [] + self._none = False self._as_pymongo = False self._as_pymongo_coerce = False @@ -71,35 +71,9 @@ class QuerySet(object): self._cursor_obj = None self._limit = None self._skip = None + self._slice = None self._hint = -1 # Using -1 as None is a valid value for hint - def clone(self): - """Creates a copy of the current - :class:`~mongoengine.queryset.QuerySet` - - .. versionadded:: 0.5 - """ - c = self.__class__(self._document, self._collection_obj) - - copy_props = ('_initial_query', '_query_obj', '_where_clause', - '_loaded_fields', '_ordering', '_snapshot', - '_timeout', '_limit', '_skip', '_slave_okay', '_hint', - '_read_preference') - - for prop in copy_props: - val = getattr(self, prop) - setattr(c, prop, copy.deepcopy(val)) - - return c - - @property - def _query(self): - if self._mongo_query is None: - self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) - return self._mongo_query - def __call__(self, q_obj=None, class_check=True, slave_okay=False, read_preference=None, **query): """Filter the selected documents by calling the @@ -121,87 +95,94 @@ class QuerySet(object): if q_obj: # make sure proper query object is passed if not isinstance(q_obj, QNode): - raise InvalidQueryError('Not a query object: %s. Did you intend to use key=value?' % q_obj) + msg = ("Not a query object: %s. " + "Did you intend to use key=value?" % q_obj) + raise InvalidQueryError(msg) query &= q_obj - self._query_obj &= query - self._mongo_query = None - self._cursor_obj = None + + queryset = self.clone() + queryset._query_obj &= query + queryset._mongo_query = None + queryset._cursor_obj = None if read_preference is not None: - self.read_preference(read_preference) - self._class_check = class_check + queryset.read_preference(read_preference) + queryset._class_check = class_check + return queryset + + + def __iter__(self): + """Support iterator protocol""" + self.rewind() return self - def filter(self, *q_objs, **query): - """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` + def __len__(self): + return self.count() + + def __getitem__(self, key): + """Support skip and limit using getitem and slicing syntax. """ - return self.__call__(*q_objs, **query) + queryset = self.clone() + + # Slice provided + if isinstance(key, slice): + try: + queryset._cursor_obj = queryset._cursor[key] + queryset._slice = key + queryset._skip, queryset._limit = key.start, key.stop + except IndexError, err: + # PyMongo raises an error if key.start == key.stop, catch it, + # bin it, kill it. + start = key.start or 0 + if start >= 0 and key.stop >= 0 and key.step is None: + if start == key.stop: + queryset.limit(0) + queryset._skip = key.start + queryset._limit = key.stop - start + return queryset + raise err + # Allow further QuerySet modifications to be performed + return queryset + # Integer index provided + elif isinstance(key, int): + if queryset._scalar: + return queryset._get_scalar( + queryset._document._from_son(queryset._cursor[key])) + if queryset._as_pymongo: + return queryset._get_as_pymongo(queryset._cursor.next()) + return queryset._document._from_son(queryset._cursor[key]) + raise AttributeError + + def __repr__(self): + """Provides the string representation of the QuerySet + + .. versionchanged:: 0.6.13 Now doesnt modify the cursor + """ + + if self._iter: + return '.. queryset mid-iteration ..' + + data = [] + for i in xrange(REPR_OUTPUT_SIZE + 1): + try: + data.append(self.next()) + except StopIteration: + break + if len(data) > REPR_OUTPUT_SIZE: + data[-1] = "...(remaining elements truncated)..." + + self.rewind() + return repr(data) + + # Core functions def all(self): """Returns all documents.""" return self.__call__() - def ensure_index(self, **kwargs): - """Deprecated use :func:`~Document.ensure_index`""" - msg = ("Doc.objects()._ensure_index() is deprecated. " - "Use Doc.ensure_index() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_index(**kwargs) - return self - - def _ensure_indexes(self): - """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ("Doc.objects()._ensure_indexes() is deprecated. " - "Use Doc.ensure_indexes() instead.") - warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_indexes() - - @property - def _collection(self): - """Property that returns the collection object. This allows us to - perform operations only if the collection is accessed. + def filter(self, *q_objs, **query): + """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` """ - return self._collection_obj - - @property - def _cursor_args(self): - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - 'slave_okay': self._slave_okay, - } - if self._read_preference is not None: - cursor_args['read_preference'] = self._read_preference - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() - return cursor_args - - @property - def _cursor(self): - if self._cursor_obj is None: - - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) - # Apply where clauses to cursor - if self._where_clause: - self._cursor_obj.where(self._where_clause) - - if self._ordering: - # Apply query ordering - self._cursor_obj.sort(self._ordering) - elif self._document._meta['ordering']: - # Otherwise, apply the ordering from the document model - self.order_by(*self._document._meta['ordering']) - self._cursor_obj.sort(self._ordering) - - if self._limit is not None: - self._cursor_obj.limit(self._limit - (self._skip or 0)) - - if self._skip is not None: - self._cursor_obj.skip(self._skip) - - if self._hint != -1: - self._cursor_obj.hint(self._hint) - return self._cursor_obj + return self.__call__(*q_objs, **query) def get(self, *q_objs, **query): """Retrieve the the matching object raising @@ -212,22 +193,29 @@ class QuerySet(object): .. versionadded:: 0.3 """ - self.limit(2) - self.__call__(*q_objs, **query) + queryset = self.__call__(*q_objs, **query) + queryset = queryset.limit(2) try: - result = self.next() + result = queryset.next() except StopIteration: msg = ("%s matching query does not exist." - % self._document._class_name) - raise self._document.DoesNotExist(msg) + % queryset._document._class_name) + raise queryset._document.DoesNotExist(msg) try: - self.next() + queryset.next() except StopIteration: return result - self.rewind() - message = u'%d items returned, instead of 1' % self.count() - raise self._document.MultipleObjectsReturned(message) + queryset.rewind() + message = u'%d items returned, instead of 1' % queryset.count() + raise queryset._document.MultipleObjectsReturned(message) + + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + + .. versionadded:: 0.4 + """ + return self._document(**kwargs).save() def get_or_create(self, write_options=None, auto_save=True, *q_objs, **query): @@ -277,20 +265,12 @@ class QuerySet(object): doc.save(write_options=write_options) return doc, True - def create(self, **kwargs): - """Create new object. Returns the saved object instance. - - .. versionadded:: 0.4 - """ - doc = self._document(**kwargs) - doc.save() - return doc - def first(self): """Retrieve the first object matching the query. """ + queryset = self.clone() try: - result = self[0] + result = queryset[0] except IndexError: result = None return result @@ -367,6 +347,117 @@ class QuerySet(object): self._document, documents=results, loaded=True) return return_one and results[0] or results + def count(self): + """Count the selected elements in the query. + """ + if self._limit == 0: + return 0 + return self._cursor.count(with_limit_and_skip=True) + + def delete(self, safe=False): + """Delete the documents matched by the query. + + :param safe: check if the operation succeeded before returning + """ + queryset = self.clone() + doc = queryset._document + + # Handle deletes where skips or limits have been applied + if queryset._skip or queryset._limit: + for doc in queryset: + doc.delete() + return + + delete_rules = doc._meta.get('delete_rules') or {} + # Check for DENY rules before actually deleting/nullifying any other + # references + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == DENY and document_cls.objects( + **{field_name + '__in': self}).count() > 0: + msg = ("Could not delete document (%s.%s refers to it)" + % (document_cls.__name__, field_name)) + raise OperationError(msg) + + for rule_entry in delete_rules: + document_cls, field_name = rule_entry + rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: + ref_q = document_cls.objects(**{field_name + '__in': self}) + ref_q_count = ref_q.count() + if (doc != document_cls and ref_q_count > 0 + or (doc == document_cls and ref_q_count > 0)): + ref_q.delete(safe=safe) + elif rule == NULLIFY: + document_cls.objects(**{field_name + '__in': self}).update( + safe_update=safe, + **{'unset__%s' % field_name: 1}) + elif rule == PULL: + document_cls.objects(**{field_name + '__in': self}).update( + safe_update=safe, + **{'pull_all__%s' % field_name: self}) + + queryset._collection.remove(queryset._query, safe=safe) + + def update(self, safe_update=True, upsert=False, multi=True, + write_options=None, **update): + """Perform an atomic update on the fields matched by the query. When + ``safe_update`` is used, the number of affected documents is returned. + + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for + :meth:`~pymongo.collection.Collection.update` + + .. versionadded:: 0.2 + """ + if not update: + raise OperationError("No update parameters, would remove data") + + if not write_options: + write_options = {} + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + # If doing an atomic upsert on an inheritable class + # then ensure we add _cls to the update operation + if upsert and '_cls' in query: + if '$set' in update: + update["$set"]["_cls"] = queryset._document._class_name + else: + update["$set"] = {"_cls": queryset._document._class_name} + + try: + ret = queryset._collection.update(query, update, multi=multi, + upsert=upsert, safe=safe_update, + **write_options) + if ret is not None and 'n' in ret: + return ret['n'] + except pymongo.errors.OperationFailure, err: + if unicode(err) == u'multi not coded yet': + message = u'update() method requires MongoDB 1.1.3+' + raise OperationError(message) + raise OperationError(u'Update failed (%s)' % unicode(err)) + + def update_one(self, safe_update=True, upsert=False, write_options=None, + **update): + """Perform an atomic update on first field matched by the query. When + ``safe_update`` is used, the number of affected documents is returned. + + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for + :meth:`~pymongo.collection.Collection.update` + :param update: Django-style update keyword arguments + + .. versionadded:: 0.2 + """ + return self.update(safe_update=True, upsert=upsert, multi=False, + write_options=None, **update) + def with_id(self, object_id): """Retrieve the object matching the id provided. Uses `object_id` only and raises InvalidQueryError if a filter has been applied. @@ -375,10 +466,11 @@ class QuerySet(object): .. versionchanged:: 0.6 Raises InvalidQueryError if filter has been set """ - if not self._query_obj.empty: + queryset = self.clone() + if not queryset._query_obj.empty: msg = "Cannot use a filter whilst using `with_id`" raise InvalidQueryError(msg) - return self.filter(pk=object_id).first() + return queryset.filter(pk=object_id).first() def in_bulk(self, object_ids): """Retrieve a set of documents by their ids. @@ -406,139 +498,48 @@ class QuerySet(object): return doc_map - def next(self): - """Wrap the result in a :class:`~mongoengine.Document` object. - """ - self._iter = True - try: - if self._limit == 0 or self.__none: - raise StopIteration - if self._scalar: - return self._get_scalar(self._document._from_son( - self._cursor.next())) - if self._as_pymongo: - return self._get_as_pymongo(self._cursor.next()) - - return self._document._from_son(self._cursor.next()) - except StopIteration, e: - self.rewind() - raise e - - def rewind(self): - """Rewind the cursor to its unevaluated state. - - .. versionadded:: 0.3 - """ - self._iter = False - self._cursor.rewind() - def none(self): """Helper that just returns a list""" - self.__none = True - return self + queryset = self.clone() + queryset._none = True + return queryset - def count(self): - """Count the selected elements in the query. + def clone(self): + """Creates a copy of the current + :class:`~mongoengine.queryset.QuerySet` + + .. versionadded:: 0.5 """ - if self._limit == 0: - return 0 - return self._cursor.count(with_limit_and_skip=True) + c = self.__class__(self._document, self._collection_obj) - def __len__(self): - return self.count() + copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', + '_where_clause', '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_class_check', '_slave_okay', '_read_preference', + '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', + '_limit', '_skip', '_slice', '_hint') - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): - """Perform a map/reduce query using the current query spec - and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, - it must be the last call made, as it does not return a maleable - ``QuerySet``. + for prop in copy_props: + val = getattr(self, prop) + setattr(c, prop, copy.copy(val)) - See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` - and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` - tests in ``tests.queryset.QuerySetTest`` for usage examples. + if self._cursor_obj: + c._cursor_obj = self._cursor_obj.clone() - :param map_f: map function, as :class:`~bson.code.Code` or string - :param reduce_f: reduce function, as - :class:`~bson.code.Code` or string - :param output: output collection name, if set to 'inline' will try to - use :class:`~pymongo.collection.Collection.inline_map_reduce` - This can also be a dictionary containing output options - see: http://docs.mongodb.org/manual/reference/commands/#mapReduce - :param finalize_f: finalize function, an optional function that - performs any post-reduction processing. - :param scope: values to insert into map/reduce global scope. Optional. - :param limit: number of objects from current query to provide - to map/reduce method + if self._slice: + c._cursor_obj[self._slice] - Returns an iterator yielding - :class:`~mongoengine.document.MapReduceDocument`. + return c - .. note:: + def select_related(self, max_depth=1): + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to + a maximum depth in order to cut down the number queries to mongodb. - Map/Reduce changed in server version **>= 1.7.4**. The PyMongo - :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.11**. - - .. versionchanged:: 0.5 - - removed ``keep_temp`` keyword argument, which was only relevant - for MongoDB server versions older than 1.7.4 - - .. versionadded:: 0.3 + .. versionadded:: 0.5 """ - MapReduceDocument = _import_class('MapReduceDocument') - - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - - map_f_scope = {} - if isinstance(map_f, Code): - map_f_scope = map_f.scope - map_f = unicode(map_f) - map_f = Code(self._sub_js_fields(map_f), map_f_scope) - - reduce_f_scope = {} - if isinstance(reduce_f, Code): - reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) - reduce_f_code = self._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) - - mr_args = {'query': self._query} - - if finalize_f: - finalize_f_scope = {} - if isinstance(finalize_f, Code): - finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) - finalize_f_code = self._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f - - if scope: - mr_args['scope'] = scope - - if limit: - mr_args['limit'] = limit - - if output == 'inline' and not self._ordering: - map_reduce_function = 'inline_map_reduce' - else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output - - results = getattr(self._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) - - if map_reduce_function == 'map_reduce': - results = results.find() - - if self._ordering: - results = results.sort(self._ordering) - - for doc in results: - yield MapReduceDocument(self._document, self._collection, - doc['_id'], doc['value']) + # Make select related work the same for querysets + max_depth += 1 + queryset = self.clone() + return queryset._dereference(queryset, max_depth=max_depth) def limit(self, n): """Limit the number of returned documents to `n`. This may also be @@ -546,14 +547,15 @@ class QuerySet(object): :param n: the maximum number of objects to return """ + queryset = self.clone() if n == 0: - self._cursor.limit(1) + queryset._cursor.limit(1) else: - self._cursor.limit(n) - self._limit = n + queryset._cursor.limit(n) + queryset._limit = n # Return self to allow chaining - return self + return queryset def skip(self, n): """Skip `n` documents before returning the results. This may also be @@ -561,9 +563,10 @@ class QuerySet(object): :param n: the number of objects to skip before returning results """ - self._cursor.skip(n) - self._skip = n - return self + queryset = self.clone() + queryset._cursor.skip(n) + queryset._skip = n + return queryset def hint(self, index=None): """Added 'hint' support, telling Mongo the proper index to use for the @@ -578,39 +581,10 @@ class QuerySet(object): .. versionadded:: 0.5 """ - self._cursor.hint(index) - self._hint = index - return self - - def __getitem__(self, key): - """Support skip and limit using getitem and slicing syntax. - """ - # Slice provided - if isinstance(key, slice): - try: - self._cursor_obj = self._cursor[key] - self._skip, self._limit = key.start, key.stop - except IndexError, err: - # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. - start = key.start or 0 - if start >= 0 and key.stop >= 0 and key.step is None: - if start == key.stop: - self.limit(0) - self._skip, self._limit = key.start, key.stop - start - return self - raise err - # Allow further QuerySet modifications to be performed - return self - # Integer index provided - elif isinstance(key, int): - if self._scalar: - return self._get_scalar(self._document._from_son( - self._cursor[key])) - if self._as_pymongo: - return self._get_as_pymongo(self._cursor.next()) - return self._document._from_son(self._cursor[key]) - raise AttributeError + queryset = self.clone() + queryset._cursor.hint(index) + queryset._hint = index + return queryset def distinct(self, field): """Return a list of distinct values for a given field. @@ -621,8 +595,9 @@ class QuerySet(object): .. versionchanged:: 0.5 - Fixed handling references .. versionchanged:: 0.6 - Improved db_field refrence handling """ - return self._dereference(self._cursor.distinct(field), 1, - name=field, instance=self._document) + queryset = self.clone() + return queryset._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=queryset._document) def only(self, *fields): """Load only a subset of this document's fields. :: @@ -679,11 +654,12 @@ class QuerySet(object): cleaned_fields.append((key, value)) fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + queryset = self.clone() for value, group in itertools.groupby(fields, lambda x: x[1]): fields = [field for field, value in group] - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, value=value) - return self + fields = queryset._fields_to_dbfields(fields) + queryset._loaded_fields += QueryFieldList(fields, value=value) + return queryset def all_fields(self): """Include all fields. Reset all previously calls of .only() or @@ -693,18 +669,10 @@ class QuerySet(object): .. versionadded:: 0.5 """ - self._loaded_fields = QueryFieldList( - always_include=self._loaded_fields.always_include) - return self - - def _fields_to_dbfields(self, fields): - """Translate fields paths to its db equivalents""" - ret = [] - for field in fields: - field = ".".join(f.db_field for f in - self._document._lookup_field(field.split('.'))) - ret.append(field) - return ret + queryset = self.clone() + queryset._loaded_fields = QueryFieldList( + always_include=queryset._loaded_fields.always_include) + return queryset def order_by(self, *keys): """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The @@ -714,25 +682,9 @@ class QuerySet(object): :param keys: fields to order the query results by; keys may be prefixed with **+** or **-** to determine the ordering direction """ - key_list = [] - for key in keys: - if not key: - continue - direction = pymongo.ASCENDING - if key[0] == '-': - direction = pymongo.DESCENDING - if key[0] in ('-', '+'): - key = key[1:] - key = key.replace('__', '.') - try: - key = self._document._translate_field_name(key) - except: - pass - key_list.append((key, direction)) - - self._ordering = key_list - - return self + queryset = self.clone() + queryset._ordering = self._get_order_by(keys) + return queryset def explain(self, format=False): """Return an explain plan record for the @@ -740,7 +692,6 @@ class QuerySet(object): :param format: format the plan before returning it """ - plan = self._cursor.explain() if format: plan = pprint.pformat(plan) @@ -753,8 +704,9 @@ class QuerySet(object): ..versionchanged:: 0.5 - made chainable """ - self._snapshot = enabled - return self + queryset = self.clone() + queryset._snapshot = enabled + return queryset def timeout(self, enabled): """Enable or disable the default mongod timeout when querying. @@ -763,16 +715,18 @@ class QuerySet(object): ..versionchanged:: 0.5 - made chainable """ - self._timeout = enabled - return self + queryset = self.clone() + queryset._timeout = enabled + return queryset def slave_okay(self, enabled): """Enable or disable the slave_okay when querying. :param enabled: whether or not the slave_okay is enabled """ - self._slave_okay = enabled - return self + queryset = self.clone() + queryset._slave_okay = enabled + return queryset def read_preference(self, read_preference): """Change the read_preference when querying. @@ -781,170 +735,9 @@ class QuerySet(object): preference. """ validate_read_preference('read_preference', read_preference) - self._read_preference = read_preference - return self - - def delete(self, safe=False): - """Delete the documents matched by the query. - - :param safe: check if the operation succeeded before returning - """ - doc = self._document - - # Handle deletes where skips or limits have been applied - if self._skip or self._limit: - for doc in self: - doc.delete() - return - - delete_rules = doc._meta.get('delete_rules') or {} - # Check for DENY rules before actually deleting/nullifying any other - # references - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects( - **{field_name + '__in': self}).count() > 0: - msg = ("Could not delete document (%s.%s refers to it)" - % (document_cls.__name__, field_name)) - raise OperationError(msg) - - for rule_entry in delete_rules: - document_cls, field_name = rule_entry - rule = doc._meta['delete_rules'][rule_entry] - if rule == CASCADE: - ref_q = document_cls.objects(**{field_name + '__in': self}) - ref_q_count = ref_q.count() - if (doc != document_cls and ref_q_count > 0 - or (doc == document_cls and ref_q_count > 0)): - ref_q.delete(safe=safe) - elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'unset__%s' % field_name: 1}) - elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - safe_update=safe, - **{'pull_all__%s' % field_name: self}) - - self._collection.remove(self._query, safe=safe) - - def update(self, safe_update=True, upsert=False, multi=True, - write_options=None, **update): - """Perform an atomic update on the fields matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for - :meth:`~pymongo.collection.Collection.update` - - .. versionadded:: 0.2 - """ - if not update: - raise OperationError("No update parameters, would remove data") - - if not write_options: - write_options = {} - - query = self._query - update = transform.update(self._document, **update) - - # If doing an atomic upsert on an inheritable class - # then ensure we add _cls to the update operation - if upsert and '_cls' in query: - if '$set' in update: - update["$set"]["_cls"] = self._document._class_name - else: - update["$set"] = {"_cls": self._document._class_name} - - try: - ret = self._collection.update(query, update, multi=multi, - upsert=upsert, safe=safe_update, - **write_options) - if ret is not None and 'n' in ret: - return ret['n'] - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' - raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) - - def update_one(self, safe_update=True, upsert=False, write_options=None, - **update): - """Perform an atomic update on first field matched by the query. When - ``safe_update`` is used, the number of affected documents is returned. - - :param safe_update: check if the operation succeeded before returning - :param upsert: Any existing document with that "_id" is overwritten. - :param write_options: extra keyword arguments for - :meth:`~pymongo.collection.Collection.update` - :param update: Django-style update keyword arguments - - .. versionadded:: 0.2 - """ - return self.update(safe_update=True, upsert=upsert, multi=False, - write_options=None, **update) - - def __iter__(self): - self.rewind() - return self - - def _get_scalar(self, doc): - - def lookup(obj, name): - chunks = name.split('__') - for chunk in chunks: - obj = getattr(obj, chunk) - return obj - - data = [lookup(doc, n) for n in self._scalar] - if len(data) == 1: - return data[0] - - return tuple(data) - - def _get_as_pymongo(self, row): - # Extract which fields paths we should follow if .fields(...) was - # used. If not, handle all fields. - if not getattr(self, '__as_pymongo_fields', None): - self.__as_pymongo_fields = [] - for field in self._loaded_fields.fields - set(['_cls', '_id', '_types']): - self.__as_pymongo_fields.append(field) - while '.' in field: - field, _ = field.rsplit('.', 1) - self.__as_pymongo_fields.append(field) - - all_fields = not self.__as_pymongo_fields - - def clean(data, path=None): - path = path or '' - - if isinstance(data, dict): - new_data = {} - for key, value in data.iteritems(): - new_path = '%s.%s' % (path, key) if path else key - if all_fields or new_path in self.__as_pymongo_fields: - new_data[key] = clean(value, path=new_path) - data = new_data - elif isinstance(data, list): - data = [clean(d, path=path) for d in data] - else: - if self._as_pymongo_coerce: - # If we need to coerce types, we need to determine the - # type of this field and use the corresponding .to_python(...) - from mongoengine.fields import EmbeddedDocumentField - obj = self._document - for chunk in path.split('.'): - obj = getattr(obj, chunk, None) - if obj is None: - break - elif isinstance(obj, EmbeddedDocumentField): - obj = obj.document_type - if obj and data is not None: - data = obj.to_python(data) - return data - return clean(row) + queryset = self.clone() + queryset._read_preference = read_preference + return queryset def scalar(self, *fields): """Instead of returning Document instances, return either a specific @@ -955,14 +748,15 @@ class QuerySet(object): :param fields: One or more fields to return instead of a Document. """ - self._scalar = list(fields) + queryset = self.clone() + queryset._scalar = list(fields) if fields: - self.only(*fields) + queryset = queryset.only(*fields) else: - self.all_fields() + queryset = queryset.all_fields() - return self + return queryset def values_list(self, *fields): """An alias for scalar""" @@ -972,36 +766,122 @@ class QuerySet(object): """Instead of returning Document instances, return raw values from pymongo. - :param coerce_type: Field types (if applicable) would be use to coerce types. + :param coerce_type: Field types (if applicable) would be use to + coerce types. """ - self._as_pymongo = True - self._as_pymongo_coerce = coerce_types - return self + queryset = self.clone() + queryset._as_pymongo = True + queryset._as_pymongo_coerce = coerce_types + return queryset - def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be - substituted for the MongoDB name of the field (specified using the - :attr:`name` keyword argument in a field's constructor). + # JSON Helpers + + def to_json(self): + """Converts a queryset to JSON""" + queryset = self.clone() + return json_util.dumps(queryset._collection_obj.find(queryset._query)) + + def from_json(self, json_data): + """Converts json data to unsaved objects""" + son_data = json_util.loads(json_data) + return [self._document._from_son(data) for data in son_data] + + # JS functionality + + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, + scope=None): + """Perform a map/reduce query using the current query spec + and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, + it must be the last call made, as it does not return a maleable + ``QuerySet``. + + See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` + and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` + tests in ``tests.queryset.QuerySetTest`` for usage examples. + + :param map_f: map function, as :class:`~bson.code.Code` or string + :param reduce_f: reduce function, as + :class:`~bson.code.Code` or string + :param output: output collection name, if set to 'inline' will try to + use :class:`~pymongo.collection.Collection.inline_map_reduce` + This can also be a dictionary containing output options + see: http://docs.mongodb.org/manual/reference/commands/#mapReduce + :param finalize_f: finalize function, an optional function that + performs any post-reduction processing. + :param scope: values to insert into map/reduce global scope. Optional. + :param limit: number of objects from current query to provide + to map/reduce method + + Returns an iterator yielding + :class:`~mongoengine.document.MapReduceDocument`. + + .. note:: + + Map/Reduce changed in server version **>= 1.7.4**. The PyMongo + :meth:`~pymongo.collection.Collection.map_reduce` helper requires + PyMongo version **>= 1.11**. + + .. versionchanged:: 0.5 + - removed ``keep_temp`` keyword argument, which was only relevant + for MongoDB server versions older than 1.7.4 + + .. versionadded:: 0.3 """ - def field_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return u'["%s"]' % fields[-1].db_field + queryset = self.clone() - def field_path_sub(match): - # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') - fields = self._document._lookup_field(field_name) - # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) + MapReduceDocument = _import_class('MapReduceDocument') - code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, - code) - return code + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") + + map_f_scope = {} + if isinstance(map_f, Code): + map_f_scope = map_f.scope + map_f = unicode(map_f) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + + reduce_f_scope = {} + if isinstance(reduce_f, Code): + reduce_f_scope = reduce_f.scope + reduce_f = unicode(reduce_f) + reduce_f_code = queryset._sub_js_fields(reduce_f) + reduce_f = Code(reduce_f_code, reduce_f_scope) + + mr_args = {'query': queryset._query} + + if finalize_f: + finalize_f_scope = {} + if isinstance(finalize_f, Code): + finalize_f_scope = finalize_f.scope + finalize_f = unicode(finalize_f) + finalize_f_code = queryset._sub_js_fields(finalize_f) + finalize_f = Code(finalize_f_code, finalize_f_scope) + mr_args['finalize'] = finalize_f + + if scope: + mr_args['scope'] = scope + + if limit: + mr_args['limit'] = limit + + if output == 'inline' and not queryset._ordering: + map_reduce_function = 'inline_map_reduce' + else: + map_reduce_function = 'map_reduce' + mr_args['out'] = output + + results = getattr(queryset._collection, map_reduce_function)( + map_f, reduce_f, **mr_args) + + if map_reduce_function == 'map_reduce': + results = results.find() + + if queryset._ordering: + results = results.sort(queryset._ordering) + + for doc in results: + yield MapReduceDocument(queryset._document, queryset._collection, + doc['_id'], doc['value']) def exec_js(self, code, *fields, **options): """Execute a Javascript function on the server. A list of fields may be @@ -1025,24 +905,26 @@ class QuerySet(object): :param options: options that you want available to the function (accessed in Javascript through the ``options`` object) """ - code = self._sub_js_fields(code) + queryset = self.clone() - fields = [self._document._translate_field_name(f) for f in fields] - collection = self._document._get_collection_name() + code = queryset._sub_js_fields(code) + + fields = [queryset._document._translate_field_name(f) for f in fields] + collection = queryset._document._get_collection_name() scope = { 'collection': collection, 'options': options or {}, } - query = self._query - if self._where_clause: - query['$where'] = self._where_clause + query = queryset._query + if queryset._where_clause: + query['$where'] = queryset._where_clause scope['query'] = query code = Code(code, scope=scope) - db = self._document._get_db() + db = queryset._document._get_db() return db.eval(code, *fields) def where(self, where_clause): @@ -1056,9 +938,10 @@ class QuerySet(object): .. versionadded:: 0.5 """ - where_clause = self._sub_js_fields(where_clause) - self._where_clause = where_clause - return self + queryset = self.clone() + where_clause = queryset._sub_js_fields(where_clause) + queryset._where_clause = where_clause + return queryset def sum(self, field): """Sum over the values of the specified field. @@ -1157,6 +1040,101 @@ class QuerySet(object): normalize=normalize) return self._item_frequencies_exec_js(field, normalize=normalize) + # Iterator helpers + + def next(self): + """Wrap the result in a :class:`~mongoengine.Document` object. + """ + self._iter = True + try: + if self._limit == 0 or self._none: + raise StopIteration + if self._scalar: + return self._get_scalar(self._document._from_son( + self._cursor.next())) + if self._as_pymongo: + return self._get_as_pymongo(self._cursor.next()) + + return self._document._from_son(self._cursor.next()) + except StopIteration, e: + self.rewind() + raise e + + def rewind(self): + """Rewind the cursor to its unevaluated state. + + .. versionadded:: 0.3 + """ + self._iter = False + self._cursor.rewind() + + # Properties + + @property + def _collection(self): + """Property that returns the collection object. This allows us to + perform operations only if the collection is accessed. + """ + return self._collection_obj + + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + 'slave_okay': self._slave_okay, + } + if self._read_preference is not None: + cursor_args['read_preference'] = self._read_preference + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + + @property + def _cursor(self): + if self._cursor_obj is None: + + self._cursor_obj = self._collection.find(self._query, + **self._cursor_args) + # Apply where clauses to cursor + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + self._cursor_obj.where(where_clause) + + if self._ordering: + # Apply query ordering + self._cursor_obj.sort(self._ordering) + elif self._document._meta['ordering']: + # Otherwise, apply the ordering from the document model + order = self._get_order_by(self._document._meta['ordering']) + self._cursor_obj.sort(order) + + if self._limit is not None: + self._cursor_obj.limit(self._limit - (self._skip or 0)) + + if self._skip is not None: + self._cursor_obj.skip(self._skip) + + if self._hint != -1: + self._cursor_obj.hint(self._hint) + return self._cursor_obj + + @property + def _query(self): + if self._mongo_query is None: + self._mongo_query = self._query_obj.to_query(self._document) + if self._class_check: + self._mongo_query.update(self._initial_query) + return self._mongo_query + + @property + def _dereference(self): + if not self.__dereference: + self.__dereference = _import_class('DeReference')() + return self.__dereference + + # Helper Functions + def _item_frequencies_map_reduce(self, field, normalize=False): map_func = """ function() { @@ -1269,48 +1247,130 @@ class QuerySet(object): return frequencies - def __repr__(self): - """Provides the string representation of the QuerySet + def _fields_to_dbfields(self, fields): + """Translate fields paths to its db equivalents""" + ret = [] + for field in fields: + field = ".".join(f.db_field for f in + self._document._lookup_field(field.split('.'))) + ret.append(field) + return ret - .. versionchanged:: 0.6.13 Now doesnt modify the cursor + def _get_order_by(self, keys): + """Creates a list of order by fields """ - - if self._iter: - return '.. queryset mid-iteration ..' - - data = [] - for i in xrange(REPR_OUTPUT_SIZE + 1): + key_list = [] + for key in keys: + if not key: + continue + direction = pymongo.ASCENDING + if key[0] == '-': + direction = pymongo.DESCENDING + if key[0] in ('-', '+'): + key = key[1:] + key = key.replace('__', '.') try: - data.append(self.next()) - except StopIteration: - break - if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." + key = self._document._translate_field_name(key) + except: + pass + key_list.append((key, direction)) + return key_list - self.rewind() - return repr(data) + def _get_scalar(self, doc): - def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to - a maximum depth in order to cut down the number queries to mongodb. + def lookup(obj, name): + chunks = name.split('__') + for chunk in chunks: + obj = getattr(obj, chunk) + return obj - .. versionadded:: 0.5 + data = [lookup(doc, n) for n in self._scalar] + if len(data) == 1: + return data[0] + + return tuple(data) + + def _get_as_pymongo(self, row): + # Extract which fields paths we should follow if .fields(...) was + # used. If not, handle all fields. + if not getattr(self, '__as_pymongo_fields', None): + self.__as_pymongo_fields = [] + for field in self._loaded_fields.fields - set(['_cls', '_id']): + self.__as_pymongo_fields.append(field) + while '.' in field: + field, _ = field.rsplit('.', 1) + self.__as_pymongo_fields.append(field) + + all_fields = not self.__as_pymongo_fields + + def clean(data, path=None): + path = path or '' + + if isinstance(data, dict): + new_data = {} + for key, value in data.iteritems(): + new_path = '%s.%s' % (path, key) if path else key + if all_fields or new_path in self.__as_pymongo_fields: + new_data[key] = clean(value, path=new_path) + data = new_data + elif isinstance(data, list): + data = [clean(d, path=path) for d in data] + else: + if self._as_pymongo_coerce: + # If we need to coerce types, we need to determine the + # type of this field and use the corresponding + # .to_python(...) + from mongoengine.fields import EmbeddedDocumentField + obj = self._document + for chunk in path.split('.'): + obj = getattr(obj, chunk, None) + if obj is None: + break + elif isinstance(obj, EmbeddedDocumentField): + obj = obj.document_type + if obj and data is not None: + data = obj.to_python(data) + return data + return clean(row) + + def _sub_js_fields(self, code): + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be + substituted for the MongoDB name of the field (specified using the + :attr:`name` keyword argument in a field's constructor). """ - # Make select related work the same for querysets - max_depth += 1 - return self._dereference(self, max_depth=max_depth) + def field_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return u'["%s"]' % fields[-1].db_field - def to_json(self): - """Converts a queryset to JSON""" - return json_util.dumps(self._collection_obj.find(self._query)) + def field_path_sub(match): + # Extract just the field name, and look up the field objects + field_name = match.group(1).split('.') + fields = self._document._lookup_field(field_name) + # Substitute the correct name for the field into the javascript + return ".".join([f.db_field for f in fields]) - def from_json(self, json_data): - """Converts json data to unsaved objects""" - son_data = json_util.loads(json_data) - return [self._document._from_son(data) for data in son_data] + code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) + code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, + code) + return code - @property - def _dereference(self): - if not self.__dereference: - self.__dereference = _import_class('DeReference')() - return self.__dereference + # Deprecated + + def ensure_index(self, **kwargs): + """Deprecated use :func:`~Document.ensure_index`""" + msg = ("Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_index(**kwargs) + return self + + def _ensure_indexes(self): + """Deprecated use :func:`~Document.ensure_indexes`""" + msg = ("Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead.") + warnings.warn(msg, DeprecationWarning) + self._document.__class__.ensure_indexes() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index bad3d360..bf64a565 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -713,19 +713,19 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(p._cursor_args, {'snapshot': False, 'slave_okay': False, 'timeout': True}) - p.snapshot(False).slave_okay(False).timeout(False) + p = p.snapshot(False).slave_okay(False).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': False, 'slave_okay': False, 'timeout': False}) - p.snapshot(True).slave_okay(False).timeout(False) + p = p.snapshot(True).slave_okay(False).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': False, 'timeout': False}) - p.snapshot(True).slave_okay(True).timeout(False) + p = p.snapshot(True).slave_okay(True).timeout(False) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': True, 'timeout': False}) - p.snapshot(True).slave_okay(True).timeout(True) + p = p.snapshot(True).slave_okay(True).timeout(True) self.assertEqual(p._cursor_args, {'snapshot': True, 'slave_okay': True, 'timeout': True}) @@ -773,7 +773,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(docs), 1000) # Limit and skip - self.assertEqual('[, , ]', "%s" % docs[1:4]) + docs = docs[1:4] + self.assertEqual('[, , ]', "%s" % docs) self.assertEqual(docs.count(), 3) self.assertEqual(len(docs), 3) diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 4af39e87..98815dbc 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -202,8 +202,8 @@ class QTest(unittest.TestCase): self.assertEqual(test2.count(), 3) self.assertFalse(test2 == test) - test2.filter(x=6) - self.assertEqual(test2.count(), 1) + test3 = test2.filter(x=6) + self.assertEqual(test3.count(), 1) self.assertEqual(test.count(), 3) def test_q(self):