Compare commits

...

14 Commits

Author SHA1 Message Date
Stefan Wojcik
34ba527e6d include a link to PyMongo docs for batch_size 2016-12-05 11:27:46 -05:00
Stefan Wojcik
ea9027755f drop redundant test 2016-12-05 00:17:03 -05:00
Stefan Wojcik
43668a93a2 implement BaseQuerySet.batch_size 2016-12-04 19:28:26 -05:00
Stefan Wójcik
15714ef855 Fix __repr__ method of the StrictDict (#1424) 2016-12-04 16:10:59 -05:00
Stefan Wójcik
eb743beaa3 fix doc.get_<field>_display + unit test inspired by #1279 (#1419) 2016-12-04 00:34:24 -05:00
Stefan Wójcik
0007535a46 Add support for cursor.comment (#1420) 2016-12-04 00:33:42 -05:00
Stefan Wójcik
8391af026c Fix filtering by embedded_doc=None (#1422) 2016-12-04 00:32:53 -05:00
Stefan Wójcik
800f656dcf remove unnecessary randomness in indexes tests (#1423) 2016-12-04 00:31:54 -05:00
Stefan Wojcik
088c5f49d9 update the changelog 2016-12-03 16:32:14 -05:00
Ollie Ford
d8d98b6143 Support Falsey primary_keys (#1354) 2016-12-03 16:10:05 -05:00
zeez
02fb3b9315 Support for authentication mechanism #905 (#1333) 2016-12-03 16:08:24 -05:00
Francesc Elies
4f87db784e Make the README example easier to replicate (#1382) 2016-12-02 22:05:20 -05:00
Jérôme Lafréchoux
7e6287b925 Merge pull request #1417 from MongoEngine/fix-db-field-in-sum-and-average
Fix BaseQuerySet#sum and BaseQuerySet#average for fields that specify a db_field
2016-12-02 20:53:48 +01:00
Stefan Wojcik
999cdfd997 Fix BaseQuerySet#sum and BaseQuerySet#average for fields that specify a db_field 2016-12-02 11:32:38 -05:00
13 changed files with 220 additions and 65 deletions

View File

@@ -52,10 +52,14 @@ Some simple examples of what MongoEngine code looks like:
.. code :: python
from mongoengine import *
connect('mydb')
class BlogPost(Document):
title = StringField(required=True, max_length=200)
posted = DateTimeField(default=datetime.datetime.now)
tags = ListField(StringField(max_length=50))
meta = {'allow_inheritance': True}
class TextPost(BlogPost):
content = StringField(required=True)

View File

@@ -4,7 +4,9 @@ Changelog
Changes in 0.10.8
=================
- Fill this in as PRs for v0.10.8 are merged
- Added ability to specify an authentication mechanism (e.g. X.509) #1333
- Added support for falsey primary keys (e.g. doc.pk = 0) #1354
- Fixed BaseQuerySet#sum/average for fields w/ explicit db_field #1417
Changes in 0.10.7
=================

View File

@@ -438,7 +438,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple
def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]

View File

@@ -121,7 +121,7 @@ class BaseDocument(object):
else:
self._data[key] = value
# Set any get_fieldname_display methods
# Set any get_<field>_display methods
self.__set_field_display()
if self._dynamic:
@@ -1005,19 +1005,18 @@ class BaseDocument(object):
return '.'.join(parts)
def __set_field_display(self):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
"""For each field that specifies choices, create a
get_<field>_display method.
"""
fields_with_choices = [(n, f) for n, f in self._fields.items()
if f.choices]
for attr_name, field in fields_with_choices:
setattr(self,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
def __get_field_display(self, field):
"""Returns the display value for a choice field"""
"""Return the display value for a choice field"""
value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value)

View File

@@ -6,6 +6,7 @@ __all__ = ['ConnectionError', 'connect', 'register_connection',
DEFAULT_CONNECTION_NAME = 'default'
if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY
else:
@@ -25,6 +26,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None,
authentication_mechanism=None,
**kwargs):
"""Add a connection.
@@ -38,6 +40,9 @@ def register_connection(alias, name=None, host=None, port=None,
:param username: username to authenticate with
:param password: password to authenticate with
:param authentication_source: database to authenticate against
:param authentication_mechanism: database authentication mechanisms.
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
@@ -53,9 +58,11 @@ def register_connection(alias, name=None, host=None, port=None,
'read_preference': read_preference,
'username': username,
'password': password,
'authentication_source': authentication_source
'authentication_source': authentication_source,
'authentication_mechanism': authentication_mechanism
}
# Handle uri style connections
conn_host = conn_settings['host']
# host can be a list or a string, so if string, force to a list
if isinstance(conn_host, str_types):
@@ -82,6 +89,8 @@ def register_connection(alias, name=None, host=None, port=None,
conn_settings['replicaSet'] = True
if 'authsource' in uri_options:
conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
else:
resolved_hosts.append(entity)
conn_settings['host'] = resolved_hosts
@@ -123,6 +132,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings.pop('username', None)
conn_settings.pop('password', None)
conn_settings.pop('authentication_source', None)
conn_settings.pop('authentication_mechanism', None)
is_mock = conn_settings.pop('is_mock', None)
if is_mock:
@@ -157,6 +167,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_settings.pop('username', None)
connection_settings.pop('password', None)
connection_settings.pop('authentication_source', None)
connection_settings.pop('authentication_mechanism', None)
if conn_settings == connection_settings and _connections.get(db_alias, None):
connection = _connections[db_alias]
break
@@ -176,11 +187,13 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn = get_connection(alias)
conn_settings = _connection_settings[alias]
db = conn[conn_settings['name']]
auth_kwargs = {'source': conn_settings['authentication_source']}
if conn_settings['authentication_mechanism'] is not None:
auth_kwargs['mechanism'] = conn_settings['authentication_mechanism']
# Authenticate if necessary
if conn_settings['username'] and conn_settings['password']:
db.authenticate(conn_settings['username'],
conn_settings['password'],
source=conn_settings['authentication_source'])
if conn_settings['username'] and (conn_settings['password'] or
conn_settings['authentication_mechanism'] == 'MONGODB-X509'):
db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs)
_dbs[alias] = db
return _dbs[alias]

View File

@@ -472,7 +472,7 @@ class Document(BaseDocument):
Raises :class:`OperationError` if called on an object that has not yet
been saved.
"""
if not self.pk:
if self.pk is None:
if kwargs.get('upsert', False):
query = self.to_mongo()
if "_cls" in query:
@@ -604,7 +604,7 @@ class Document(BaseDocument):
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
if not self.pk:
if self.pk is None:
raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).only(*fields).limit(
@@ -655,7 +655,7 @@ class Document(BaseDocument):
def to_dbref(self):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if not self.pk:
if self.pk is None:
msg = "Only saved documents can have a valid dbref"
raise OperationError(msg)
return DBRef(self.__class__._get_collection_name(), self.pk)

View File

@@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
if not isinstance(value, self.document_type):
if value is not None and not isinstance(value, self.document_type):
value = self.document_type._from_son(value)
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)

View File

@@ -82,6 +82,7 @@ class BaseQuerySet(object):
self._limit = None
self._skip = None
self._hint = -1 # Using -1 as None is a valid value for hint
self._batch_size = None
self.only_fields = []
self._max_time_ms = None
@@ -781,6 +782,19 @@ class BaseQuerySet(object):
queryset._hint = index
return queryset
def batch_size(self, size):
"""Limit the number of documents returned in a single batch (each
batch requires a round trip to the server).
See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size
for details.
:param size: desired size of each batch.
"""
queryset = self.clone()
queryset._batch_size = size
return queryset
def distinct(self, field):
"""Return a list of distinct values for a given field.
@@ -933,6 +947,14 @@ class BaseQuerySet(object):
queryset._ordering = queryset._get_order_by(keys)
return queryset
def comment(self, text):
"""Add a comment to the query.
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
return self._chainable_method("comment", text)
def explain(self, format=False):
"""Return an explain plan record for the
:class:`~mongoengine.queryset.QuerySet`\ 's cursor.
@@ -1271,9 +1293,10 @@ class BaseQuerySet(object):
:param field: the field to sum over; use dot notation to refer to
embedded document fields
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}}
]
# if we're performing a sum over a list field, we sum up all the
@@ -1300,9 +1323,10 @@ class BaseQuerySet(object):
:param field: the field to average over; use dot notation to refer to
embedded document fields
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}}
]
# if we're performing an average over a list field, we average out
@@ -1457,6 +1481,9 @@ class BaseQuerySet(object):
if self._hint != -1:
self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj
def __deepcopy__(self, memo):

View File

@@ -2,10 +2,8 @@
import unittest
import sys
sys.path[0:0] = [""]
import pymongo
from random import randint
from nose.plugins.skip import SkipTest
from datetime import datetime
@@ -17,11 +15,9 @@ __all__ = ("IndexesTest", )
class IndexesTest(unittest.TestCase):
_MAX_RAND = 10 ** 10
def setUp(self):
self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
self.connection = connect(db=self.db_name)
self.connection = connect(db='mongoenginetest')
self.db = get_db()
class Person(Document):

View File

@@ -3202,5 +3202,20 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(b._instance, a)
self.assertEqual(idx, 2)
def test_falsey_pk(self):
"""Ensure that we can create and update a document with Falsey PK.
"""
class Person(Document):
age = IntField(primary_key=True)
height = FloatField()
person = Person()
person.age = 0
person.height = 1.89
person.save()
person.update(set__height=2.0)
if __name__ == '__main__':
unittest.main()

View File

@@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
Shirt.drop_collection()
shirt = Shirt()
shirt1 = Shirt()
shirt2 = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
# Make sure get_<field>_display returns the default value (or None)
self.assertEqual(shirt1.get_size_display(), None)
self.assertEqual(shirt1.get_style_display(), 'Wide')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
shirt1.size = 'XXL'
shirt1.style = 'B'
shirt2.size = 'M'
shirt2.style = 'S'
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
shirt1.size = 'Z'
shirt1.style = 'Z'
self.assertEqual(shirt1.get_size_display(), 'Z')
self.assertEqual(shirt1.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt1.validate)
def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.

View File

@@ -337,9 +337,36 @@ class QuerySetTest(unittest.TestCase):
query = query.filter(boolfield=True)
self.assertEqual(query.count(), 1)
def test_batch_size(self):
"""Ensure that batch_size works."""
class A(Document):
s = StringField()
A.drop_collection()
for i in range(100):
A.objects.create(s=str(i))
# test iterating over the result set
cnt = 0
for a in A.objects.batch_size(10):
cnt += 1
self.assertEqual(cnt, 100)
# test chaining
qs = A.objects.all()
qs = qs.limit(10).batch_size(20).skip(91)
cnt = 0
for a in qs:
cnt += 1
self.assertEqual(cnt, 9)
# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
def test_update_write_concern(self):
"""Test that passing write_concern works"""
self.Person.drop_collection()
write_concern = {"fsync": True}
@@ -1239,7 +1266,8 @@ class QuerySetTest(unittest.TestCase):
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query.
"""Ensure that an embedded document is properly returned from
a query.
"""
class User(EmbeddedDocument):
name = StringField()
@@ -1250,16 +1278,31 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
post = BlogPost(content='Had a good coffee today...')
post.author = User(name='Test User')
post.save()
BlogPost.objects.create(
author=User(name='Test User'),
content='Had a good coffee today...'
)
result = BlogPost.objects.first()
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
class User(EmbeddedDocument):
name = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
BlogPost.drop_collection()
BlogPost.objects.create(content='Anonymous post...')
result = BlogPost.objects.get(author=None)
self.assertEqual(result.author, None)
def test_find_dict_item(self):
"""Ensure that DictField items may be found.
"""
@@ -2199,6 +2242,21 @@ class QuerySetTest(unittest.TestCase):
a.author.name for a in Author.objects.order_by('-author__age')]
self.assertEqual(names, ['User A', 'User B', 'User C'])
def test_comment(self):
"""Make sure adding a comment to the query works."""
class User(Document):
age = IntField()
with db_ops_tracker() as q:
adult = (User.objects.filter(age__gte=18)
.comment('looking for an adult')
.first())
ops = q.get_ops()
self.assertEqual(len(ops), 1)
op = ops[0]
self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}})
self.assertEqual(op['query']['$comment'], 'looking for an adult')
def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing.
"""
@@ -2838,6 +2896,34 @@ class QuerySetTest(unittest.TestCase):
sum([a for a in ages if a >= 50])
)
def test_sum_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can be summed over correctly.
"""
class UserVisit(Document):
num_visits = IntField(db_field='visits')
UserVisit.drop_collection()
UserVisit.objects.create(num_visits=10)
UserVisit.objects.create(num_visits=5)
self.assertEqual(UserVisit.objects.sum('num_visits'), 15)
def test_average_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can have its average computed correctly.
"""
class UserVisit(Document):
num_visits = IntField(db_field='visits')
UserVisit.drop_collection()
UserVisit.objects.create(num_visits=20)
UserVisit.objects.create(num_visits=10)
self.assertEqual(UserVisit.objects.average('num_visits'), 15)
def test_embedded_average(self):
class Pay(EmbeddedDocument):
value = DecimalField()

View File

@@ -1,4 +1,5 @@
import unittest
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
@@ -13,6 +14,14 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_repr(self):
d = self.dtype(a=1, b=2, c=3)
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
# make sure quotes are escaped properly
d = self.dtype(a='"', b="'", c="")
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))