Compare commits
	
		
			12 Commits
		
	
	
		
			v0.10.7
			...
			fix-strict
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | c78e5079d4 | ||
|  | ab69e50361 | ||
|  | eb743beaa3 | ||
|  | 0007535a46 | ||
|  | 8391af026c | ||
|  | 800f656dcf | ||
|  | 088c5f49d9 | ||
|  | d8d98b6143 | ||
|  | 02fb3b9315 | ||
|  | 4f87db784e | ||
|  | 7e6287b925 | ||
|  | 999cdfd997 | 
| @@ -52,10 +52,14 @@ Some simple examples of what MongoEngine code looks like: | ||||
|  | ||||
| .. code :: python | ||||
|  | ||||
|     from mongoengine import * | ||||
|     connect('mydb') | ||||
|  | ||||
|     class BlogPost(Document): | ||||
|         title = StringField(required=True, max_length=200) | ||||
|         posted = DateTimeField(default=datetime.datetime.now) | ||||
|         tags = ListField(StringField(max_length=50)) | ||||
|         meta = {'allow_inheritance': True} | ||||
|  | ||||
|     class TextPost(BlogPost): | ||||
|         content = StringField(required=True) | ||||
|   | ||||
| @@ -4,7 +4,9 @@ Changelog | ||||
|  | ||||
| Changes in 0.10.8 | ||||
| ================= | ||||
| - Fill this in as PRs for v0.10.8 are merged | ||||
| - Added ability to specify an authentication mechanism (e.g. X.509) #1333 | ||||
| - Added support for falsey primary keys (e.g. doc.pk = 0) #1354 | ||||
| - Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417 | ||||
|  | ||||
| Changes in 0.10.7 | ||||
| ================= | ||||
|   | ||||
| @@ -438,7 +438,7 @@ class StrictDict(object): | ||||
|                 __slots__ = allowed_keys_tuple | ||||
|  | ||||
|                 def __repr__(self): | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys()) | ||||
|                     return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) | ||||
|  | ||||
|             cls._classes[allowed_keys] = SpecificStrictDict | ||||
|         return cls._classes[allowed_keys] | ||||
|   | ||||
| @@ -121,7 +121,7 @@ class BaseDocument(object): | ||||
|                 else: | ||||
|                     self._data[key] = value | ||||
|  | ||||
|         # Set any get_fieldname_display methods | ||||
|         # Set any get_<field>_display methods | ||||
|         self.__set_field_display() | ||||
|  | ||||
|         if self._dynamic: | ||||
| @@ -1005,19 +1005,18 @@ class BaseDocument(object): | ||||
|         return '.'.join(parts) | ||||
|  | ||||
|     def __set_field_display(self): | ||||
|         """Dynamically set the display value for a field with choices""" | ||||
|         for attr_name, field in self._fields.items(): | ||||
|             if field.choices: | ||||
|                 if self._dynamic: | ||||
|                     obj = self | ||||
|                 else: | ||||
|                     obj = type(self) | ||||
|                 setattr(obj, | ||||
|                         'get_%s_display' % attr_name, | ||||
|                         partial(self.__get_field_display, field=field)) | ||||
|         """For each field that specifies choices, create a | ||||
|         get_<field>_display method. | ||||
|         """ | ||||
|         fields_with_choices = [(n, f) for n, f in self._fields.items() | ||||
|                                if f.choices] | ||||
|         for attr_name, field in fields_with_choices: | ||||
|             setattr(self, | ||||
|                     'get_%s_display' % attr_name, | ||||
|                     partial(self.__get_field_display, field=field)) | ||||
|  | ||||
|     def __get_field_display(self, field): | ||||
|         """Returns the display value for a choice field""" | ||||
|         """Return the display value for a choice field""" | ||||
|         value = getattr(self, field.name) | ||||
|         if field.choices and isinstance(field.choices[0], (list, tuple)): | ||||
|             return dict(field.choices).get(value, value) | ||||
|   | ||||
| @@ -6,6 +6,7 @@ __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
|  | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
|  | ||||
| if IS_PYMONGO_3: | ||||
|     READ_PREFERENCE = ReadPreference.PRIMARY | ||||
| else: | ||||
| @@ -25,6 +26,7 @@ _dbs = {} | ||||
| def register_connection(alias, name=None, host=None, port=None, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         authentication_mechanism=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
|  | ||||
| @@ -38,6 +40,9 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|     :param username: username to authenticate with | ||||
|     :param password: password to authenticate with | ||||
|     :param authentication_source: database to authenticate against | ||||
|     :param authentication_mechanism: database authentication mechanisms. | ||||
|         By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, | ||||
|         MONGODB-CR (MongoDB Challenge Response protocol) for older servers. | ||||
|     :param is_mock: explicitly use mongomock for this connection | ||||
|         (can also be done by using `mongomock://` as db host prefix) | ||||
|     :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver | ||||
| @@ -53,9 +58,11 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|         'read_preference': read_preference, | ||||
|         'username': username, | ||||
|         'password': password, | ||||
|         'authentication_source': authentication_source | ||||
|         'authentication_source': authentication_source, | ||||
|         'authentication_mechanism': authentication_mechanism | ||||
|     } | ||||
|  | ||||
|     # Handle uri style connections | ||||
|     conn_host = conn_settings['host'] | ||||
|     # host can be a list or a string, so if string, force to a list | ||||
|     if isinstance(conn_host, str_types): | ||||
| @@ -82,6 +89,8 @@ def register_connection(alias, name=None, host=None, port=None, | ||||
|                 conn_settings['replicaSet'] = True | ||||
|             if 'authsource' in uri_options: | ||||
|                 conn_settings['authentication_source'] = uri_options['authsource'] | ||||
|             if 'authmechanism' in uri_options: | ||||
|                 conn_settings['authentication_mechanism'] = uri_options['authmechanism'] | ||||
|         else: | ||||
|             resolved_hosts.append(entity) | ||||
|     conn_settings['host'] = resolved_hosts | ||||
| @@ -123,6 +132,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         conn_settings.pop('username', None) | ||||
|         conn_settings.pop('password', None) | ||||
|         conn_settings.pop('authentication_source', None) | ||||
|         conn_settings.pop('authentication_mechanism', None) | ||||
|  | ||||
|         is_mock = conn_settings.pop('is_mock', None) | ||||
|         if is_mock: | ||||
| @@ -157,6 +167,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|                 connection_settings.pop('username', None) | ||||
|                 connection_settings.pop('password', None) | ||||
|                 connection_settings.pop('authentication_source', None) | ||||
|                 connection_settings.pop('authentication_mechanism', None) | ||||
|                 if conn_settings == connection_settings and _connections.get(db_alias, None): | ||||
|                     connection = _connections[db_alias] | ||||
|                     break | ||||
| @@ -176,11 +187,13 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|         conn = get_connection(alias) | ||||
|         conn_settings = _connection_settings[alias] | ||||
|         db = conn[conn_settings['name']] | ||||
|         auth_kwargs = {'source': conn_settings['authentication_source']} | ||||
|         if conn_settings['authentication_mechanism'] is not None: | ||||
|             auth_kwargs['mechanism'] = conn_settings['authentication_mechanism'] | ||||
|         # Authenticate if necessary | ||||
|         if conn_settings['username'] and conn_settings['password']: | ||||
|             db.authenticate(conn_settings['username'], | ||||
|                             conn_settings['password'], | ||||
|                             source=conn_settings['authentication_source']) | ||||
|         if conn_settings['username'] and (conn_settings['password'] or | ||||
|                                           conn_settings['authentication_mechanism'] == 'MONGODB-X509'): | ||||
|             db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs) | ||||
|         _dbs[alias] = db | ||||
|     return _dbs[alias] | ||||
|  | ||||
|   | ||||
| @@ -472,7 +472,7 @@ class Document(BaseDocument): | ||||
|         Raises :class:`OperationError` if called on an object that has not yet | ||||
|         been saved. | ||||
|         """ | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             if kwargs.get('upsert', False): | ||||
|                 query = self.to_mongo() | ||||
|                 if "_cls" in query: | ||||
| @@ -604,7 +604,7 @@ class Document(BaseDocument): | ||||
|         elif "max_depth" in kwargs: | ||||
|             max_depth = kwargs["max_depth"] | ||||
|  | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             raise self.DoesNotExist("Document does not exist") | ||||
|         obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( | ||||
|             **self._object_key).only(*fields).limit( | ||||
| @@ -655,7 +655,7 @@ class Document(BaseDocument): | ||||
|     def to_dbref(self): | ||||
|         """Returns an instance of :class:`~bson.dbref.DBRef` useful in | ||||
|         `__raw__` queries.""" | ||||
|         if not self.pk: | ||||
|         if self.pk is None: | ||||
|             msg = "Only saved documents can have a valid dbref" | ||||
|             raise OperationError(msg) | ||||
|         return DBRef(self.__class__._get_collection_name(), self.pk) | ||||
|   | ||||
| @@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField): | ||||
|         return self.document_type._fields.get(member_name) | ||||
|  | ||||
|     def prepare_query_value(self, op, value): | ||||
|         if not isinstance(value, self.document_type): | ||||
|         if value is not None and not isinstance(value, self.document_type): | ||||
|             value = self.document_type._from_son(value) | ||||
|         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||
|         return self.to_mongo(value) | ||||
|   | ||||
| @@ -933,6 +933,14 @@ class BaseQuerySet(object): | ||||
|         queryset._ordering = queryset._get_order_by(keys) | ||||
|         return queryset | ||||
|  | ||||
|     def comment(self, text): | ||||
|         """Add a comment to the query. | ||||
|  | ||||
|         See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment | ||||
|         for details. | ||||
|         """ | ||||
|         return self._chainable_method("comment", text) | ||||
|  | ||||
|     def explain(self, format=False): | ||||
|         """Return an explain plan record for the | ||||
|         :class:`~mongoengine.queryset.QuerySet`\ 's cursor. | ||||
| @@ -1271,9 +1279,10 @@ class BaseQuerySet(object): | ||||
|         :param field: the field to sum over; use dot notation to refer to | ||||
|             embedded document fields | ||||
|         """ | ||||
|         db_field = self._fields_to_dbfields([field]).pop() | ||||
|         pipeline = [ | ||||
|             {'$match': self._query}, | ||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}} | ||||
|             {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} | ||||
|         ] | ||||
|  | ||||
|         # if we're performing a sum over a list field, we sum up all the | ||||
| @@ -1300,9 +1309,10 @@ class BaseQuerySet(object): | ||||
|         :param field: the field to average over; use dot notation to refer to | ||||
|             embedded document fields | ||||
|         """ | ||||
|         db_field = self._fields_to_dbfields([field]).pop() | ||||
|         pipeline = [ | ||||
|             {'$match': self._query}, | ||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}} | ||||
|             {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} | ||||
|         ] | ||||
|  | ||||
|         # if we're performing an average over a list field, we average out | ||||
|   | ||||
| @@ -2,10 +2,8 @@ | ||||
| import unittest | ||||
| import sys | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| import pymongo | ||||
| from random import randint | ||||
|  | ||||
| from nose.plugins.skip import SkipTest | ||||
| from datetime import datetime | ||||
| @@ -17,11 +15,9 @@ __all__ = ("IndexesTest", ) | ||||
|  | ||||
|  | ||||
| class IndexesTest(unittest.TestCase): | ||||
|     _MAX_RAND = 10 ** 10 | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND)) | ||||
|         self.connection = connect(db=self.db_name) | ||||
|         self.connection = connect(db='mongoenginetest') | ||||
|         self.db = get_db() | ||||
|  | ||||
|         class Person(Document): | ||||
|   | ||||
| @@ -3202,5 +3202,20 @@ class InstanceTest(unittest.TestCase): | ||||
|             self.assertEqual(b._instance, a) | ||||
|         self.assertEqual(idx, 2) | ||||
|  | ||||
|     def test_falsey_pk(self): | ||||
|         """Ensure that we can create and update a document with Falsey PK. | ||||
|         """ | ||||
|         class Person(Document): | ||||
|             age = IntField(primary_key=True) | ||||
|             height = FloatField() | ||||
|  | ||||
|         person = Person() | ||||
|         person.age = 0 | ||||
|         person.height = 1.89 | ||||
|         person.save() | ||||
|  | ||||
|         person.update(set__height=2.0) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase): | ||||
|                 ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), | ||||
|                 ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) | ||||
|             style = StringField(max_length=3, choices=( | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') | ||||
|                 ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|  | ||||
|         shirt = Shirt() | ||||
|         shirt1 = Shirt() | ||||
|         shirt2 = Shirt() | ||||
|  | ||||
|         self.assertEqual(shirt.get_size_display(), None) | ||||
|         self.assertEqual(shirt.get_style_display(), 'Small') | ||||
|         # Make sure get_<field>_display returns the default value (or None) | ||||
|         self.assertEqual(shirt1.get_size_display(), None) | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Wide') | ||||
|  | ||||
|         shirt.size = "XXL" | ||||
|         shirt.style = "B" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Baggy') | ||||
|         shirt1.size = 'XXL' | ||||
|         shirt1.style = 'B' | ||||
|         shirt2.size = 'M' | ||||
|         shirt2.style = 'S' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Baggy') | ||||
|         self.assertEqual(shirt2.get_size_display(), 'Medium') | ||||
|         self.assertEqual(shirt2.get_style_display(), 'Small') | ||||
|  | ||||
|         # Set as Z - an invalid choice | ||||
|         shirt.size = "Z" | ||||
|         shirt.style = "Z" | ||||
|         self.assertEqual(shirt.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt.validate) | ||||
|  | ||||
|         Shirt.drop_collection() | ||||
|         shirt1.size = 'Z' | ||||
|         shirt1.style = 'Z' | ||||
|         self.assertEqual(shirt1.get_size_display(), 'Z') | ||||
|         self.assertEqual(shirt1.get_style_display(), 'Z') | ||||
|         self.assertRaises(ValidationError, shirt1.validate) | ||||
|  | ||||
|     def test_simple_choices_validation(self): | ||||
|         """Ensure that value is in a container of allowed values. | ||||
|   | ||||
| @@ -339,7 +339,6 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|     def test_update_write_concern(self): | ||||
|         """Test that passing write_concern works""" | ||||
|  | ||||
|         self.Person.drop_collection() | ||||
|  | ||||
|         write_concern = {"fsync": True} | ||||
| @@ -1239,7 +1238,8 @@ class QuerySetTest(unittest.TestCase): | ||||
|             self.assertFalse('$orderby' in q.get_ops()[0]['query']) | ||||
|  | ||||
|     def test_find_embedded(self): | ||||
|         """Ensure that an embedded document is properly returned from a query. | ||||
|         """Ensure that an embedded document is properly returned from | ||||
|         a query. | ||||
|         """ | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
| @@ -1250,16 +1250,31 @@ class QuerySetTest(unittest.TestCase): | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         post = BlogPost(content='Had a good coffee today...') | ||||
|         post.author = User(name='Test User') | ||||
|         post.save() | ||||
|         BlogPost.objects.create( | ||||
|             author=User(name='Test User'), | ||||
|             content='Had a good coffee today...' | ||||
|         ) | ||||
|  | ||||
|         result = BlogPost.objects.first() | ||||
|         self.assertTrue(isinstance(result.author, User)) | ||||
|         self.assertEqual(result.author.name, 'Test User') | ||||
|  | ||||
|     def test_find_empty_embedded(self): | ||||
|         """Ensure that you can save and find an empty embedded document.""" | ||||
|         class User(EmbeddedDocument): | ||||
|             name = StringField() | ||||
|  | ||||
|         class BlogPost(Document): | ||||
|             content = StringField() | ||||
|             author = EmbeddedDocumentField(User) | ||||
|  | ||||
|         BlogPost.drop_collection() | ||||
|  | ||||
|         BlogPost.objects.create(content='Anonymous post...') | ||||
|  | ||||
|         result = BlogPost.objects.get(author=None) | ||||
|         self.assertEqual(result.author, None) | ||||
|  | ||||
|     def test_find_dict_item(self): | ||||
|         """Ensure that DictField items may be found. | ||||
|         """ | ||||
| @@ -2199,6 +2214,21 @@ class QuerySetTest(unittest.TestCase): | ||||
|             a.author.name for a in Author.objects.order_by('-author__age')] | ||||
|         self.assertEqual(names, ['User A', 'User B', 'User C']) | ||||
|  | ||||
|     def test_comment(self): | ||||
|         """Make sure adding a comment to the query works.""" | ||||
|         class User(Document): | ||||
|             age = IntField() | ||||
|  | ||||
|         with db_ops_tracker() as q: | ||||
|             adult = (User.objects.filter(age__gte=18) | ||||
|                 .comment('looking for an adult') | ||||
|                 .first()) | ||||
|             ops = q.get_ops() | ||||
|             self.assertEqual(len(ops), 1) | ||||
|             op = ops[0] | ||||
|             self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}}) | ||||
|             self.assertEqual(op['query']['$comment'], 'looking for an adult') | ||||
|  | ||||
|     def test_map_reduce(self): | ||||
|         """Ensure map/reduce is both mapping and reducing. | ||||
|         """ | ||||
| @@ -2838,6 +2868,34 @@ class QuerySetTest(unittest.TestCase): | ||||
|             sum([a for a in ages if a >= 50]) | ||||
|         ) | ||||
|  | ||||
|     def test_sum_over_db_field(self): | ||||
|         """Ensure that a field mapped to a db field with a different name | ||||
|         can be summed over correctly. | ||||
|         """ | ||||
|         class UserVisit(Document): | ||||
|             num_visits = IntField(db_field='visits') | ||||
|  | ||||
|         UserVisit.drop_collection() | ||||
|  | ||||
|         UserVisit.objects.create(num_visits=10) | ||||
|         UserVisit.objects.create(num_visits=5) | ||||
|  | ||||
|         self.assertEqual(UserVisit.objects.sum('num_visits'), 15) | ||||
|  | ||||
|     def test_average_over_db_field(self): | ||||
|         """Ensure that a field mapped to a db field with a different name | ||||
|         can have its average computed correctly. | ||||
|         """ | ||||
|         class UserVisit(Document): | ||||
|             num_visits = IntField(db_field='visits') | ||||
|  | ||||
|         UserVisit.drop_collection() | ||||
|  | ||||
|         UserVisit.objects.create(num_visits=20) | ||||
|         UserVisit.objects.create(num_visits=10) | ||||
|  | ||||
|         self.assertEqual(UserVisit.objects.average('num_visits'), 15) | ||||
|  | ||||
|     def test_embedded_average(self): | ||||
|         class Pay(EmbeddedDocument): | ||||
|             value = DecimalField() | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import unittest | ||||
|  | ||||
| from mongoengine.base.datastructures import StrictDict, SemiStrictDict | ||||
|  | ||||
|  | ||||
| @@ -13,6 +14,14 @@ class TestStrictDict(unittest.TestCase): | ||||
|         d = self.dtype(a=1, b=1, c=1) | ||||
|         self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         d = self.dtype(a=1, b=2, c=3) | ||||
|         self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}') | ||||
|  | ||||
|         # make sure quotes are escaped properly | ||||
|         d = self.dtype(a='"', b="'", c="") | ||||
|         self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') | ||||
|  | ||||
|     def test_init_fails_on_nonexisting_attrs(self): | ||||
|         self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user