Compare commits
7 Commits
fix-get-fi
...
fix-iterat
Author | SHA1 | Date | |
---|---|---|---|
|
894678da39 | ||
|
0a66a4b8a9 | ||
|
15714ef855 | ||
|
eb743beaa3 | ||
|
0007535a46 | ||
|
8391af026c | ||
|
800f656dcf |
@@ -438,7 +438,7 @@ class StrictDict(object):
|
||||
__slots__ = allowed_keys_tuple
|
||||
|
||||
def __repr__(self):
|
||||
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
|
||||
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
|
||||
|
||||
cls._classes[allowed_keys] = SpecificStrictDict
|
||||
return cls._classes[allowed_keys]
|
||||
|
@@ -121,7 +121,7 @@ class BaseDocument(object):
|
||||
else:
|
||||
self._data[key] = value
|
||||
|
||||
# Set any get_fieldname_display methods
|
||||
# Set any get_<field>_display methods
|
||||
self.__set_field_display()
|
||||
|
||||
if self._dynamic:
|
||||
@@ -1005,19 +1005,18 @@ class BaseDocument(object):
|
||||
return '.'.join(parts)
|
||||
|
||||
def __set_field_display(self):
|
||||
"""Dynamically set the display value for a field with choices"""
|
||||
for attr_name, field in self._fields.items():
|
||||
if field.choices:
|
||||
if self._dynamic:
|
||||
obj = self
|
||||
else:
|
||||
obj = type(self)
|
||||
setattr(obj,
|
||||
'get_%s_display' % attr_name,
|
||||
partial(self.__get_field_display, field=field))
|
||||
"""For each field that specifies choices, create a
|
||||
get_<field>_display method.
|
||||
"""
|
||||
fields_with_choices = [(n, f) for n, f in self._fields.items()
|
||||
if f.choices]
|
||||
for attr_name, field in fields_with_choices:
|
||||
setattr(self,
|
||||
'get_%s_display' % attr_name,
|
||||
partial(self.__get_field_display, field=field))
|
||||
|
||||
def __get_field_display(self, field):
|
||||
"""Returns the display value for a choice field"""
|
||||
"""Return the display value for a choice field"""
|
||||
value = getattr(self, field.name)
|
||||
if field.choices and isinstance(field.choices[0], (list, tuple)):
|
||||
return dict(field.choices).get(value, value)
|
||||
|
@@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField):
|
||||
return self.document_type._fields.get(member_name)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(value, self.document_type):
|
||||
if value is not None and not isinstance(value, self.document_type):
|
||||
value = self.document_type._from_son(value)
|
||||
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
|
||||
return self.to_mongo(value)
|
||||
|
@@ -275,6 +275,8 @@ class BaseQuerySet(object):
|
||||
except StopIteration:
|
||||
return result
|
||||
|
||||
# If we were able to retrieve the 2nd doc, rewind the cursor and
|
||||
# raise the MultipleObjectsReturned exception.
|
||||
queryset.rewind()
|
||||
message = u'%d items returned, instead of 1' % queryset.count()
|
||||
raise queryset._document.MultipleObjectsReturned(message)
|
||||
@@ -933,6 +935,14 @@ class BaseQuerySet(object):
|
||||
queryset._ordering = queryset._get_order_by(keys)
|
||||
return queryset
|
||||
|
||||
def comment(self, text):
|
||||
"""Add a comment to the query.
|
||||
|
||||
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
|
||||
for details.
|
||||
"""
|
||||
return self._chainable_method("comment", text)
|
||||
|
||||
def explain(self, format=False):
|
||||
"""Return an explain plan record for the
|
||||
:class:`~mongoengine.queryset.QuerySet`\ 's cursor.
|
||||
|
@@ -27,9 +27,10 @@ class QuerySet(BaseQuerySet):
|
||||
in batches of ``ITER_CHUNK_SIZE``.
|
||||
|
||||
If ``self._has_more`` the cursor hasn't been exhausted so cache then
|
||||
batch. Otherwise iterate the result_cache.
|
||||
batch. Otherwise iterate the result_cache.
|
||||
"""
|
||||
self._iter = True
|
||||
|
||||
if self._has_more:
|
||||
return self._iter_results()
|
||||
|
||||
@@ -42,10 +43,12 @@ class QuerySet(BaseQuerySet):
|
||||
"""
|
||||
if self._len is not None:
|
||||
return self._len
|
||||
|
||||
# Populate the result cache with *all* of the docs in the cursor
|
||||
if self._has_more:
|
||||
# populate the cache
|
||||
list(self._iter_results())
|
||||
|
||||
# Cache the length of the complete result cache and return it
|
||||
self._len = len(self._result_cache)
|
||||
return self._len
|
||||
|
||||
@@ -64,18 +67,33 @@ class QuerySet(BaseQuerySet):
|
||||
def _iter_results(self):
|
||||
"""A generator for iterating over the result cache.
|
||||
|
||||
Also populates the cache if there are more possible results to yield.
|
||||
Raises StopIteration when there are no more results"""
|
||||
Also populates the cache if there are more possible results to
|
||||
yield. Raises StopIteration when there are no more results.
|
||||
"""
|
||||
if self._result_cache is None:
|
||||
self._result_cache = []
|
||||
|
||||
pos = 0
|
||||
while True:
|
||||
upper = len(self._result_cache)
|
||||
while pos < upper:
|
||||
|
||||
# For all positions lower than the length of the current result
|
||||
# cache, serve the docs straight from the cache w/o hitting the
|
||||
# database.
|
||||
# XXX it's VERY important to compute the len within the `while`
|
||||
# condition because the result cache might expand mid-iteration
|
||||
# (e.g. if we call len(qs) inside a loop that iterates over the
|
||||
# queryset). Fortunately len(list) is O(1) in Python, so this
|
||||
# doesn't cause performance issues.
|
||||
while pos < len(self._result_cache):
|
||||
yield self._result_cache[pos]
|
||||
pos += 1
|
||||
|
||||
# Raise StopIteration if we already established there were no more
|
||||
# docs in the db cursor.
|
||||
if not self._has_more:
|
||||
raise StopIteration
|
||||
|
||||
# Otherwise, populate more of the cache and repeat.
|
||||
if len(self._result_cache) <= pos:
|
||||
self._populate_cache()
|
||||
|
||||
@@ -86,12 +104,22 @@ class QuerySet(BaseQuerySet):
|
||||
"""
|
||||
if self._result_cache is None:
|
||||
self._result_cache = []
|
||||
if self._has_more:
|
||||
try:
|
||||
for i in xrange(ITER_CHUNK_SIZE):
|
||||
self._result_cache.append(self.next())
|
||||
except StopIteration:
|
||||
self._has_more = False
|
||||
|
||||
# Skip populating the cache if we already established there are no
|
||||
# more docs to pull from the database.
|
||||
if not self._has_more:
|
||||
return
|
||||
|
||||
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
|
||||
# the result cache.
|
||||
try:
|
||||
for i in xrange(ITER_CHUNK_SIZE):
|
||||
self._result_cache.append(self.next())
|
||||
except StopIteration:
|
||||
# Getting this exception means there are no more docs in the
|
||||
# db cursor. Set _has_more to False so that we can use that
|
||||
# information in other places.
|
||||
self._has_more = False
|
||||
|
||||
def count(self, with_limit_and_skip=False):
|
||||
"""Count the selected elements in the query.
|
||||
|
@@ -2,10 +2,8 @@
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import pymongo
|
||||
from random import randint
|
||||
|
||||
from nose.plugins.skip import SkipTest
|
||||
from datetime import datetime
|
||||
@@ -17,11 +15,9 @@ __all__ = ("IndexesTest", )
|
||||
|
||||
|
||||
class IndexesTest(unittest.TestCase):
|
||||
_MAX_RAND = 10 ** 10
|
||||
|
||||
def setUp(self):
|
||||
self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
|
||||
self.connection = connect(db=self.db_name)
|
||||
self.connection = connect(db='mongoenginetest')
|
||||
self.db = get_db()
|
||||
|
||||
class Person(Document):
|
||||
|
@@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase):
|
||||
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
|
||||
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
|
||||
style = StringField(max_length=3, choices=(
|
||||
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
|
||||
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
|
||||
|
||||
Shirt.drop_collection()
|
||||
|
||||
shirt = Shirt()
|
||||
shirt1 = Shirt()
|
||||
shirt2 = Shirt()
|
||||
|
||||
self.assertEqual(shirt.get_size_display(), None)
|
||||
self.assertEqual(shirt.get_style_display(), 'Small')
|
||||
# Make sure get_<field>_display returns the default value (or None)
|
||||
self.assertEqual(shirt1.get_size_display(), None)
|
||||
self.assertEqual(shirt1.get_style_display(), 'Wide')
|
||||
|
||||
shirt.size = "XXL"
|
||||
shirt.style = "B"
|
||||
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
|
||||
self.assertEqual(shirt.get_style_display(), 'Baggy')
|
||||
shirt1.size = 'XXL'
|
||||
shirt1.style = 'B'
|
||||
shirt2.size = 'M'
|
||||
shirt2.style = 'S'
|
||||
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
|
||||
self.assertEqual(shirt1.get_style_display(), 'Baggy')
|
||||
self.assertEqual(shirt2.get_size_display(), 'Medium')
|
||||
self.assertEqual(shirt2.get_style_display(), 'Small')
|
||||
|
||||
# Set as Z - an invalid choice
|
||||
shirt.size = "Z"
|
||||
shirt.style = "Z"
|
||||
self.assertEqual(shirt.get_size_display(), 'Z')
|
||||
self.assertEqual(shirt.get_style_display(), 'Z')
|
||||
self.assertRaises(ValidationError, shirt.validate)
|
||||
|
||||
Shirt.drop_collection()
|
||||
shirt1.size = 'Z'
|
||||
shirt1.style = 'Z'
|
||||
self.assertEqual(shirt1.get_size_display(), 'Z')
|
||||
self.assertEqual(shirt1.get_style_display(), 'Z')
|
||||
self.assertRaises(ValidationError, shirt1.validate)
|
||||
|
||||
def test_simple_choices_validation(self):
|
||||
"""Ensure that value is in a container of allowed values.
|
||||
|
@@ -339,7 +339,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
def test_update_write_concern(self):
|
||||
"""Test that passing write_concern works"""
|
||||
|
||||
self.Person.drop_collection()
|
||||
|
||||
write_concern = {"fsync": True}
|
||||
@@ -1239,7 +1238,8 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
|
||||
|
||||
def test_find_embedded(self):
|
||||
"""Ensure that an embedded document is properly returned from a query.
|
||||
"""Ensure that an embedded document is properly returned from
|
||||
a query.
|
||||
"""
|
||||
class User(EmbeddedDocument):
|
||||
name = StringField()
|
||||
@@ -1250,16 +1250,31 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(content='Had a good coffee today...')
|
||||
post.author = User(name='Test User')
|
||||
post.save()
|
||||
BlogPost.objects.create(
|
||||
author=User(name='Test User'),
|
||||
content='Had a good coffee today...'
|
||||
)
|
||||
|
||||
result = BlogPost.objects.first()
|
||||
self.assertTrue(isinstance(result.author, User))
|
||||
self.assertEqual(result.author.name, 'Test User')
|
||||
|
||||
def test_find_empty_embedded(self):
|
||||
"""Ensure that you can save and find an empty embedded document."""
|
||||
class User(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
author = EmbeddedDocumentField(User)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
BlogPost.objects.create(content='Anonymous post...')
|
||||
|
||||
result = BlogPost.objects.get(author=None)
|
||||
self.assertEqual(result.author, None)
|
||||
|
||||
def test_find_dict_item(self):
|
||||
"""Ensure that DictField items may be found.
|
||||
"""
|
||||
@@ -2199,6 +2214,21 @@ class QuerySetTest(unittest.TestCase):
|
||||
a.author.name for a in Author.objects.order_by('-author__age')]
|
||||
self.assertEqual(names, ['User A', 'User B', 'User C'])
|
||||
|
||||
def test_comment(self):
|
||||
"""Make sure adding a comment to the query works."""
|
||||
class User(Document):
|
||||
age = IntField()
|
||||
|
||||
with db_ops_tracker() as q:
|
||||
adult = (User.objects.filter(age__gte=18)
|
||||
.comment('looking for an adult')
|
||||
.first())
|
||||
ops = q.get_ops()
|
||||
self.assertEqual(len(ops), 1)
|
||||
op = ops[0]
|
||||
self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}})
|
||||
self.assertEqual(op['query']['$comment'], 'looking for an adult')
|
||||
|
||||
def test_map_reduce(self):
|
||||
"""Ensure map/reduce is both mapping and reducing.
|
||||
"""
|
||||
@@ -4860,6 +4890,56 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(1, Doc.objects(item__type__="axe").count())
|
||||
|
||||
def test_len_during_iteration(self):
|
||||
"""Tests that calling len on a queyset during iteration doesn't
|
||||
stop paging.
|
||||
"""
|
||||
class Data(Document):
|
||||
pass
|
||||
|
||||
for i in xrange(300):
|
||||
Data().save()
|
||||
|
||||
records = Data.objects.limit(250)
|
||||
|
||||
# This should pull all 250 docs from mongo and populate the result
|
||||
# cache
|
||||
len(records)
|
||||
|
||||
# Assert that iterating over documents in the qs touches every
|
||||
# document even if we call len(qs) midway through the iteration.
|
||||
for i, r in enumerate(records):
|
||||
if i == 58:
|
||||
len(records)
|
||||
self.assertEqual(i, 249)
|
||||
|
||||
# Assert the same behavior is true even if we didn't pre-populate the
|
||||
# result cache.
|
||||
records = Data.objects.limit(250)
|
||||
for i, r in enumerate(records):
|
||||
if i == 58:
|
||||
len(records)
|
||||
self.assertEqual(i, 249)
|
||||
|
||||
def test_iteration_within_iteration(self):
|
||||
"""You should be able to reliably iterate over all the documents
|
||||
in a given queryset even if there are multiple iterations of it
|
||||
happening at the same time.
|
||||
"""
|
||||
class Data(Document):
|
||||
pass
|
||||
|
||||
for i in xrange(300):
|
||||
Data().save()
|
||||
|
||||
qs = Data.objects.limit(250)
|
||||
for i, doc in enumerate(qs):
|
||||
for j, doc2 in enumerate(qs):
|
||||
pass
|
||||
|
||||
self.assertEqual(i, 249)
|
||||
self.assertEqual(j, 249)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
|
||||
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
|
||||
|
||||
|
||||
@@ -13,6 +14,14 @@ class TestStrictDict(unittest.TestCase):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
|
||||
|
||||
def test_repr(self):
|
||||
d = self.dtype(a=1, b=2, c=3)
|
||||
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
|
||||
|
||||
# make sure quotes are escaped properly
|
||||
d = self.dtype(a='"', b="'", c="")
|
||||
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
|
||||
|
||||
def test_init_fails_on_nonexisting_attrs(self):
|
||||
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
|
||||
|
||||
|
Reference in New Issue
Block a user