QuerySet.only field name translation and polymorphism fix
This commit is contained in:
		| @@ -456,7 +456,18 @@ class QuerySet(object): | ||||
|          | ||||
|         :param *fields: fields to include | ||||
|         """ | ||||
|         self._loaded_fields = list(fields) | ||||
|         self._loaded_fields = [] | ||||
|         for field in fields: | ||||
|             if '.' in field: | ||||
|                 raise InvalidQueryError('Subfields cannot be used as ' | ||||
|                                         'arguments to QuerySet.only') | ||||
|             # Translate field name | ||||
|             field_name = QuerySet._lookup_field(self._document, field)[-1].name | ||||
|             self._loaded_fields.append(field_name) | ||||
|  | ||||
|         # _cls is needed for polymorphism | ||||
|         if self._document._meta.get('allow_inheritance'): | ||||
|             self._loaded_fields += ['_cls'] | ||||
|         return self | ||||
|  | ||||
|     def order_by(self, *keys): | ||||
|   | ||||
| @@ -258,6 +258,39 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|     def test_only(self): | ||||
|         """Ensure that QuerySet.only only returns the requested fields. | ||||
|         """ | ||||
|         person = self.Person(name='test', age=25) | ||||
|         person.save() | ||||
|  | ||||
|         obj = self.Person.objects.only('name').get() | ||||
|         self.assertEqual(obj.name, person.name) | ||||
|         self.assertEqual(obj.age, None) | ||||
|  | ||||
|         obj = self.Person.objects.only('age').get() | ||||
|         self.assertEqual(obj.name, None) | ||||
|         self.assertEqual(obj.age, person.age) | ||||
|  | ||||
|         obj = self.Person.objects.only('name', 'age').get() | ||||
|         self.assertEqual(obj.name, person.name) | ||||
|         self.assertEqual(obj.age, person.age) | ||||
|  | ||||
|         # Check polymorphism still works | ||||
|         class Employee(self.Person): | ||||
|             salary = IntField(name='wage') | ||||
|  | ||||
|         employee = Employee(name='test employee', age=40, salary=30000) | ||||
|         employee.save() | ||||
|  | ||||
|         obj = self.Person.objects(id=employee.id).only('age').get() | ||||
|         self.assertTrue(isinstance(obj, Employee)) | ||||
|  | ||||
|         # Check field names are looked up properly | ||||
|         obj = Employee.objects(id=employee.id).only('salary').get() | ||||
|         self.assertEqual(obj.salary, employee.salary) | ||||
|         self.assertEqual(obj.name, None) | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from a query. | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user