Compare commits

..

8 Commits

Author SHA1 Message Date
Stefan Wojcik
d1997667ab fix flake8 2016-12-10 12:59:39 -05:00
Stefan Wojcik
9f2e44e600 Deprecate Python v2.6 2016-12-09 13:20:21 -05:00
Stefan Wójcik
1b9432824b Add ability to filter the generic reference field by ObjectId and DBRef (#1425) 2016-12-09 12:56:06 -05:00
rmendocna
25e0f12976 fix delete cascade for models without a literal id field: replace with pk (#1247) 2016-12-05 22:54:21 -05:00
Stefan Wójcik
f168682a68 Dont let the MongoDB URI override connection settings it doesnt explicitly specify (#1421) 2016-12-05 22:31:00 -05:00
Stefan Wójcik
d25058a46d Implement BaseQuerySet.batch_size (#1426) 2016-12-05 22:13:22 -05:00
Stefan Wójcik
4d0c092d9f Fix iteration within iteration (#1427) 2016-12-05 09:38:24 -05:00
Stefan Wójcik
15714ef855 Fix __repr__ method of the StrictDict (#1424) 2016-12-04 16:10:59 -05:00
11 changed files with 246 additions and 41 deletions

View File

@@ -1,7 +1,7 @@
language: python language: python
python: python:
- '2.6' - '2.6' # TODO remove in v0.11.0
- '2.7' - '2.7'
- '3.3' - '3.3'
- '3.4' - '3.4'

View File

@@ -25,7 +25,8 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None, def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE, read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None, username=None, password=None,
authentication_source=None,
authentication_mechanism=None, authentication_mechanism=None,
**kwargs): **kwargs):
"""Add a connection. """Add a connection.
@@ -70,20 +71,26 @@ def register_connection(alias, name=None, host=None, port=None,
resolved_hosts = [] resolved_hosts = []
for entity in conn_host: for entity in conn_host:
# Handle uri style connections
# Handle Mongomock
if entity.startswith('mongomock://'): if entity.startswith('mongomock://'):
conn_settings['is_mock'] = True conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
elif '://' in entity: elif '://' in entity:
uri_dict = uri_parser.parse_uri(entity) uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity) resolved_hosts.append(entity)
conn_settings.update({
'name': uri_dict.get('database') or name, if uri_dict.get('database'):
'username': uri_dict.get('username'), conn_settings['name'] = uri_dict.get('database')
'password': uri_dict.get('password'),
'read_preference': read_preference, for param in ('read_preference', 'username', 'password'):
}) if uri_dict.get(param):
conn_settings[param] = uri_dict[param]
uri_options = uri_dict['options'] uri_options = uri_dict['options']
if 'replicaset' in uri_options: if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True conn_settings['replicaSet'] = True

View File

@@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
if document is None: if document is None:
return None return None
if isinstance(document, (dict, SON)): if isinstance(document, (dict, SON, ObjectId, DBRef)):
return document return document
id_field_name = document.__class__._meta['id_field'] id_field_name = document.__class__._meta['id_field']

View File

@@ -1,9 +1,22 @@
"""Helper functions and types to aid with Python 2.5 - 3 support.""" """Helper functions and types to aid with Python 2.6 - 3 support."""
import sys import sys
import warnings
import pymongo import pymongo
# Show a deprecation warning for people using Python v2.6
# TODO remove in mongoengine v0.11.0
if sys.version_info[0] == 2 and sys.version_info[1] == 6:
warnings.warn(
'Python v2.6 support is deprecated and is going to be dropped '
'entirely in the upcoming v0.11.0 release. Update your Python '
'version if you want to have access to the latest features and '
'bug fixes in MongoEngine.',
DeprecationWarning
)
if pymongo.version_tuple[0] < 3: if pymongo.version_tuple[0] < 3:
IS_PYMONGO_3 = False IS_PYMONGO_3 = False
else: else:

View File

@@ -82,6 +82,7 @@ class BaseQuerySet(object):
self._limit = None self._limit = None
self._skip = None self._skip = None
self._hint = -1 # Using -1 as None is a valid value for hint self._hint = -1 # Using -1 as None is a valid value for hint
self._batch_size = None
self.only_fields = [] self.only_fields = []
self._max_time_ms = None self._max_time_ms = None
@@ -275,6 +276,8 @@ class BaseQuerySet(object):
except StopIteration: except StopIteration:
return result return result
# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind() queryset.rewind()
message = u'%d items returned, instead of 1' % queryset.count() message = u'%d items returned, instead of 1' % queryset.count()
raise queryset._document.MultipleObjectsReturned(message) raise queryset._document.MultipleObjectsReturned(message)
@@ -444,7 +447,7 @@ class BaseQuerySet(object):
if doc._collection == document_cls._collection: if doc._collection == document_cls._collection:
for ref in queryset: for ref in queryset:
cascade_refs.add(ref.id) cascade_refs.add(ref.id)
ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs}) ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs})
ref_q_count = ref_q.count() ref_q_count = ref_q.count()
if ref_q_count > 0: if ref_q_count > 0:
ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs)
@@ -781,6 +784,19 @@ class BaseQuerySet(object):
queryset._hint = index queryset._hint = index
return queryset return queryset
def batch_size(self, size):
"""Limit the number of documents returned in a single batch (each
batch requires a round trip to the server).
See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size
for details.
:param size: desired size of each batch.
"""
queryset = self.clone()
queryset._batch_size = size
return queryset
def distinct(self, field): def distinct(self, field):
"""Return a list of distinct values for a given field. """Return a list of distinct values for a given field.
@@ -1467,6 +1483,9 @@ class BaseQuerySet(object):
if self._hint != -1: if self._hint != -1:
self._cursor_obj.hint(self._hint) self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj return self._cursor_obj
def __deepcopy__(self, memo): def __deepcopy__(self, memo):

View File

@@ -27,9 +27,10 @@ class QuerySet(BaseQuerySet):
in batches of ``ITER_CHUNK_SIZE``. in batches of ``ITER_CHUNK_SIZE``.
If ``self._has_more`` the cursor hasn't been exhausted so cache then If ``self._has_more`` the cursor hasn't been exhausted so cache then
batch. Otherwise iterate the result_cache. batch. Otherwise iterate the result_cache.
""" """
self._iter = True self._iter = True
if self._has_more: if self._has_more:
return self._iter_results() return self._iter_results()
@@ -42,10 +43,12 @@ class QuerySet(BaseQuerySet):
""" """
if self._len is not None: if self._len is not None:
return self._len return self._len
# Populate the result cache with *all* of the docs in the cursor
if self._has_more: if self._has_more:
# populate the cache
list(self._iter_results()) list(self._iter_results())
# Cache the length of the complete result cache and return it
self._len = len(self._result_cache) self._len = len(self._result_cache)
return self._len return self._len
@@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet):
def _iter_results(self): def _iter_results(self):
"""A generator for iterating over the result cache. """A generator for iterating over the result cache.
Also populates the cache if there are more possible results to yield. Also populates the cache if there are more possible results to
Raises StopIteration when there are no more results""" yield. Raises StopIteration when there are no more results.
"""
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
pos = 0 pos = 0
while True: while True:
upper = len(self._result_cache)
while pos < upper: # For all positions lower than the length of the current result
# cache, serve the docs straight from the cache w/o hitting the
# database.
# XXX it's VERY important to compute the len within the `while`
# condition because the result cache might expand mid-iteration
# (e.g. if we call len(qs) inside a loop that iterates over the
# queryset). Fortunately len(list) is O(1) in Python, so this
# doesn't cause performance issues.
while pos < len(self._result_cache):
yield self._result_cache[pos] yield self._result_cache[pos]
pos += 1 pos += 1
# Raise StopIteration if we already established there were no more
# docs in the db cursor.
if not self._has_more: if not self._has_more:
raise StopIteration raise StopIteration
# Otherwise, populate more of the cache and repeat.
if len(self._result_cache) <= pos: if len(self._result_cache) <= pos:
self._populate_cache() self._populate_cache()
@@ -86,12 +104,22 @@ class QuerySet(BaseQuerySet):
""" """
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
if self._has_more:
try: # Skip populating the cache if we already established there are no
for i in xrange(ITER_CHUNK_SIZE): # more docs to pull from the database.
self._result_cache.append(self.next()) if not self._has_more:
except StopIteration: return
self._has_more = False
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for i in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
# information in other places.
self._has_more = False
def count(self, with_limit_and_skip=False): def count(self, with_limit_and_skip=False):
"""Count the selected elements in the query. """Count the selected elements in the query.

View File

@@ -1,6 +1,7 @@
from collections import defaultdict from collections import defaultdict
from bson import SON from bson import ObjectId, SON
from bson.dbref import DBRef
import pymongo import pymongo
from mongoengine.base.fields import UPDATE_OPERATORS from mongoengine.base.fields import UPDATE_OPERATORS
@@ -26,6 +27,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS) STRING_OPERATORS + CUSTOM_OPERATORS)
# TODO make this less complex
def query(_doc_cls=None, **kwargs): def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
@@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs):
parts = [] parts = []
CachedReferenceField = _import_class('CachedReferenceField') CachedReferenceField = _import_class('CachedReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
cleaned_fields = [] cleaned_fields = []
for field in fields: for field in fields:
@@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
# If we're querying a GenericReferenceField, we need to alter the
# key depending on the value:
# * If the value is a DBRef, the key should be "field_name._ref".
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
if isinstance(field, GenericReferenceField):
if isinstance(value, DBRef):
parts[-1] += '._ref'
elif isinstance(value, ObjectId):
parts[-1] += '._ref.$id'
# if op and op not in COMPARISON_OPERATORS: # if op and op not in COMPARISON_OPERATORS:
if op: if op:
if op in GEO_OPERATORS: if op in GEO_OPERATORS:
@@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs):
for i, part in indices: for i, part in indices:
parts.insert(i, part) parts.insert(i, part)
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value
elif key in mongo_query: elif key in mongo_query:
if key in mongo_query and isinstance(mongo_query[key], dict): if isinstance(mongo_query[key], dict):
mongo_query[key].update(value) mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON # $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key] value_dict = mongo_query[key]

View File

@@ -9,5 +9,5 @@ tests = tests
[flake8] [flake8]
ignore=E501,F401,F403,F405,I201 ignore=E501,F401,F403,F405,I201
exclude=build,dist,docs,venv,.tox,.eggs,tests exclude=build,dist,docs,venv,.tox,.eggs,tests
max-complexity=42 max-complexity=45
application-import-names=mongoengine,tests application-import-names=mongoengine,tests

View File

@@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_filter_by_dbref(self):
"""Ensure we can search for a specific generic reference by
providing its ObjectId.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
self.assertEqual(doc, doc2)
def test_generic_reference_filter_by_objectid(self):
"""Ensure we can search for a specific generic reference by
providing its DBRef.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
self.assertTrue(isinstance(doc1.pk, ObjectId))
doc = Doc.objects.get(ref=doc1.pk)
self.assertEqual(doc, doc2)
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """

View File

@@ -337,6 +337,34 @@ class QuerySetTest(unittest.TestCase):
query = query.filter(boolfield=True) query = query.filter(boolfield=True)
self.assertEqual(query.count(), 1) self.assertEqual(query.count(), 1)
def test_batch_size(self):
"""Ensure that batch_size works."""
class A(Document):
s = StringField()
A.drop_collection()
for i in range(100):
A.objects.create(s=str(i))
# test iterating over the result set
cnt = 0
for a in A.objects.batch_size(10):
cnt += 1
self.assertEqual(cnt, 100)
# test chaining
qs = A.objects.all()
qs = qs.limit(10).batch_size(20).skip(91)
cnt = 0
for a in qs:
cnt += 1
self.assertEqual(cnt, 9)
# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
self.Person.drop_collection() self.Person.drop_collection()
@@ -4890,6 +4918,56 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count()) self.assertEqual(1, Doc.objects(item__type__="axe").count())
def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't
stop paging.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
records = Data.objects.limit(250)
# This should pull all 250 docs from mongo and populate the result
# cache
len(records)
# Assert that iterating over documents in the qs touches every
# document even if we call len(qs) midway through the iteration.
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
# Assert the same behavior is true even if we didn't pre-populate the
# result cache.
records = Data.objects.limit(250)
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
def test_iteration_within_iteration(self):
"""You should be able to reliably iterate over all the documents
in a given queryset even if there are multiple iterations of it
happening at the same time.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
qs = Data.objects.limit(250)
for i, doc in enumerate(qs):
for j, doc2 in enumerate(qs):
pass
self.assertEqual(i, 249)
self.assertEqual(j, 249)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -174,19 +174,9 @@ class ConnectionTest(unittest.TestCase):
c.mongoenginetest.system.users.remove({}) c.mongoenginetest.system.users.remove({})
def test_connect_uri_without_db(self): def test_connect_uri_without_db(self):
"""Ensure connect() method works properly with uri's without database_name """Ensure connect() method works properly if the URI doesn't
include a database name.
""" """
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/') connect("mongoenginetest", host='mongodb://localhost/')
conn = get_connection() conn = get_connection()
@@ -196,8 +186,31 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest') self.assertEqual(db.name, 'mongoenginetest')
c.admin.system.users.remove({}) def test_connect_uri_default_db(self):
c.mongoenginetest.system.users.remove({}) """Ensure connect() defaults to the right database name if
the URI and the database_name don't explicitly specify it.
"""
connect(host='mongodb://localhost/')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
"""
c = connect(host='mongodb://localhost/mongoenginetest',
username='user',
password='pass')
# OperationFailure means that mongoengine attempted authentication
# w/ the provided username/password and failed - that's the desired
# behavior. If the MongoDB URI would override the credentials
self.assertRaises(OperationFailure, get_db)
def test_connect_uri_with_authsource(self): def test_connect_uri_with_authsource(self):
"""Ensure that the connect() method works well with """Ensure that the connect() method works well with