Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
50923d809d fix doc.get_<field>_display + unit test inspired by #1279 2016-12-03 17:26:39 -05:00
7 changed files with 31 additions and 154 deletions

View File

@@ -438,7 +438,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple __slots__ = allowed_keys_tuple
def __repr__(self): def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
cls._classes[allowed_keys] = SpecificStrictDict cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys] return cls._classes[allowed_keys]

View File

@@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
value = self.document_type._from_son(value) value = self.document_type._from_son(value)
super(EmbeddedDocumentField, self).prepare_query_value(op, value) super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value) return self.to_mongo(value)

View File

@@ -275,8 +275,6 @@ class BaseQuerySet(object):
except StopIteration: except StopIteration:
return result return result
# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind() queryset.rewind()
message = u'%d items returned, instead of 1' % queryset.count() message = u'%d items returned, instead of 1' % queryset.count()
raise queryset._document.MultipleObjectsReturned(message) raise queryset._document.MultipleObjectsReturned(message)
@@ -935,14 +933,6 @@ class BaseQuerySet(object):
queryset._ordering = queryset._get_order_by(keys) queryset._ordering = queryset._get_order_by(keys)
return queryset return queryset
def comment(self, text):
"""Add a comment to the query.
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
return self._chainable_method("comment", text)
def explain(self, format=False): def explain(self, format=False):
"""Return an explain plan record for the """Return an explain plan record for the
:class:`~mongoengine.queryset.QuerySet`\ 's cursor. :class:`~mongoengine.queryset.QuerySet`\ 's cursor.

View File

@@ -27,10 +27,9 @@ class QuerySet(BaseQuerySet):
in batches of ``ITER_CHUNK_SIZE``. in batches of ``ITER_CHUNK_SIZE``.
If ``self._has_more`` the cursor hasn't been exhausted so cache then If ``self._has_more`` the cursor hasn't been exhausted so cache then
batch. Otherwise iterate the result_cache. batch. Otherwise iterate the result_cache.
""" """
self._iter = True self._iter = True
if self._has_more: if self._has_more:
return self._iter_results() return self._iter_results()
@@ -43,12 +42,10 @@ class QuerySet(BaseQuerySet):
""" """
if self._len is not None: if self._len is not None:
return self._len return self._len
# Populate the result cache with *all* of the docs in the cursor
if self._has_more: if self._has_more:
# populate the cache
list(self._iter_results()) list(self._iter_results())
# Cache the length of the complete result cache and return it
self._len = len(self._result_cache) self._len = len(self._result_cache)
return self._len return self._len
@@ -67,33 +64,18 @@ class QuerySet(BaseQuerySet):
def _iter_results(self): def _iter_results(self):
"""A generator for iterating over the result cache. """A generator for iterating over the result cache.
Also populates the cache if there are more possible results to Also populates the cache if there are more possible results to yield.
yield. Raises StopIteration when there are no more results. Raises StopIteration when there are no more results"""
"""
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
pos = 0 pos = 0
while True: while True:
upper = len(self._result_cache)
# For all positions lower than the length of the current result while pos < upper:
# cache, serve the docs straight from the cache w/o hitting the
# database.
# XXX it's VERY important to compute the len within the `while`
# condition because the result cache might expand mid-iteration
# (e.g. if we call len(qs) inside a loop that iterates over the
# queryset). Fortunately len(list) is O(1) in Python, so this
# doesn't cause performance issues.
while pos < len(self._result_cache):
yield self._result_cache[pos] yield self._result_cache[pos]
pos += 1 pos += 1
# Raise StopIteration if we already established there were no more
# docs in the db cursor.
if not self._has_more: if not self._has_more:
raise StopIteration raise StopIteration
# Otherwise, populate more of the cache and repeat.
if len(self._result_cache) <= pos: if len(self._result_cache) <= pos:
self._populate_cache() self._populate_cache()
@@ -104,22 +86,12 @@ class QuerySet(BaseQuerySet):
""" """
if self._result_cache is None: if self._result_cache is None:
self._result_cache = [] self._result_cache = []
if self._has_more:
# Skip populating the cache if we already established there are no try:
# more docs to pull from the database. for i in xrange(ITER_CHUNK_SIZE):
if not self._has_more: self._result_cache.append(self.next())
return except StopIteration:
self._has_more = False
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for i in xrange(ITER_CHUNK_SIZE):
self._result_cache.append(self.next())
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
# information in other places.
self._has_more = False
def count(self, with_limit_and_skip=False): def count(self, with_limit_and_skip=False):
"""Count the selected elements in the query. """Count the selected elements in the query.

View File

@@ -2,8 +2,10 @@
import unittest import unittest
import sys import sys
sys.path[0:0] = [""]
import pymongo import pymongo
from random import randint
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from datetime import datetime from datetime import datetime
@@ -15,9 +17,11 @@ __all__ = ("IndexesTest", )
class IndexesTest(unittest.TestCase): class IndexesTest(unittest.TestCase):
_MAX_RAND = 10 ** 10
def setUp(self): def setUp(self):
self.connection = connect(db='mongoenginetest') self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
self.connection = connect(db=self.db_name)
self.db = get_db() self.db = get_db()
class Person(Document): class Person(Document):

View File

@@ -339,6 +339,7 @@ class QuerySetTest(unittest.TestCase):
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
self.Person.drop_collection() self.Person.drop_collection()
write_concern = {"fsync": True} write_concern = {"fsync": True}
@@ -1238,8 +1239,7 @@ class QuerySetTest(unittest.TestCase):
self.assertFalse('$orderby' in q.get_ops()[0]['query']) self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from """Ensure that an embedded document is properly returned from a query.
a query.
""" """
class User(EmbeddedDocument): class User(EmbeddedDocument):
name = StringField() name = StringField()
@@ -1250,31 +1250,16 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
BlogPost.objects.create( post = BlogPost(content='Had a good coffee today...')
author=User(name='Test User'), post.author = User(name='Test User')
content='Had a good coffee today...' post.save()
)
result = BlogPost.objects.first() result = BlogPost.objects.first()
self.assertTrue(isinstance(result.author, User)) self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User') self.assertEqual(result.author.name, 'Test User')
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
class User(EmbeddedDocument):
name = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
BlogPost.drop_collection() BlogPost.drop_collection()
BlogPost.objects.create(content='Anonymous post...')
result = BlogPost.objects.get(author=None)
self.assertEqual(result.author, None)
def test_find_dict_item(self): def test_find_dict_item(self):
"""Ensure that DictField items may be found. """Ensure that DictField items may be found.
""" """
@@ -2214,21 +2199,6 @@ class QuerySetTest(unittest.TestCase):
a.author.name for a in Author.objects.order_by('-author__age')] a.author.name for a in Author.objects.order_by('-author__age')]
self.assertEqual(names, ['User A', 'User B', 'User C']) self.assertEqual(names, ['User A', 'User B', 'User C'])
def test_comment(self):
"""Make sure adding a comment to the query works."""
class User(Document):
age = IntField()
with db_ops_tracker() as q:
adult = (User.objects.filter(age__gte=18)
.comment('looking for an adult')
.first())
ops = q.get_ops()
self.assertEqual(len(ops), 1)
op = ops[0]
self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}})
self.assertEqual(op['query']['$comment'], 'looking for an adult')
def test_map_reduce(self): def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing. """Ensure map/reduce is both mapping and reducing.
""" """
@@ -4890,56 +4860,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count()) self.assertEqual(1, Doc.objects(item__type__="axe").count())
def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't
stop paging.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
records = Data.objects.limit(250)
# This should pull all 250 docs from mongo and populate the result
# cache
len(records)
# Assert that iterating over documents in the qs touches every
# document even if we call len(qs) midway through the iteration.
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
# Assert the same behavior is true even if we didn't pre-populate the
# result cache.
records = Data.objects.limit(250)
for i, r in enumerate(records):
if i == 58:
len(records)
self.assertEqual(i, 249)
def test_iteration_within_iteration(self):
"""You should be able to reliably iterate over all the documents
in a given queryset even if there are multiple iterations of it
happening at the same time.
"""
class Data(Document):
pass
for i in xrange(300):
Data().save()
qs = Data.objects.limit(250)
for i, doc in enumerate(qs):
for j, doc2 in enumerate(qs):
pass
self.assertEqual(i, 249)
self.assertEqual(j, 249)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -1,6 +1,5 @@
import unittest import unittest
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
class TestStrictDict(unittest.TestCase): class TestStrictDict(unittest.TestCase):
@@ -14,17 +13,9 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype(a=1, b=1, c=1) d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_repr(self):
d = self.dtype(a=1, b=2, c=3)
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
# make sure quotes are escaped properly
d = self.dtype(a='"', b="'", c="")
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self): def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
def test_eq(self): def test_eq(self):
d = self.dtype(a=1, b=1, c=1) d = self.dtype(a=1, b=1, c=1)
dd = self.dtype(a=1, b=1, c=1) dd = self.dtype(a=1, b=1, c=1)
@@ -33,7 +24,7 @@ class TestStrictDict(unittest.TestCase):
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
self.assertEqual(d, dd) self.assertEqual(d, dd)
self.assertNotEqual(d, e) self.assertNotEqual(d, e)
self.assertNotEqual(d, f) self.assertNotEqual(d, f)
@@ -47,19 +38,19 @@ class TestStrictDict(unittest.TestCase):
d.a = 1 d.a = 1
self.assertEqual(d.a, 1) self.assertEqual(d.a, 1)
self.assertRaises(AttributeError, lambda: d.b) self.assertRaises(AttributeError, lambda: d.b)
def test_setattr_raises_on_nonexisting_attr(self): def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype() d = self.dtype()
def _f(): def _f():
d.x = 1 d.x = 1
self.assertRaises(AttributeError, _f) self.assertRaises(AttributeError, _f)
def test_setattr_getattr_special(self): def test_setattr_getattr_special(self):
d = self.strict_dict_class(["items"]) d = self.strict_dict_class(["items"])
d.items = 1 d.items = 1
self.assertEqual(d.items, 1) self.assertEqual(d.items, 1)
def test_get(self): def test_get(self):
d = self.dtype(a=1) d = self.dtype(a=1)
self.assertEqual(d.get('a'), 1) self.assertEqual(d.get('a'), 1)
@@ -97,7 +88,7 @@ class TestSemiSrictDict(TestStrictDict):
def test_init_succeeds_with_nonexisting_attrs(self): def test_init_succeeds_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2) d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
def test_iter_with_nonexisting_attrs(self): def test_iter_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2) d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual(list(d), ['a', 'b', 'c', 'x']) self.assertEqual(list(d), ['a', 'b', 'c', 'x'])