diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index a1e1245f..f542cc87 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -452,7 +452,6 @@ class QuerySet(object): self._mongo_query = None self._cursor_obj = None self._class_check = class_check - self._slave_okay = slave_okay return self def filter(self, *q_objs, **query): @@ -504,18 +503,23 @@ class QuerySet(object): return self._collection_obj + @property + def _cursor_args(self): + cursor_args = { + 'snapshot': self._snapshot, + 'timeout': self._timeout, + 'slave_okay': self._slave_okay + } + if self._loaded_fields: + cursor_args['fields'] = self._loaded_fields.as_dict() + return cursor_args + @property def _cursor(self): if self._cursor_obj is None: - cursor_args = { - 'snapshot': self._snapshot, - 'timeout': self._timeout, - 'slave_okay': self._slave_okay - } - if self._loaded_fields: - cursor_args['fields'] = self._loaded_fields.as_dict() + self._cursor_obj = self._collection.find(self._query, - **cursor_args) + **self._cursor_args) # Apply where clauses to cursor if self._where_clause: self._cursor_obj.where(self._where_clause) @@ -772,7 +776,7 @@ class QuerySet(object): id_field = self._document._meta['id_field'] object_id = self._document._fields[id_field].to_mongo(object_id) - result = self._collection.find_one({'_id': object_id}) + result = self._collection.find_one({'_id': object_id}, **self._cursor_args) if result is not None: result = self._document._from_son(result) return result @@ -788,7 +792,8 @@ class QuerySet(object): """ doc_map = {} - docs = self._collection.find({'_id': {'$in': object_ids}}) + docs = self._collection.find({'_id': {'$in': object_ids}}, + **self._cursor_args) for doc in docs: doc_map[doc['_id']] = self._document._from_son(doc) @@ -1085,6 +1090,7 @@ class QuerySet(object): :param enabled: whether or not snapshot mode is enabled """ self._snapshot = enabled + return self def timeout(self, enabled): """Enable or disable the default mongod timeout when querying. @@ -1092,6 +1098,15 @@ class QuerySet(object): :param enabled: whether or not the timeout is used """ self._timeout = enabled + return self + + def slave_okay(self, enabled): + """Enable or disable the slave_okay when querying. + + :param enabled: whether or not the slave_okay is enabled + """ + self._slave_okay = enabled + return self def delete(self, safe=False): """Delete the documents matched by the query. diff --git a/tests/queryset.py b/tests/queryset.py index 28d44861..1947254b 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -422,11 +422,35 @@ class QuerySetTest(unittest.TestCase): person2.save() # Retrieve the first person from the database - person = self.Person.objects(slave_okay=True).first() + person = self.Person.objects.slave_okay(True).first() self.assertTrue(isinstance(person, self.Person)) self.assertEqual(person.name, "User A") self.assertEqual(person.age, 20) + def test_cursor_args(self): + """Ensures the cursor args can be set as expected + """ + p = self.Person.objects + # Check default + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': True}) + + p.snapshot(False).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': False, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(False).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': False, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(False) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': False}) + + p.snapshot(True).slave_okay(True).timeout(True) + self.assertEqual(p._cursor_args, + {'snapshot': True, 'slave_okay': True, 'timeout': True}) + def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """