fix-#399: Not overriding default values when loading a subset of fields
This commit is contained in:
		| @@ -55,6 +55,10 @@ class BaseDocument(object): | |||||||
|                         "Multiple values for keyword argument '" + name + "'") |                         "Multiple values for keyword argument '" + name + "'") | ||||||
|                 values[name] = value |                 values[name] = value | ||||||
|         __auto_convert = values.pop("__auto_convert", True) |         __auto_convert = values.pop("__auto_convert", True) | ||||||
|  |  | ||||||
|  |         # 399: set default values only to fields loaded from DB | ||||||
|  |         __only_fields = set(values.pop("__only_fields", values)) | ||||||
|  |  | ||||||
|         signals.pre_init.send(self.__class__, document=self, values=values) |         signals.pre_init.send(self.__class__, document=self, values=values) | ||||||
|  |  | ||||||
|         if self.STRICT and not self._dynamic: |         if self.STRICT and not self._dynamic: | ||||||
| @@ -69,7 +73,7 @@ class BaseDocument(object): | |||||||
|  |  | ||||||
|         # Assign default values to instance |         # Assign default values to instance | ||||||
|         for key, field in self._fields.iteritems(): |         for key, field in self._fields.iteritems(): | ||||||
|             if self._db_field_map.get(key, key) in values: |             if self._db_field_map.get(key, key) in __only_fields: | ||||||
|                 continue |                 continue | ||||||
|             value = getattr(self, key, None) |             value = getattr(self, key, None) | ||||||
|             setattr(self, key, value) |             setattr(self, key, value) | ||||||
| @@ -610,7 +614,7 @@ class BaseDocument(object): | |||||||
|         return cls._meta.get('collection', None) |         return cls._meta.get('collection', None) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _from_son(cls, son, _auto_dereference=True): |     def _from_son(cls, son, _auto_dereference=True, only_fields=[]): | ||||||
|         """Create an instance of a Document (subclass) from a PyMongo SON. |         """Create an instance of a Document (subclass) from a PyMongo SON. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
| @@ -658,10 +662,11 @@ class BaseDocument(object): | |||||||
|         if cls.STRICT: |         if cls.STRICT: | ||||||
|             data = dict((k, v) |             data = dict((k, v) | ||||||
|                         for k, v in data.iteritems() if k in cls._fields) |                         for k, v in data.iteritems() if k in cls._fields) | ||||||
|         obj = cls(__auto_convert=False, _created=False, **data) |         obj = cls(__auto_convert=False, _created=False, __only_fields=only_fields, **data) | ||||||
|         obj._changed_fields = changed_fields |         obj._changed_fields = changed_fields | ||||||
|         if not _auto_dereference: |         if not _auto_dereference: | ||||||
|             obj._fields = fields |             obj._fields = fields | ||||||
|  |  | ||||||
|         return obj |         return obj | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|   | |||||||
| @@ -81,6 +81,7 @@ class BaseQuerySet(object): | |||||||
|         self._limit = None |         self._limit = None | ||||||
|         self._skip = None |         self._skip = None | ||||||
|         self._hint = -1  # Using -1 as None is a valid value for hint |         self._hint = -1  # Using -1 as None is a valid value for hint | ||||||
|  |         self.only_fields = [] | ||||||
|  |  | ||||||
|     def __call__(self, q_obj=None, class_check=True, slave_okay=False, |     def __call__(self, q_obj=None, class_check=True, slave_okay=False, | ||||||
|                  read_preference=None, **query): |                  read_preference=None, **query): | ||||||
| @@ -151,12 +152,13 @@ class BaseQuerySet(object): | |||||||
|             if queryset._scalar: |             if queryset._scalar: | ||||||
|                 return queryset._get_scalar( |                 return queryset._get_scalar( | ||||||
|                     queryset._document._from_son(queryset._cursor[key], |                     queryset._document._from_son(queryset._cursor[key], | ||||||
|                                                  _auto_dereference=self._auto_dereference)) |                                                  _auto_dereference=self._auto_dereference, | ||||||
|  |                                                  only_fields=self.only_fields)) | ||||||
|  |  | ||||||
|             if queryset._as_pymongo: |             if queryset._as_pymongo: | ||||||
|                 return queryset._get_as_pymongo(queryset._cursor[key]) |                 return queryset._get_as_pymongo(queryset._cursor[key]) | ||||||
|             return queryset._document._from_son(queryset._cursor[key], |             return queryset._document._from_son(queryset._cursor[key], | ||||||
|                                                 _auto_dereference=self._auto_dereference) |                                                 _auto_dereference=self._auto_dereference, only_fields=self.only_fields) | ||||||
|         raise AttributeError |         raise AttributeError | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
| @@ -570,10 +572,10 @@ class BaseQuerySet(object): | |||||||
|  |  | ||||||
|         if full_response: |         if full_response: | ||||||
|             if result["value"] is not None: |             if result["value"] is not None: | ||||||
|                 result["value"] = self._document._from_son(result["value"]) |                 result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields) | ||||||
|         else: |         else: | ||||||
|             if result is not None: |             if result is not None: | ||||||
|                 result = self._document._from_son(result) |                 result = self._document._from_son(result, only_fields=self.only_fields) | ||||||
|  |  | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
| @@ -608,13 +610,13 @@ class BaseQuerySet(object): | |||||||
|         if self._scalar: |         if self._scalar: | ||||||
|             for doc in docs: |             for doc in docs: | ||||||
|                 doc_map[doc['_id']] = self._get_scalar( |                 doc_map[doc['_id']] = self._get_scalar( | ||||||
|                     self._document._from_son(doc)) |                     self._document._from_son(doc, only_fields=self.only_fields)) | ||||||
|         elif self._as_pymongo: |         elif self._as_pymongo: | ||||||
|             for doc in docs: |             for doc in docs: | ||||||
|                 doc_map[doc['_id']] = self._get_as_pymongo(doc) |                 doc_map[doc['_id']] = self._get_as_pymongo(doc) | ||||||
|         else: |         else: | ||||||
|             for doc in docs: |             for doc in docs: | ||||||
|                 doc_map[doc['_id']] = self._document._from_son(doc) |                 doc_map[doc['_id']] = self._document._from_son(doc, only_fields=self.only_fields) | ||||||
|  |  | ||||||
|         return doc_map |         return doc_map | ||||||
|  |  | ||||||
| @@ -667,7 +669,7 @@ class BaseQuerySet(object): | |||||||
|                       '_timeout', '_class_check', '_slave_okay', '_read_preference', |                       '_timeout', '_class_check', '_slave_okay', '_read_preference', | ||||||
|                       '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', |                       '_iter', '_scalar', '_as_pymongo', '_as_pymongo_coerce', | ||||||
|                       '_limit', '_skip', '_hint', '_auto_dereference', |                       '_limit', '_skip', '_hint', '_auto_dereference', | ||||||
|                       '_search_text', '_include_text_scores') |                       '_search_text', '_include_text_scores', 'only_fields') | ||||||
|  |  | ||||||
|         for prop in copy_props: |         for prop in copy_props: | ||||||
|             val = getattr(self, prop) |             val = getattr(self, prop) | ||||||
| @@ -785,6 +787,7 @@ class BaseQuerySet(object): | |||||||
|         .. versionchanged:: 0.5 - Added subfield support |         .. versionchanged:: 0.5 - Added subfield support | ||||||
|         """ |         """ | ||||||
|         fields = dict([(f, QueryFieldList.ONLY) for f in fields]) |         fields = dict([(f, QueryFieldList.ONLY) for f in fields]) | ||||||
|  |         self.only_fields = fields.keys() | ||||||
|         return self.fields(True, **fields) |         return self.fields(True, **fields) | ||||||
|  |  | ||||||
|     def exclude(self, *fields): |     def exclude(self, *fields): | ||||||
| @@ -972,7 +975,7 @@ class BaseQuerySet(object): | |||||||
|     def from_json(self, json_data): |     def from_json(self, json_data): | ||||||
|         """Converts json data to unsaved objects""" |         """Converts json data to unsaved objects""" | ||||||
|         son_data = json_util.loads(json_data) |         son_data = json_util.loads(json_data) | ||||||
|         return [self._document._from_son(data) for data in son_data] |         return [self._document._from_son(data, only_fields=self.only_fields) for data in son_data] | ||||||
|  |  | ||||||
|     def aggregate(self, *pipeline, **kwargs): |     def aggregate(self, *pipeline, **kwargs): | ||||||
|         """ |         """ | ||||||
| @@ -1324,7 +1327,7 @@ class BaseQuerySet(object): | |||||||
|         if self._as_pymongo: |         if self._as_pymongo: | ||||||
|             return self._get_as_pymongo(raw_doc) |             return self._get_as_pymongo(raw_doc) | ||||||
|         doc = self._document._from_son(raw_doc, |         doc = self._document._from_son(raw_doc, | ||||||
|                                        _auto_dereference=self._auto_dereference) |                                        _auto_dereference=self._auto_dereference, only_fields=self.only_fields) | ||||||
|         if self._scalar: |         if self._scalar: | ||||||
|             return self._get_scalar(doc) |             return self._get_scalar(doc) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2618,5 +2618,20 @@ class InstanceTest(unittest.TestCase): | |||||||
|         self.assertTrue(obj3 != dbref2) |         self.assertTrue(obj3 != dbref2) | ||||||
|         self.assertTrue(dbref2 != obj3) |         self.assertTrue(dbref2 != obj3) | ||||||
|  |  | ||||||
|  |     def test_default_values(self): | ||||||
|  |         class Person(Document): | ||||||
|  |             created_on = DateTimeField(default=lambda: datetime.utcnow()) | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         p = Person(name='alon') | ||||||
|  |         p.save() | ||||||
|  |         orig_created_on = Person.objects().only('created_on')[0].created_on | ||||||
|  |  | ||||||
|  |         p2 = Person.objects().only('name')[0] | ||||||
|  |         p2.name = 'alon2' | ||||||
|  |         p2.save() | ||||||
|  |         p3 = Person.objects().only('created_on')[0] | ||||||
|  |         self.assertEquals(orig_created_on, p3.created_on) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user