Compare commits

..

2 Commits

Author SHA1 Message Date
Stefan Wojcik
c78e5079d4 extra test for escaping quotes 2016-12-04 15:33:48 -05:00
Stefan Wojcik
ab69e50361 fix __repr__ method of the StrictDict 2016-12-04 15:29:03 -05:00
11 changed files with 41 additions and 246 deletions

View File

@@ -1,7 +1,7 @@
language: python language: python
python: python:
- '2.6' # TODO remove in v0.11.0 - '2.6'
- '2.7' - '2.7'
- '3.3' - '3.3'
- '3.4' - '3.4'

View File

@@ -25,8 +25,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None, def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE, read_preference=READ_PREFERENCE,
username=None, password=None, username=None, password=None, authentication_source=None,
authentication_source=None,
authentication_mechanism=None, authentication_mechanism=None,
**kwargs): **kwargs):
"""Add a connection. """Add a connection.
@@ -71,26 +70,20 @@ def register_connection(alias, name=None, host=None, port=None,
resolved_hosts = [] resolved_hosts = []
for entity in conn_host: for entity in conn_host:
# Handle uri style connections
# Handle Mongomock
if entity.startswith('mongomock://'): if entity.startswith('mongomock://'):
conn_settings['is_mock'] = True conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
elif '://' in entity: elif '://' in entity:
uri_dict = uri_parser.parse_uri(entity) uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity) resolved_hosts.append(entity)
conn_settings.update({
if uri_dict.get('database'): 'name': uri_dict.get('database') or name,
conn_settings['name'] = uri_dict.get('database') 'username': uri_dict.get('username'),
'password': uri_dict.get('password'),
for param in ('read_preference', 'username', 'password'): 'read_preference': read_preference,
if uri_dict.get(param): })
conn_settings[param] = uri_dict[param]
uri_options = uri_dict['options'] uri_options = uri_dict['options']
if 'replicaset' in uri_options: if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True conn_settings['replicaSet'] = True

View File

@@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
if document is None: if document is None:
return None return None
if isinstance(document, (dict, SON, ObjectId, DBRef)): if isinstance(document, (dict, SON)):
return document return document
id_field_name = document.__class__._meta['id_field'] id_field_name = document.__class__._meta['id_field']

View File

@@ -1,22 +1,9 @@
"""Helper functions and types to aid with Python 2.6 - 3 support.""" """Helper functions and types to aid with Python 2.5 - 3 support."""
import sys import sys
import warnings
import pymongo import pymongo
# Show a deprecation warning for people using Python v2.6
# TODO remove in mongoengine v0.11.0
if sys.version_info[0] == 2 and sys.version_info[1] == 6:
warnings.warn(
'Python v2.6 support is deprecated and is going to be dropped '
'entirely in the upcoming v0.11.0 release. Update your Python '
'version if you want to have access to the latest features and '
'bug fixes in MongoEngine.',
DeprecationWarning
)
if pymongo.version_tuple[0] < 3: if pymongo.version_tuple[0] < 3:
IS_PYMONGO_3 = False IS_PYMONGO_3 = False
else: else:

View File

@@ -82,7 +82,6 @@ class BaseQuerySet(object):
self._limit = None self._limit = None
self._skip = None self._skip = None
self._hint = -1 # Using -1 as None is a valid value for hint self._hint = -1 # Using -1 as None is a valid value for hint
self._batch_size = None
self.only_fields = [] self.only_fields = []
self._max_time_ms = None self._max_time_ms = None
@@ -276,8 +275,6 @@ class BaseQuerySet(object):
except StopIteration: except StopIteration:
return result return result
# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind() queryset.rewind()
message = u'%d items returned, instead of 1' % queryset.count() message = u'%d items returned, instead of 1' % queryset.count()
raise queryset._document.MultipleObjectsReturned(message) raise queryset._document.MultipleObjectsReturned(message)
@@ -447,7 +444,7 @@ class BaseQuerySet(object):
if doc._collection == document_cls._collection: if doc._collection == document_cls._collection:
for ref in queryset: for ref in queryset:
cascade_refs.add(ref.id) cascade_refs.add(ref.id)
ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs}) ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs})
ref_q_count = ref_q.count() ref_q_count = ref_q.count()
if ref_q_count > 0: if ref_q_count > 0:
ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs)
@@ -784,19 +781,6 @@ class BaseQuerySet(object):
queryset._hint = index queryset._hint = index
return queryset return queryset
def batch_size(self, size):
"""Limit the number of documents returned in a single batch (each
batch requires a round trip to the server).
See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size
for details.
:param size: desired size of each batch.
"""
queryset = self.clone()
queryset._batch_size = size
return queryset
def distinct(self, field): def distinct(self, field):
"""Return a list of distinct values for a given field. """Return a list of distinct values for a given field.
@@ -1483,9 +1467,6 @@ class BaseQuerySet(object):
if self._hint != -1: if self._hint != -1:
self._cursor_obj.hint(self._hint) self._cursor_obj.hint(self._hint)
if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)
return self._cursor_obj return self._cursor_obj
def __deepcopy__(self, memo): def __deepcopy__(self, memo):

View File

@@ -27,10 +27,9 @@ class QuerySet(BaseQuerySet):
in batches of ``ITER_CHUNK_SIZE``. in batches of ``ITER_CHUNK_SIZE``.
If ``self._has_more`` the cursor hasn't been exhausted so cache then If ``self._has_more`` the cursor hasn't been exhausted so cache then
batch. Otherwise iterate the result_cache. batch. Otherwise iterate the result_cache.
""" """
self._iter = True self._iter = True
if self._has_more: if self._has_more:
return self._iter_results() return self._iter_results()
@@ -43,12 +42,10 @@ class QuerySet(BaseQuerySet):
""" """
if self._len is not None: if self._len is not None:
return self._len return self._len
# Populate the result cache with *all* of the docs in the cursor
if self._has_more: if self._has_more:
# populate the cache
list(self._iter_results()) list(self._iter_results())
# Cache the length of the complete result cache and return it
self._len = len(self._result_cache) self._len = len(self._result_cache)
return self._len return self._len
@@ -67,33 +64,18 @@ class QuerySet(BaseQuerySet):
def _iter_results(self): def _iter_results(self):
"""A generator for iterating over the result cache. """A generator for iterating over the result cache.
Also populates the cache if there are more possible results to Also populates the cache if there are more possible results to yield.
yield. Raises StopIteration when there are no more results. Raises StopIteration when there are no more results"""
"""
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
pos = 0 pos = 0
while True: while True:
upper = len(self._result_cache)
# For all positions lower than the length of the current result while pos < upper:
# cache, serve the docs straight from the cache w/o hitting the
# database.
# XXX it's VERY important to compute the len within the `while`
# condition because the result cache might expand mid-iteration
# (e.g. if we call len(qs) inside a loop that iterates over the
# queryset). Fortunately len(list) is O(1) in Python, so this
# doesn't cause performance issues.
while pos < len(self._result_cache):
yield self._result_cache[pos] yield self._result_cache[pos]
pos += 1 pos += 1
# Raise StopIteration if we already established there were no more
# docs in the db cursor.
if not self._has_more: if not self._has_more:
raise StopIteration raise StopIteration
# Otherwise, populate more of the cache and repeat.
if len(self._result_cache) <= pos: if len(self._result_cache) <= pos:
self._populate_cache() self._populate_cache()
@@ -104,22 +86,12 @@ class QuerySet(BaseQuerySet):
""" """
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
if self._has_more:
# Skip populating the cache if we already established there are no try:
# more docs to pull from the database. for i in xrange(ITER_CHUNK_SIZE):
if not self._has_more: self._result_cache.append(self.next())
return except StopIteration:
self._has_more = False
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for i in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
# information in other places.
self._has_more = False
def count(self, with_limit_and_skip=False): def count(self, with_limit_and_skip=False):
"""Count the selected elements in the query. """Count the selected elements in the query.

View File

@@ -1,7 +1,6 @@
from collections import defaultdict from collections import defaultdict
from bson import ObjectId, SON from bson import SON
from bson.dbref import DBRef
import pymongo import pymongo
from mongoengine.base.fields import UPDATE_OPERATORS from mongoengine.base.fields import UPDATE_OPERATORS
@@ -27,7 +26,6 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS) STRING_OPERATORS + CUSTOM_OPERATORS)
# TODO make this less complex
def query(_doc_cls=None, **kwargs): def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
@@ -64,7 +62,6 @@ def query(_doc_cls=None, **kwargs):
parts = [] parts = []
CachedReferenceField = _import_class('CachedReferenceField') CachedReferenceField = _import_class('CachedReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
cleaned_fields = [] cleaned_fields = []
for field in fields: for field in fields:
@@ -104,16 +101,6 @@ def query(_doc_cls=None, **kwargs):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
# If we're querying a GenericReferenceField, we need to alter the
# key depending on the value:
# * If the value is a DBRef, the key should be "field_name._ref".
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
if isinstance(field, GenericReferenceField):
if isinstance(value, DBRef):
parts[-1] += '._ref'
elif isinstance(value, ObjectId):
parts[-1] += '._ref.$id'
# if op and op not in COMPARISON_OPERATORS: # if op and op not in COMPARISON_OPERATORS:
if op: if op:
if op in GEO_OPERATORS: if op in GEO_OPERATORS:
@@ -141,13 +128,11 @@ def query(_doc_cls=None, **kwargs):
for i, part in indices: for i, part in indices:
parts.insert(i, part) parts.insert(i, part)
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value
elif key in mongo_query: elif key in mongo_query:
if isinstance(mongo_query[key], dict): if key in mongo_query and isinstance(mongo_query[key], dict):
mongo_query[key].update(value) mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON # $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key] value_dict = mongo_query[key]

View File

@@ -9,5 +9,5 @@ tests = tests
[flake8] [flake8]
ignore=E501,F401,F403,F405,I201 ignore=E501,F401,F403,F405,I201
exclude=build,dist,docs,venv,.tox,.eggs,tests exclude=build,dist,docs,venv,.tox,.eggs,tests
max-complexity=45 max-complexity=42
application-import-names=mongoengine,tests application-import-names=mongoengine,tests

View File

@@ -2810,38 +2810,6 @@ class FieldTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_filter_by_dbref(self):
"""Ensure we can search for a specific generic reference by
providing its ObjectId.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
self.assertEqual(doc, doc2)
def test_generic_reference_filter_by_objectid(self):
"""Ensure we can search for a specific generic reference by
providing its DBRef.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
self.assertTrue(isinstance(doc1.pk, ObjectId))
doc = Doc.objects.get(ref=doc1.pk)
self.assertEqual(doc, doc2)
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """

View File

@@ -337,34 +337,6 @@ class QuerySetTest(unittest.TestCase):
query = query.filter(boolfield=True) query = query.filter(boolfield=True)
self.assertEqual(query.count(), 1) self.assertEqual(query.count(), 1)
def test_batch_size(self):
"""Ensure that batch_size works."""
class A(Document):
s = StringField()
A.drop_collection()
for i in range(100):
A.objects.create(s=str(i))
# test iterating over the result set
cnt = 0
for a in A.objects.batch_size(10):
cnt += 1
self.assertEqual(cnt, 100)
# test chaining
qs = A.objects.all()
qs = qs.limit(10).batch_size(20).skip(91)
cnt = 0
for a in qs:
cnt += 1
self.assertEqual(cnt, 9)
# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
self.Person.drop_collection() self.Person.drop_collection()
@@ -4918,56 +4890,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count()) self.assertEqual(1, Doc.objects(item__type__="axe").count())
def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't
stop paging.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
records = Data.objects.limit(250)
# This should pull all 250 docs from mongo and populate the result
# cache
len(records)
# Assert that iterating over documents in the qs touches every
# document even if we call len(qs) midway through the iteration.
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
# Assert the same behavior is true even if we didn't pre-populate the
# result cache.
records = Data.objects.limit(250)
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
def test_iteration_within_iteration(self):
"""You should be able to reliably iterate over all the documents
in a given queryset even if there are multiple iterations of it
happening at the same time.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
qs = Data.objects.limit(250)
for i, doc in enumerate(qs):
for j, doc2 in enumerate(qs):
pass
self.assertEqual(i, 249)
self.assertEqual(j, 249)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -174,9 +174,19 @@ class ConnectionTest(unittest.TestCase):
c.mongoenginetest.system.users.remove({}) c.mongoenginetest.system.users.remove({})
def test_connect_uri_without_db(self): def test_connect_uri_without_db(self):
"""Ensure connect() method works properly if the URI doesn't """Ensure connect() method works properly with uri's without database_name
include a database name.
""" """
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/') connect("mongoenginetest", host='mongodb://localhost/')
conn = get_connection() conn = get_connection()
@@ -186,31 +196,8 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest') self.assertEqual(db.name, 'mongoenginetest')
def test_connect_uri_default_db(self): c.admin.system.users.remove({})
"""Ensure connect() defaults to the right database name if c.mongoenginetest.system.users.remove({})
the URI and the database_name don't explicitly specify it.
"""
connect(host='mongodb://localhost/')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
"""
c = connect(host='mongodb://localhost/mongoenginetest',
username='user',
password='pass')
# OperationFailure means that mongoengine attempted authentication
# w/ the provided username/password and failed - that's the desired
# behavior. If the MongoDB URI would override the credentials
self.assertRaises(OperationFailure, get_db)
def test_connect_uri_with_authsource(self): def test_connect_uri_with_authsource(self):
"""Ensure that the connect() method works well with """Ensure that the connect() method works well with