Compare commits
	
		
			1 Commits
		
	
	
		
			fix-iterat
			...
			simpler-in
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | ea82cb80f6 | 
| @@ -438,7 +438,7 @@ class StrictDict(object): | |||||||
|                 __slots__ = allowed_keys_tuple |                 __slots__ = allowed_keys_tuple | ||||||
|  |  | ||||||
|                 def __repr__(self): |                 def __repr__(self): | ||||||
|                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) |                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) | ||||||
|  |  | ||||||
|             cls._classes[allowed_keys] = SpecificStrictDict |             cls._classes[allowed_keys] = SpecificStrictDict | ||||||
|         return cls._classes[allowed_keys] |         return cls._classes[allowed_keys] | ||||||
|   | |||||||
| @@ -121,7 +121,7 @@ class BaseDocument(object): | |||||||
|                 else: |                 else: | ||||||
|                     self._data[key] = value |                     self._data[key] = value | ||||||
|  |  | ||||||
|         # Set any get_<field>_display methods |         # Set any get_fieldname_display methods | ||||||
|         self.__set_field_display() |         self.__set_field_display() | ||||||
|  |  | ||||||
|         if self._dynamic: |         if self._dynamic: | ||||||
| @@ -1005,18 +1005,19 @@ class BaseDocument(object): | |||||||
|         return '.'.join(parts) |         return '.'.join(parts) | ||||||
|  |  | ||||||
|     def __set_field_display(self): |     def __set_field_display(self): | ||||||
|         """For each field that specifies choices, create a |         """Dynamically set the display value for a field with choices""" | ||||||
|         get_<field>_display method. |         for attr_name, field in self._fields.items(): | ||||||
|         """ |             if field.choices: | ||||||
|         fields_with_choices = [(n, f) for n, f in self._fields.items() |                 if self._dynamic: | ||||||
|                                if f.choices] |                     obj = self | ||||||
|         for attr_name, field in fields_with_choices: |                 else: | ||||||
|             setattr(self, |                     obj = type(self) | ||||||
|                     'get_%s_display' % attr_name, |                 setattr(obj, | ||||||
|                     partial(self.__get_field_display, field=field)) |                         'get_%s_display' % attr_name, | ||||||
|  |                         partial(self.__get_field_display, field=field)) | ||||||
|  |  | ||||||
|     def __get_field_display(self, field): |     def __get_field_display(self, field): | ||||||
|         """Return the display value for a choice field""" |         """Returns the display value for a choice field""" | ||||||
|         value = getattr(self, field.name) |         value = getattr(self, field.name) | ||||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): |         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||||
|             return dict(field.choices).get(value, value) |             return dict(field.choices).get(value, value) | ||||||
|   | |||||||
| @@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.document_type._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if value is not None and not isinstance(value, self.document_type): |         if not isinstance(value, self.document_type): | ||||||
|             value = self.document_type._from_son(value) |             value = self.document_type._from_son(value) | ||||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) |         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|   | |||||||
| @@ -275,8 +275,6 @@ class BaseQuerySet(object): | |||||||
|         except StopIteration: |         except StopIteration: | ||||||
|             return result |             return result | ||||||
|  |  | ||||||
|         # If we were able to retrieve the 2nd doc, rewind the cursor and |  | ||||||
|         # raise the MultipleObjectsReturned exception. |  | ||||||
|         queryset.rewind() |         queryset.rewind() | ||||||
|         message = u'%d items returned, instead of 1' % queryset.count() |         message = u'%d items returned, instead of 1' % queryset.count() | ||||||
|         raise queryset._document.MultipleObjectsReturned(message) |         raise queryset._document.MultipleObjectsReturned(message) | ||||||
| @@ -935,14 +933,6 @@ class BaseQuerySet(object): | |||||||
|         queryset._ordering = queryset._get_order_by(keys) |         queryset._ordering = queryset._get_order_by(keys) | ||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def comment(self, text): |  | ||||||
|         """Add a comment to the query. |  | ||||||
|  |  | ||||||
|         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment |  | ||||||
|         for details. |  | ||||||
|         """ |  | ||||||
|         return self._chainable_method("comment", text) |  | ||||||
|  |  | ||||||
|     def explain(self, format=False): |     def explain(self, format=False): | ||||||
|         """Return an explain plan record for the |         """Return an explain plan record for the | ||||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. |         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||||
|   | |||||||
| @@ -27,10 +27,9 @@ class QuerySet(BaseQuerySet): | |||||||
|         in batches of ``ITER_CHUNK_SIZE``. |         in batches of ``ITER_CHUNK_SIZE``. | ||||||
|  |  | ||||||
|         If ``self._has_more`` the cursor hasn't been exhausted so cache then |         If ``self._has_more`` the cursor hasn't been exhausted so cache then | ||||||
|         batch. Otherwise iterate the result_cache. |         batch.  Otherwise iterate the result_cache. | ||||||
|         """ |         """ | ||||||
|         self._iter = True |         self._iter = True | ||||||
|  |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|             return self._iter_results() |             return self._iter_results() | ||||||
|  |  | ||||||
| @@ -43,12 +42,10 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._len is not None: |         if self._len is not None: | ||||||
|             return self._len |             return self._len | ||||||
|  |  | ||||||
|         # Populate the result cache with *all* of the docs in the cursor |  | ||||||
|         if self._has_more: |         if self._has_more: | ||||||
|  |             # populate the cache | ||||||
|             list(self._iter_results()) |             list(self._iter_results()) | ||||||
|  |  | ||||||
|         # Cache the length of the complete result cache and return it |  | ||||||
|         self._len = len(self._result_cache) |         self._len = len(self._result_cache) | ||||||
|         return self._len |         return self._len | ||||||
|  |  | ||||||
| @@ -67,33 +64,18 @@ class QuerySet(BaseQuerySet): | |||||||
|     def _iter_results(self): |     def _iter_results(self): | ||||||
|         """A generator for iterating over the result cache. |         """A generator for iterating over the result cache. | ||||||
|  |  | ||||||
|         Also populates the cache if there are more possible results to |         Also populates the cache if there are more possible results to yield. | ||||||
|         yield. Raises StopIteration when there are no more results. |         Raises StopIteration when there are no more results""" | ||||||
|         """ |  | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |  | ||||||
|         pos = 0 |         pos = 0 | ||||||
|         while True: |         while True: | ||||||
|  |             upper = len(self._result_cache) | ||||||
|             # For all positions lower than the length of the current result |             while pos < upper: | ||||||
|             # cache, serve the docs straight from the cache w/o hitting the |  | ||||||
|             # database. |  | ||||||
|             # XXX it's VERY important to compute the len within the `while` |  | ||||||
|             # condition because the result cache might expand mid-iteration |  | ||||||
|             # (e.g. if we call len(qs) inside a loop that iterates over the |  | ||||||
|             # queryset). Fortunately len(list) is O(1) in Python, so this |  | ||||||
|             # doesn't cause performance issues. |  | ||||||
|             while pos < len(self._result_cache): |  | ||||||
|                 yield self._result_cache[pos] |                 yield self._result_cache[pos] | ||||||
|                 pos += 1 |                 pos += 1 | ||||||
|  |  | ||||||
|             # Raise StopIteration if we already established there were no more |  | ||||||
|             # docs in the db cursor. |  | ||||||
|             if not self._has_more: |             if not self._has_more: | ||||||
|                 raise StopIteration |                 raise StopIteration | ||||||
|  |  | ||||||
|             # Otherwise, populate more of the cache and repeat. |  | ||||||
|             if len(self._result_cache) <= pos: |             if len(self._result_cache) <= pos: | ||||||
|                 self._populate_cache() |                 self._populate_cache() | ||||||
|  |  | ||||||
| @@ -104,22 +86,12 @@ class QuerySet(BaseQuerySet): | |||||||
|         """ |         """ | ||||||
|         if self._result_cache is None: |         if self._result_cache is None: | ||||||
|             self._result_cache = [] |             self._result_cache = [] | ||||||
|  |         if self._has_more: | ||||||
|         # Skip populating the cache if we already established there are no |             try: | ||||||
|         # more docs to pull from the database. |                 for i in xrange(ITER_CHUNK_SIZE): | ||||||
|         if not self._has_more: |                     self._result_cache.append(self.next()) | ||||||
|             return |             except StopIteration: | ||||||
|  |                 self._has_more = False | ||||||
|         # Pull in ITER_CHUNK_SIZE docs from the database and store them in |  | ||||||
|         # the result cache. |  | ||||||
|         try: |  | ||||||
|             for i in xrange(ITER_CHUNK_SIZE): |  | ||||||
|                 self._result_cache.append(self.next()) |  | ||||||
|         except StopIteration: |  | ||||||
|             # Getting this exception means there are no more docs in the |  | ||||||
|             # db cursor. Set _has_more to False so that we can use that |  | ||||||
|             # information in other places. |  | ||||||
|             self._has_more = False |  | ||||||
|  |  | ||||||
|     def count(self, with_limit_and_skip=False): |     def count(self, with_limit_and_skip=False): | ||||||
|         """Count the selected elements in the query. |         """Count the selected elements in the query. | ||||||
|   | |||||||
| @@ -1047,7 +1047,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|     def test_list_assignment(self): |     def test_list_assignment(self): | ||||||
|         """Ensure that list field element assignment and slicing work |         """Ensure that list field element assignment and slicing work  | ||||||
|         """ |         """ | ||||||
|         class BlogPost(Document): |         class BlogPost(Document): | ||||||
|             info = ListField() |             info = ListField() | ||||||
| @@ -1057,12 +1057,12 @@ class FieldTest(unittest.TestCase): | |||||||
|         post = BlogPost() |         post = BlogPost() | ||||||
|         post.info = ['e1', 'e2', 3, '4', 5] |         post.info = ['e1', 'e2', 3, '4', 5] | ||||||
|         post.save() |         post.save() | ||||||
|  |          | ||||||
|         post.info[0] = 1 |         post.info[0] = 1 | ||||||
|         post.save() |         post.save() | ||||||
|         post.reload() |         post.reload() | ||||||
|         self.assertEqual(post.info[0], 1) |         self.assertEqual(post.info[0], 1) | ||||||
|  |          | ||||||
|         post.info[1:3] = ['n2', 'n3'] |         post.info[1:3] = ['n2', 'n3'] | ||||||
|         post.save() |         post.save() | ||||||
|         post.reload() |         post.reload() | ||||||
| @@ -1209,7 +1209,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         self.assertEqual(simple.widgets, [4]) |         self.assertEqual(simple.widgets, [4]) | ||||||
|  |  | ||||||
|     def test_list_field_with_negative_indices(self): |     def test_list_field_with_negative_indices(self): | ||||||
|  |          | ||||||
|         class Simple(Document): |         class Simple(Document): | ||||||
|             widgets = ListField() |             widgets = ListField() | ||||||
|  |  | ||||||
| @@ -1823,7 +1823,7 @@ class FieldTest(unittest.TestCase): | |||||||
|                                'parent': "50a234ea469ac1eda42d347d"}) |                                'parent': "50a234ea469ac1eda42d347d"}) | ||||||
|         mongoed = p1.to_mongo() |         mongoed = p1.to_mongo() | ||||||
|         self.assertTrue(isinstance(mongoed['parent'], ObjectId)) |         self.assertTrue(isinstance(mongoed['parent'], ObjectId)) | ||||||
|  |          | ||||||
|     def test_cached_reference_field_get_and_save(self): |     def test_cached_reference_field_get_and_save(self): | ||||||
|         """ |         """ | ||||||
|         Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo |         Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo | ||||||
| @@ -1835,11 +1835,11 @@ class FieldTest(unittest.TestCase): | |||||||
|         class Ocorrence(Document): |         class Ocorrence(Document): | ||||||
|             person = StringField() |             person = StringField() | ||||||
|             animal = CachedReferenceField(Animal) |             animal = CachedReferenceField(Animal) | ||||||
|  |          | ||||||
|         Animal.drop_collection() |         Animal.drop_collection() | ||||||
|         Ocorrence.drop_collection() |         Ocorrence.drop_collection() | ||||||
|  |          | ||||||
|         Ocorrence(person="testte", |         Ocorrence(person="testte",  | ||||||
|                   animal=Animal(name="Leopard", tag="heavy").save()).save() |                   animal=Animal(name="Leopard", tag="heavy").save()).save() | ||||||
|         p = Ocorrence.objects.get() |         p = Ocorrence.objects.get() | ||||||
|         p.person = 'new_testte' |         p.person = 'new_testte' | ||||||
| @@ -3001,32 +3001,28 @@ class FieldTest(unittest.TestCase): | |||||||
|                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), |                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), | ||||||
|                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) |                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) | ||||||
|             style = StringField(max_length=3, choices=( |             style = StringField(max_length=3, choices=( | ||||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') |                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') | ||||||
|  |  | ||||||
|         Shirt.drop_collection() |         Shirt.drop_collection() | ||||||
|  |  | ||||||
|         shirt1 = Shirt() |         shirt = Shirt() | ||||||
|         shirt2 = Shirt() |  | ||||||
|  |  | ||||||
|         # Make sure get_<field>_display returns the default value (or None) |         self.assertEqual(shirt.get_size_display(), None) | ||||||
|         self.assertEqual(shirt1.get_size_display(), None) |         self.assertEqual(shirt.get_style_display(), 'Small') | ||||||
|         self.assertEqual(shirt1.get_style_display(), 'Wide') |  | ||||||
|  |  | ||||||
|         shirt1.size = 'XXL' |         shirt.size = "XXL" | ||||||
|         shirt1.style = 'B' |         shirt.style = "B" | ||||||
|         shirt2.size = 'M' |         self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') | ||||||
|         shirt2.style = 'S' |         self.assertEqual(shirt.get_style_display(), 'Baggy') | ||||||
|         self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') |  | ||||||
|         self.assertEqual(shirt1.get_style_display(), 'Baggy') |  | ||||||
|         self.assertEqual(shirt2.get_size_display(), 'Medium') |  | ||||||
|         self.assertEqual(shirt2.get_style_display(), 'Small') |  | ||||||
|  |  | ||||||
|         # Set as Z - an invalid choice |         # Set as Z - an invalid choice | ||||||
|         shirt1.size = 'Z' |         shirt.size = "Z" | ||||||
|         shirt1.style = 'Z' |         shirt.style = "Z" | ||||||
|         self.assertEqual(shirt1.get_size_display(), 'Z') |         self.assertEqual(shirt.get_size_display(), 'Z') | ||||||
|         self.assertEqual(shirt1.get_style_display(), 'Z') |         self.assertEqual(shirt.get_style_display(), 'Z') | ||||||
|         self.assertRaises(ValidationError, shirt1.validate) |         self.assertRaises(ValidationError, shirt.validate) | ||||||
|  |  | ||||||
|  |         Shirt.drop_collection() | ||||||
|  |  | ||||||
|     def test_simple_choices_validation(self): |     def test_simple_choices_validation(self): | ||||||
|         """Ensure that value is in a container of allowed values. |         """Ensure that value is in a container of allowed values. | ||||||
|   | |||||||
| @@ -339,6 +339,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|     def test_update_write_concern(self): |     def test_update_write_concern(self): | ||||||
|         """Test that passing write_concern works""" |         """Test that passing write_concern works""" | ||||||
|  |  | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|         write_concern = {"fsync": True} |         write_concern = {"fsync": True} | ||||||
| @@ -1238,8 +1239,7 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) |             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||||
|  |  | ||||||
|     def test_find_embedded(self): |     def test_find_embedded(self): | ||||||
|         """Ensure that an embedded document is properly returned from |         """Ensure that an embedded document is properly returned from a query. | ||||||
|         a query. |  | ||||||
|         """ |         """ | ||||||
|         class User(EmbeddedDocument): |         class User(EmbeddedDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -1250,31 +1250,16 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         BlogPost.objects.create( |         post = BlogPost(content='Had a good coffee today...') | ||||||
|             author=User(name='Test User'), |         post.author = User(name='Test User') | ||||||
|             content='Had a good coffee today...' |         post.save() | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         result = BlogPost.objects.first() |         result = BlogPost.objects.first() | ||||||
|         self.assertTrue(isinstance(result.author, User)) |         self.assertTrue(isinstance(result.author, User)) | ||||||
|         self.assertEqual(result.author.name, 'Test User') |         self.assertEqual(result.author.name, 'Test User') | ||||||
|  |  | ||||||
|     def test_find_empty_embedded(self): |  | ||||||
|         """Ensure that you can save and find an empty embedded document.""" |  | ||||||
|         class User(EmbeddedDocument): |  | ||||||
|             name = StringField() |  | ||||||
|  |  | ||||||
|         class BlogPost(Document): |  | ||||||
|             content = StringField() |  | ||||||
|             author = EmbeddedDocumentField(User) |  | ||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|         BlogPost.objects.create(content='Anonymous post...') |  | ||||||
|  |  | ||||||
|         result = BlogPost.objects.get(author=None) |  | ||||||
|         self.assertEqual(result.author, None) |  | ||||||
|  |  | ||||||
|     def test_find_dict_item(self): |     def test_find_dict_item(self): | ||||||
|         """Ensure that DictField items may be found. |         """Ensure that DictField items may be found. | ||||||
|         """ |         """ | ||||||
| @@ -2214,21 +2199,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             a.author.name for a in Author.objects.order_by('-author__age')] |             a.author.name for a in Author.objects.order_by('-author__age')] | ||||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) |         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||||
|  |  | ||||||
|     def test_comment(self): |  | ||||||
|         """Make sure adding a comment to the query works.""" |  | ||||||
|         class User(Document): |  | ||||||
|             age = IntField() |  | ||||||
|  |  | ||||||
|         with db_ops_tracker() as q: |  | ||||||
|             adult = (User.objects.filter(age__gte=18) |  | ||||||
|                 .comment('looking for an adult') |  | ||||||
|                 .first()) |  | ||||||
|             ops = q.get_ops() |  | ||||||
|             self.assertEqual(len(ops), 1) |  | ||||||
|             op = ops[0] |  | ||||||
|             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) |  | ||||||
|             self.assertEqual(op['query']['$comment'], 'looking for an adult') |  | ||||||
|  |  | ||||||
|     def test_map_reduce(self): |     def test_map_reduce(self): | ||||||
|         """Ensure map/reduce is both mapping and reducing. |         """Ensure map/reduce is both mapping and reducing. | ||||||
|         """ |         """ | ||||||
| @@ -4890,56 +4860,6 @@ class QuerySetTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(1, Doc.objects(item__type__="axe").count()) |         self.assertEqual(1, Doc.objects(item__type__="axe").count()) | ||||||
|  |  | ||||||
|     def test_len_during_iteration(self): |  | ||||||
|         """Tests that calling len on a queyset during iteration doesn't |  | ||||||
|         stop paging. |  | ||||||
|         """ |  | ||||||
|         class Data(Document): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         for i in xrange(300): |  | ||||||
|             Data().save() |  | ||||||
|  |  | ||||||
|         records = Data.objects.limit(250) |  | ||||||
|  |  | ||||||
|         # This should pull all 250 docs from mongo and populate the result |  | ||||||
|         # cache |  | ||||||
|         len(records) |  | ||||||
|  |  | ||||||
|         # Assert that iterating over documents in the qs touches every |  | ||||||
|         # document even if we call len(qs) midway through the iteration. |  | ||||||
|         for i, r in enumerate(records): |  | ||||||
|             if i == 58: |  | ||||||
|                 len(records) |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|  |  | ||||||
|         # Assert the same behavior is true even if we didn't pre-populate the |  | ||||||
|         # result cache. |  | ||||||
|         records = Data.objects.limit(250) |  | ||||||
|         for i, r in enumerate(records): |  | ||||||
|             if i == 58: |  | ||||||
|                 len(records) |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|  |  | ||||||
|     def test_iteration_within_iteration(self): |  | ||||||
|         """You should be able to reliably iterate over all the documents |  | ||||||
|         in a given queryset even if there are multiple iterations of it |  | ||||||
|         happening at the same time. |  | ||||||
|         """ |  | ||||||
|         class Data(Document): |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|         for i in xrange(300): |  | ||||||
|             Data().save() |  | ||||||
|  |  | ||||||
|         qs = Data.objects.limit(250) |  | ||||||
|         for i, doc in enumerate(qs): |  | ||||||
|             for j, doc2 in enumerate(qs): |  | ||||||
|                 pass |  | ||||||
|  |  | ||||||
|         self.assertEqual(i, 249) |  | ||||||
|         self.assertEqual(j, 249) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| import unittest | import unittest | ||||||
|  | from mongoengine.base.datastructures import StrictDict, SemiStrictDict  | ||||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestStrictDict(unittest.TestCase): | class TestStrictDict(unittest.TestCase): | ||||||
| @@ -14,17 +13,9 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         d = self.dtype(a=1, b=1, c=1) |         d = self.dtype(a=1, b=1, c=1) | ||||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) |         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||||
|  |  | ||||||
|     def test_repr(self): |  | ||||||
|         d = self.dtype(a=1, b=2, c=3) |  | ||||||
|         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') |  | ||||||
|  |  | ||||||
|         # make sure quotes are escaped properly |  | ||||||
|         d = self.dtype(a='"', b="'", c="") |  | ||||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') |  | ||||||
|  |  | ||||||
|     def test_init_fails_on_nonexisting_attrs(self): |     def test_init_fails_on_nonexisting_attrs(self): | ||||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) |         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||||
|  |          | ||||||
|     def test_eq(self): |     def test_eq(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1) |         d = self.dtype(a=1, b=1, c=1) | ||||||
|         dd = self.dtype(a=1, b=1, c=1) |         dd = self.dtype(a=1, b=1, c=1) | ||||||
| @@ -33,7 +24,7 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) |         g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) | ||||||
|         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) |         h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) | ||||||
|         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) |         i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) | ||||||
|  |          | ||||||
|         self.assertEqual(d, dd) |         self.assertEqual(d, dd) | ||||||
|         self.assertNotEqual(d, e) |         self.assertNotEqual(d, e) | ||||||
|         self.assertNotEqual(d, f) |         self.assertNotEqual(d, f) | ||||||
| @@ -47,19 +38,19 @@ class TestStrictDict(unittest.TestCase): | |||||||
|         d.a = 1 |         d.a = 1 | ||||||
|         self.assertEqual(d.a, 1) |         self.assertEqual(d.a, 1) | ||||||
|         self.assertRaises(AttributeError, lambda: d.b) |         self.assertRaises(AttributeError, lambda: d.b) | ||||||
|  |      | ||||||
|     def test_setattr_raises_on_nonexisting_attr(self): |     def test_setattr_raises_on_nonexisting_attr(self): | ||||||
|         d = self.dtype() |         d = self.dtype() | ||||||
|  |  | ||||||
|         def _f(): |         def _f(): | ||||||
|             d.x = 1 |             d.x = 1 | ||||||
|         self.assertRaises(AttributeError, _f) |         self.assertRaises(AttributeError, _f) | ||||||
|  |      | ||||||
|     def test_setattr_getattr_special(self): |     def test_setattr_getattr_special(self): | ||||||
|         d = self.strict_dict_class(["items"]) |         d = self.strict_dict_class(["items"]) | ||||||
|         d.items = 1 |         d.items = 1 | ||||||
|         self.assertEqual(d.items, 1) |         self.assertEqual(d.items, 1) | ||||||
|  |      | ||||||
|     def test_get(self): |     def test_get(self): | ||||||
|         d = self.dtype(a=1) |         d = self.dtype(a=1) | ||||||
|         self.assertEqual(d.get('a'), 1) |         self.assertEqual(d.get('a'), 1) | ||||||
| @@ -97,7 +88,7 @@ class TestSemiSrictDict(TestStrictDict): | |||||||
|     def test_init_succeeds_with_nonexisting_attrs(self): |     def test_init_succeeds_with_nonexisting_attrs(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1, x=2) |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) |         self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) | ||||||
|  |     | ||||||
|     def test_iter_with_nonexisting_attrs(self): |     def test_iter_with_nonexisting_attrs(self): | ||||||
|         d = self.dtype(a=1, b=1, c=1, x=2) |         d = self.dtype(a=1, b=1, c=1, x=2) | ||||||
|         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) |         self.assertEqual(list(d), ['a', 'b', 'c', 'x']) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user