Updated slave_okay syntax
Now inline with .timeout() and .snapshot(). Made them chainable - so its easier to use and added tests for cursor_args
This commit is contained in:
		| @@ -452,7 +452,6 @@ class QuerySet(object): | |||||||
|         self._mongo_query = None |         self._mongo_query = None | ||||||
|         self._cursor_obj = None |         self._cursor_obj = None | ||||||
|         self._class_check = class_check |         self._class_check = class_check | ||||||
|         self._slave_okay = slave_okay |  | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def filter(self, *q_objs, **query): |     def filter(self, *q_objs, **query): | ||||||
| @@ -504,18 +503,23 @@ class QuerySet(object): | |||||||
|  |  | ||||||
|         return self._collection_obj |         return self._collection_obj | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def _cursor_args(self): | ||||||
|  |         cursor_args = { | ||||||
|  |             'snapshot': self._snapshot, | ||||||
|  |             'timeout': self._timeout, | ||||||
|  |             'slave_okay': self._slave_okay | ||||||
|  |         } | ||||||
|  |         if self._loaded_fields: | ||||||
|  |             cursor_args['fields'] = self._loaded_fields.as_dict() | ||||||
|  |         return cursor_args | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def _cursor(self): |     def _cursor(self): | ||||||
|         if self._cursor_obj is None: |         if self._cursor_obj is None: | ||||||
|             cursor_args = { |  | ||||||
|                 'snapshot': self._snapshot, |  | ||||||
|                 'timeout': self._timeout, |  | ||||||
|                 'slave_okay': self._slave_okay |  | ||||||
|             } |  | ||||||
|             if self._loaded_fields: |  | ||||||
|                 cursor_args['fields'] = self._loaded_fields.as_dict() |  | ||||||
|             self._cursor_obj = self._collection.find(self._query, |             self._cursor_obj = self._collection.find(self._query, | ||||||
|                                                      **cursor_args) |                                                      **self._cursor_args) | ||||||
|             # Apply where clauses to cursor |             # Apply where clauses to cursor | ||||||
|             if self._where_clause: |             if self._where_clause: | ||||||
|                 self._cursor_obj.where(self._where_clause) |                 self._cursor_obj.where(self._where_clause) | ||||||
| @@ -772,7 +776,7 @@ class QuerySet(object): | |||||||
|         id_field = self._document._meta['id_field'] |         id_field = self._document._meta['id_field'] | ||||||
|         object_id = self._document._fields[id_field].to_mongo(object_id) |         object_id = self._document._fields[id_field].to_mongo(object_id) | ||||||
|  |  | ||||||
|         result = self._collection.find_one({'_id': object_id}) |         result = self._collection.find_one({'_id': object_id}, **self._cursor_args) | ||||||
|         if result is not None: |         if result is not None: | ||||||
|             result = self._document._from_son(result) |             result = self._document._from_son(result) | ||||||
|         return result |         return result | ||||||
| @@ -788,7 +792,8 @@ class QuerySet(object): | |||||||
|         """ |         """ | ||||||
|         doc_map = {} |         doc_map = {} | ||||||
|  |  | ||||||
|         docs = self._collection.find({'_id': {'$in': object_ids}}) |         docs = self._collection.find({'_id': {'$in': object_ids}}, | ||||||
|  |                                      **self._cursor_args) | ||||||
|         for doc in docs: |         for doc in docs: | ||||||
|             doc_map[doc['_id']] = self._document._from_son(doc) |             doc_map[doc['_id']] = self._document._from_son(doc) | ||||||
|  |  | ||||||
| @@ -1085,6 +1090,7 @@ class QuerySet(object): | |||||||
|         :param enabled: whether or not snapshot mode is enabled |         :param enabled: whether or not snapshot mode is enabled | ||||||
|         """ |         """ | ||||||
|         self._snapshot = enabled |         self._snapshot = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def timeout(self, enabled): |     def timeout(self, enabled): | ||||||
|         """Enable or disable the default mongod timeout when querying. |         """Enable or disable the default mongod timeout when querying. | ||||||
| @@ -1092,6 +1098,15 @@ class QuerySet(object): | |||||||
|         :param enabled: whether or not the timeout is used |         :param enabled: whether or not the timeout is used | ||||||
|         """ |         """ | ||||||
|         self._timeout = enabled |         self._timeout = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def slave_okay(self, enabled): | ||||||
|  |         """Enable or disable the slave_okay when querying. | ||||||
|  |  | ||||||
|  |         :param enabled: whether or not the slave_okay is enabled | ||||||
|  |         """ | ||||||
|  |         self._slave_okay = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def delete(self, safe=False): |     def delete(self, safe=False): | ||||||
|         """Delete the documents matched by the query. |         """Delete the documents matched by the query. | ||||||
|   | |||||||
| @@ -422,11 +422,35 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         person2.save() |         person2.save() | ||||||
|  |  | ||||||
|         # Retrieve the first person from the database |         # Retrieve the first person from the database | ||||||
|         person = self.Person.objects(slave_okay=True).first() |         person = self.Person.objects.slave_okay(True).first() | ||||||
|         self.assertTrue(isinstance(person, self.Person)) |         self.assertTrue(isinstance(person, self.Person)) | ||||||
|         self.assertEqual(person.name, "User A") |         self.assertEqual(person.name, "User A") | ||||||
|         self.assertEqual(person.age, 20) |         self.assertEqual(person.age, 20) | ||||||
|  |  | ||||||
|  |     def test_cursor_args(self): | ||||||
|  |         """Ensures the cursor args can be set as expected | ||||||
|  |         """ | ||||||
|  |         p = self.Person.objects | ||||||
|  |         # Check default | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': False, 'slave_okay': False, 'timeout': True}) | ||||||
|  |  | ||||||
|  |         p.snapshot(False).slave_okay(False).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': False, 'slave_okay': False, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(False).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': False, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(True).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': True, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(True).timeout(True) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': True, 'timeout': True}) | ||||||
|  |  | ||||||
|     def test_repeated_iteration(self): |     def test_repeated_iteration(self): | ||||||
|         """Ensure that QuerySet rewinds itself one iteration finishes. |         """Ensure that QuerySet rewinds itself one iteration finishes. | ||||||
|         """ |         """ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user