Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
50923d809d fix doc.get_<field>_display + unit test inspired by #1279 2016-12-03 17:26:39 -05:00
9 changed files with 23 additions and 113 deletions

View File

@@ -438,7 +438,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple
def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]

View File

@@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type):
if not isinstance(value, self.document_type):
value = self.document_type._from_son(value)
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)
@@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
if document is None:
return None
if isinstance(document, (dict, SON, ObjectId, DBRef)):
if isinstance(document, (dict, SON)):
return document
id_field_name = document.__class__._meta['id_field']

View File

@@ -933,14 +933,6 @@ class BaseQuerySet(object):
queryset._ordering = queryset._get_order_by(keys)
return queryset
def comment(self, text):
"""Add a comment to the query.
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
return self._chainable_method("comment", text)
def explain(self, format=False):
"""Return an explain plan record for the
:class:`~mongoengine.queryset.QuerySet`\ 's cursor.

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
from bson import ObjectId, SON
from bson.dbref import DBRef
from bson import SON
import pymongo
from mongoengine.base.fields import UPDATE_OPERATORS
@@ -27,7 +26,6 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS)
# TODO make this less complex
def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format.
"""
@@ -64,7 +62,6 @@ def query(_doc_cls=None, **kwargs):
parts = []
CachedReferenceField = _import_class('CachedReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
cleaned_fields = []
for field in fields:
@@ -104,16 +101,6 @@ def query(_doc_cls=None, **kwargs):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
# If we're querying a GenericReferenceField, we need to alter the
# key depending on the value:
# * If the value is a DBRef, the key should be "field_name._ref".
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
if isinstance(field, GenericReferenceField):
if isinstance(value, DBRef):
parts[-1] += '._ref'
elif isinstance(value, ObjectId):
parts[-1] += '._ref.$id'
# if op and op not in COMPARISON_OPERATORS:
if op:
if op in GEO_OPERATORS:
@@ -141,13 +128,11 @@ def query(_doc_cls=None, **kwargs):
for i, part in indices:
parts.insert(i, part)
key = '.'.join(parts)
if op is None or key not in mongo_query:
mongo_query[key] = value
elif key in mongo_query:
if isinstance(mongo_query[key], dict):
if key in mongo_query and isinstance(mongo_query[key], dict):
mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key]

View File

@@ -9,5 +9,5 @@ tests = tests
[flake8]
ignore=E501,F401,F403,F405,I201
exclude=build,dist,docs,venv,.tox,.eggs,tests
max-complexity=45
max-complexity=42
application-import-names=mongoengine,tests

View File

@@ -2,8 +2,10 @@
import unittest
import sys
sys.path[0:0] = [""]
import pymongo
from random import randint
from nose.plugins.skip import SkipTest
from datetime import datetime
@@ -15,9 +17,11 @@ __all__ = ("IndexesTest", )
class IndexesTest(unittest.TestCase):
_MAX_RAND = 10 ** 10
def setUp(self):
self.connection = connect(db='mongoenginetest')
self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
self.connection = connect(db=self.db_name)
self.db = get_db()
class Person(Document):

View File

@@ -2810,38 +2810,6 @@ class FieldTest(unittest.TestCase):
Post.drop_collection()
User.drop_collection()
def test_generic_reference_filter_by_dbref(self):
"""Ensure we can search for a specific generic reference by
providing its ObjectId.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
self.assertEqual(doc, doc2)
def test_generic_reference_filter_by_objectid(self):
"""Ensure we can search for a specific generic reference by
providing its DBRef.
"""
class Doc(Document):
ref = GenericReferenceField()
Doc.drop_collection()
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
self.assertTrue(isinstance(doc1.pk, ObjectId))
doc = Doc.objects.get(ref=doc1.pk)
self.assertEqual(doc, doc2)
def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved.
"""

View File

@@ -339,6 +339,7 @@ class QuerySetTest(unittest.TestCase):
def test_update_write_concern(self):
"""Test that passing write_concern works"""
self.Person.drop_collection()
write_concern = {"fsync": True}
@@ -1238,8 +1239,7 @@ class QuerySetTest(unittest.TestCase):
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from
a query.
"""Ensure that an embedded document is properly returned from a query.
"""
class User(EmbeddedDocument):
name = StringField()
@@ -1250,31 +1250,16 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
BlogPost.objects.create(
author=User(name='Test User'),
content='Had a good coffee today...'
)
post = BlogPost(content='Had a good coffee today...')
post.author = User(name='Test User')
post.save()
result = BlogPost.objects.first()
self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User')
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
class User(EmbeddedDocument):
name = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
BlogPost.drop_collection()
BlogPost.objects.create(content='Anonymous post...')
result = BlogPost.objects.get(author=None)
self.assertEqual(result.author, None)
def test_find_dict_item(self):
"""Ensure that DictField items may be found.
"""
@@ -2214,21 +2199,6 @@ class QuerySetTest(unittest.TestCase):
a.author.name for a in Author.objects.order_by('-author__age')]
self.assertEqual(names, ['User A', 'User B', 'User C'])
def test_comment(self):
"""Make sure adding a comment to the query works."""
class User(Document):
age = IntField()
with db_ops_tracker() as q:
adult = (User.objects.filter(age__gte=18)
.comment('looking for an adult')
.first())
ops = q.get_ops()
self.assertEqual(len(ops), 1)
op = ops[0]
self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}})
self.assertEqual(op['query']['$comment'], 'looking for an adult')
def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing.
"""

View File

@@ -1,6 +1,5 @@
import unittest
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
class TestStrictDict(unittest.TestCase):
@@ -14,17 +13,9 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_repr(self):
d = self.dtype(a=1, b=2, c=3)
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
# make sure quotes are escaped properly
d = self.dtype(a='"', b="'", c="")
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
def test_eq(self):
d = self.dtype(a=1, b=1, c=1)
dd = self.dtype(a=1, b=1, c=1)
@@ -33,7 +24,7 @@ class TestStrictDict(unittest.TestCase):
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
self.assertEqual(d, dd)
self.assertNotEqual(d, e)
self.assertNotEqual(d, f)
@@ -47,19 +38,19 @@ class TestStrictDict(unittest.TestCase):
d.a = 1
self.assertEqual(d.a, 1)
self.assertRaises(AttributeError, lambda: d.b)
def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype()
def _f():
d.x = 1
self.assertRaises(AttributeError, _f)
def test_setattr_getattr_special(self):
d = self.strict_dict_class(["items"])
d.items = 1
self.assertEqual(d.items, 1)
def test_get(self):
d = self.dtype(a=1)
self.assertEqual(d.get('a'), 1)
@@ -97,7 +88,7 @@ class TestSemiSrictDict(TestStrictDict):
def test_init_succeeds_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
def test_iter_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual(list(d), ['a', 'b', 'c', 'x'])