Added initial CachedReferenceField

This commit is contained in:
Wilson Júnior 2014-07-07 12:54:41 -03:00
parent c97eb5d63f
commit 30fdd3e184
3 changed files with 445 additions and 101 deletions

View File

@ -23,6 +23,7 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
class BaseDocument(object):
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
@ -50,7 +51,8 @@ class BaseDocument(object):
for value in args:
name = next(field)
if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
raise TypeError(
"Multiple values for keyword argument '" + name + "'")
values[name] = value
__auto_convert = values.pop("__auto_convert", True)
signals.pre_init.send(self.__class__, document=self, values=values)
@ -58,7 +60,8 @@ class BaseDocument(object):
if self.STRICT and not self._dynamic:
self._data = StrictDict.create(allowed_keys=self._fields_ordered)()
else:
self._data = SemiStrictDict.create(allowed_keys=self._fields_ordered)()
self._data = SemiStrictDict.create(
allowed_keys=self._fields_ordered)()
_created = values.pop("_created", True)
self._data = {}
@ -257,26 +260,45 @@ class BaseDocument(object):
"""
pass
def to_mongo(self, use_db_field=True):
"""Return as SON data ready for use with MongoDB.
def to_mongo(self, use_db_field=True, fields=[]):
"""
Return as SON data ready for use with MongoDB.
"""
data = SON()
data["_id"] = None
data['_cls'] = self._class_name
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
root_fields = set([f.split('.')[0] for f in fields])
for field_name in self:
if root_fields and field_name not in root_fields:
continue
value = self._data.get(field_name, None)
field = self._fields.get(field_name)
if field is None and self._dynamic:
field = self._dynamic_fields.get(field_name)
if value is not None:
EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, (EmbeddedDocument)) and use_db_field==False:
if isinstance(value, (EmbeddedDocument)) and \
not use_db_field:
value = field.to_mongo(value, use_db_field)
else:
value = field.to_mongo(value)
if isinstance(field, EmbeddedDocumentField) and fields:
key = '%s.' % field_name
value = field.to_mongo(value, fields=[
i.replace(key, '') for i in fields if i.startswith(key)])
elif value is not None:
value = field.to_mongo(value)
# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
@ -321,7 +343,8 @@ class BaseDocument(object):
self._data.get(name)) for name in self._fields_ordered]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class(
"GenericEmbeddedDocumentField")
for field, value in fields:
if value is not None:
@ -352,7 +375,8 @@ class BaseDocument(object):
"""Converts a document to JSON.
:param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False
"""
use_db_field = kwargs.pop('use_db_field') if kwargs.has_key('use_db_field') else True
use_db_field = kwargs.pop('use_db_field') if kwargs.has_key(
'use_db_field') else True
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
@classmethod
@ -454,7 +478,8 @@ class BaseDocument(object):
changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed.
@ -493,7 +518,8 @@ class BaseDocument(object):
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
self._nestable_types_changed_fields(
changed_fields, key, data, inspected)
return changed_fields
def _delta(self):
@ -631,7 +657,8 @@ class BaseDocument(object):
raise InvalidDocumentError(msg)
if cls.STRICT:
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
data = dict((k, v)
for k, v in data.iteritems() if k in cls._fields)
obj = cls(__auto_convert=False, _created=False, **data)
obj._changed_fields = changed_fields
if not _auto_dereference:
@ -794,7 +821,8 @@ class BaseDocument(object):
geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField",
"PointField", "LineStringField", "PolygonField"]
geo_field_types = tuple([_import_class(field) for field in geo_field_type_names])
geo_field_types = tuple([_import_class(field)
for field in geo_field_type_names])
for field in cls._fields.values():
if not isinstance(field, geo_field_types):
@ -804,7 +832,8 @@ class BaseDocument(object):
if field_cls in inspected:
continue
if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field)
geo_indices += field_cls._geo_indices(
inspected, parent_field=field.db_field)
elif field._geo_index:
field_name = field.db_field
if parent_field:

View File

@ -34,13 +34,14 @@ except ImportError:
Image = None
ImageOps = None
__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
__all__ = [
'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
'SortedListField', 'DictField', 'MapField', 'ReferenceField',
'GenericReferenceField', 'BinaryField', 'GridFSError',
'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'CachedReferenceField', 'GenericReferenceField', 'BinaryField',
'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField',
'GeoJsonBaseField']
@ -50,6 +51,7 @@ RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField):
"""A unicode string field.
"""
@ -109,6 +111,7 @@ class StringField(BaseField):
class URLField(StringField):
"""A field that validates input as an URL.
.. versionadded:: 0.3
@ -116,7 +119,8 @@ class URLField(StringField):
_URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
# domain...
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
@ -145,15 +149,19 @@ class URLField(StringField):
class EmailField(StringField):
"""A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4
"""
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain
# dot-atom
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*"
# quoted-string
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"'
# domain
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE
)
def validate(self, value):
@ -163,6 +171,7 @@ class EmailField(StringField):
class IntField(BaseField):
"""An 32-bit integer field.
"""
@ -197,6 +206,7 @@ class IntField(BaseField):
class LongField(BaseField):
"""An 64-bit integer field.
"""
@ -231,6 +241,7 @@ class LongField(BaseField):
class FloatField(BaseField):
"""An floating point number field.
"""
@ -265,6 +276,7 @@ class FloatField(BaseField):
class DecimalField(BaseField):
"""A fixed-point decimal number field.
.. versionchanged:: 0.8
@ -338,6 +350,7 @@ class DecimalField(BaseField):
class BooleanField(BaseField):
"""A boolean field type.
.. versionadded:: 0.1.2
@ -356,6 +369,7 @@ class BooleanField(BaseField):
class DateTimeField(BaseField):
"""A datetime field.
Uses the python-dateutil library if available alternatively use time.strptime
@ -423,6 +437,7 @@ class DateTimeField(BaseField):
class ComplexDateTimeField(StringField):
"""
ComplexDateTimeField handles microseconds exactly instead of rounding
like DateTimeField does.
@ -525,6 +540,7 @@ class ComplexDateTimeField(StringField):
class EmbeddedDocumentField(BaseField):
"""An embedded document field - with a declared document_type.
Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`.
"""
@ -551,7 +567,7 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._from_son(value)
return value
def to_mongo(self, value, use_db_field=True):
def to_mongo(self, value, use_db_field=True, fields=[]):
if not isinstance(value, self.document_type):
return value
return self.document_type.to_mongo(value, use_db_field)
@ -574,6 +590,7 @@ class EmbeddedDocumentField(BaseField):
class GenericEmbeddedDocumentField(BaseField):
"""A generic embedded document field - allows any
:class:`~mongoengine.EmbeddedDocument` to be stored.
@ -612,6 +629,7 @@ class GenericEmbeddedDocumentField(BaseField):
class DynamicField(BaseField):
"""A truly dynamic field type capable of handling different and varying
types of data.
@ -675,6 +693,7 @@ class DynamicField(BaseField):
class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
@ -708,6 +727,7 @@ class ListField(ComplexBaseField):
class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always
retrieved.
@ -739,6 +759,7 @@ class SortedListField(ListField):
reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse)
def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is
not a string.
@ -747,6 +768,7 @@ def key_not_string(d):
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
return True
def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary
contains a dot or a dollar sign.
@ -755,7 +777,9 @@ def key_has_dot_or_dollar(d):
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
return True
class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
@ -807,6 +831,7 @@ class DictField(ComplexBaseField):
class MapField(DictField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@ -822,6 +847,7 @@ class MapField(DictField):
class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on
access (lazily).
@ -955,7 +981,106 @@ class ReferenceField(BaseField):
return self.document_type._fields.get(member_name)
class CachedReferenceField(BaseField):
"""
A referencefield with cache fields support
.. versionadded:: 0.9
"""
def __init__(self, document_type, fields=[], **kwargs):
"""Initialises the Cached Reference Field.
:param fields: A list of fields to be cached in document
"""
if not isinstance(document_type, basestring) and \
not issubclass(document_type, (Document, basestring)):
self.error('Argument to CachedReferenceField constructor must be a'
' document class or a string')
self.document_type_obj = document_type
self.fields = fields
super(CachedReferenceField, self).__init__(**kwargs)
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
if isinstance(value, dict):
collection = self.document_type._get_collection_name()
value = DBRef(
collection, self.document_type.id.to_python(value['_id']))
return value
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing.
"""
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
doc_tipe = self.document_type
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
if id_ is None:
self.error('You can only reference documents once they have'
' been saved to the database')
else:
self.error('Only accept a document object')
value = {
"_id": id_field.to_mongo(id_)
}
value.update(dict(document.to_mongo(fields=self.fields)))
return value
def prepare_query_value(self, op, value):
if value is None:
return None
return self.to_mongo(value)
def validate(self, value):
if not isinstance(value, (self.document_type)):
self.error("A CachedReferenceField only accepts documents")
if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been '
'saved to the database')
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
@ -974,6 +1099,7 @@ class GenericReferenceField(BaseField):
return self
value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference
if self._auto_dereference and isinstance(value, (dict, SON)):
instance._data[self.name] = self.dereference(value)
@ -1036,6 +1162,7 @@ class GenericReferenceField(BaseField):
class BinaryField(BaseField):
"""A binary data field.
"""
@ -1067,6 +1194,7 @@ class GridFSError(Exception):
class GridFSProxy(object):
"""Proxy object to handle writing and reading of files to and from GridFS
.. versionadded:: 0.4
@ -1121,7 +1249,8 @@ class GridFSProxy(object):
return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
def __str__(self):
name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
name = getattr(
self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
return '<%s: %s>' % (self.__class__.__name__, name)
def __eq__(self, other):
@ -1135,7 +1264,8 @@ class GridFSProxy(object):
@property
def fs(self):
if not self._fs:
self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name)
self._fs = gridfs.GridFS(
get_db(self.db_alias), self.collection_name)
return self._fs
def get(self, id=None):
@ -1209,6 +1339,7 @@ class GridFSProxy(object):
class FileField(BaseField):
"""A GridFS storage field.
.. versionadded:: 0.4
@ -1253,7 +1384,8 @@ class FileField(BaseField):
pass
# Create a new proxy object as we don't already have one
instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
instance._data[key] = self.get_proxy_obj(
key=key, instance=instance)
instance._data[key].put(value)
else:
instance._data[key] = value
@ -1291,11 +1423,13 @@ class FileField(BaseField):
class ImageGridFsProxy(GridFSProxy):
"""
Proxy for ImageField
versionadded: 0.6
"""
def put(self, file_obj, **kwargs):
"""
Insert a image in database
@ -1341,7 +1475,8 @@ class ImageGridFsProxy(GridFSProxy):
size = field.thumbnail_size
if size['force']:
thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS)
thumbnail = ImageOps.fit(
img, (size['width'], size['height']), Image.ANTIALIAS)
else:
thumbnail = img.copy()
thumbnail.thumbnail((size['width'],
@ -1367,7 +1502,7 @@ class ImageGridFsProxy(GridFSProxy):
**kwargs)
def delete(self, *args, **kwargs):
#deletes thumbnail
# deletes thumbnail
out = self.get()
if out and out.thumbnail_id:
self.fs.delete(out.thumbnail_id)
@ -1427,6 +1562,7 @@ class ImproperlyConfigured(Exception):
class ImageField(FileField):
"""
A Image File storage field.
@ -1465,6 +1601,7 @@ class ImageField(FileField):
class SequenceField(BaseField):
"""Provides a sequental counter see:
http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers
@ -1534,7 +1671,7 @@ class SequenceField(BaseField):
data = collection.find_one({"_id": sequence_id})
if data:
return self.value_decorator(data['next']+1)
return self.value_decorator(data['next'] + 1)
return self.value_decorator(1)
@ -1579,6 +1716,7 @@ class SequenceField(BaseField):
class UUIDField(BaseField):
"""A UUID field.
.. versionadded:: 0.6
@ -1631,6 +1769,7 @@ class UUIDField(BaseField):
class GeoPointField(BaseField):
"""A list storing a longitude and latitude coordinate.
.. note:: this represents a generic point in a 2D plane and a legacy way of
@ -1651,13 +1790,16 @@ class GeoPointField(BaseField):
'of (x, y)')
if not len(value) == 2:
self.error("Value (%s) must be a two-dimensional point" % repr(value))
self.error("Value (%s) must be a two-dimensional point" %
repr(value))
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
self.error("Both values (%s) in point must be float or int" % repr(value))
self.error(
"Both values (%s) in point must be float or int" % repr(value))
class PointField(GeoJsonBaseField):
"""A GeoJSON field storing a longitude and latitude coordinate.
The data is represented as:
@ -1677,6 +1819,7 @@ class PointField(GeoJsonBaseField):
class LineStringField(GeoJsonBaseField):
"""A GeoJSON field storing a line of longitude and latitude coordinates.
The data is represented as:
@ -1695,6 +1838,7 @@ class LineStringField(GeoJsonBaseField):
class PolygonField(GeoJsonBaseField):
"""A GeoJSON field storing a polygon of longitude and latitude coordinates.
The data is represented as:

View File

@ -47,7 +47,8 @@ class FieldTest(unittest.TestCase):
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid'])
self.assertEqual(
data_to_be_saved, ['age', 'created', 'name', 'userid'])
self.assertTrue(person.validate() is None)
@ -63,7 +64,8 @@ class FieldTest(unittest.TestCase):
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid'])
self.assertEqual(
data_to_be_saved, ['age', 'created', 'name', 'userid'])
def test_default_values_set_to_None(self):
"""Ensure that default field values are used when creating a document.
@ -587,7 +589,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped
# Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
log = LogEntry()
@ -688,7 +691,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields
# Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
@ -696,14 +700,16 @@ class FieldTest(unittest.TestCase):
log.reload()
self.assertEqual(log.date, d1)
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields
# Post UTC - microseconds are rounded (down) nearest millisecond - with
# default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
log.date = d1
log.save()
log.reload()
self.assertEqual(log.date, d1)
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields
# Pre UTC dates microseconds below 1000 are dropped - with default
# datetimefields
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
log.date = d1
log.save()
@ -929,12 +935,16 @@ class FieldTest(unittest.TestCase):
post.save()
self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEqual(BlogPost.objects.filter(info__0__test='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__0__test='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEqual(BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_list_field_passed_in_value(self):
@ -951,7 +961,6 @@ class FieldTest(unittest.TestCase):
foo.bars.append(bar)
self.assertEqual(repr(foo.bars), '[<Bar: Bar object>]')
def test_list_field_strict(self):
"""Ensure that list field handles validation if provided a strict field type."""
@ -1082,20 +1091,28 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
# Test querying
self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__1__value=10).count(), 1)
Simple.objects().update(
set__mapping__2__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEqual(
Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
Simple.drop_collection()
@ -1141,12 +1158,16 @@ class FieldTest(unittest.TestCase):
post.save()
self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEqual(BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEqual(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
post = BlogPost.objects.create(info={'title': 'original'})
post.info.update({'title': 'updated'})
@ -1207,19 +1228,26 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
# Test querying
self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(
set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update(
set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
Simple.drop_collection()
@ -1477,6 +1505,151 @@ class FieldTest(unittest.TestCase):
mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId))
def test_cached_reference_fields(self):
class Animal(Document):
name = StringField()
tag = StringField()
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy")
a.save()
o = Ocorrence(person="teste", animal=a)
o.save()
self.assertEqual(
a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
count = Ocorrence.objects(animal__tag='heavy').count()
self.assertEqual(count, 1)
ocorrence = Ocorrence.objects(animal__tag='heavy').first()
self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_embedded_fields(self):
class Owner(EmbeddedDocument):
TPS = (
('n', "Normal"),
('u', "Urgent")
)
name = StringField()
tp = StringField(
verbose_name="Type",
db_field="t",
choices=TPS)
class Animal(Document):
name = StringField()
tag = StringField()
owner = EmbeddedDocumentField(Owner)
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag', 'owner.tp'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy",
owner=Owner(tp='u', name="Wilson Júnior")
)
a.save()
o = Ocorrence(person="teste", animal=a)
o.save()
self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), {
'_id': a.pk,
'tag': 'heavy',
'owner': {
't': 'u'
}
})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u')
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
count = Ocorrence.objects(
animal__tag='heavy', animal__owner__tp='u').count()
self.assertEqual(count, 1)
ocorrence = Ocorrence.objects(
animal__tag='heavy',
animal__owner__tp='u').first()
self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_embedded_list_fields(self):
class Owner(EmbeddedDocument):
name = StringField()
tags = ListField(StringField())
class Animal(Document):
name = StringField()
tag = StringField()
owner = EmbeddedDocumentField(Owner)
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag', 'owner.tags'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy",
owner=Owner(tags=['cool', 'funny'],
name="Wilson Júnior")
)
a.save()
o = Ocorrence(person="teste 2", animal=a)
o.save()
self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), {
'_id': a.pk,
'tag': 'heavy',
'owner': {
'tags': ['cool', 'funny']
}
})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
self.assertEqual(o.to_mongo()['animal']['owner']['tags'],
['cool', 'funny'])
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
query = Ocorrence.objects(
animal__tag='heavy', animal__owner__tags='cool')._query
self.assertEqual(
query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'})
ocorrence = Ocorrence.objects(
animal__tag='heavy',
animal__owner__tags='cool').first()
self.assertEqual(ocorrence.person, "teste 2")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_objectid_reference_fields(self):
class Person(Document):
@ -1836,7 +2009,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(repr(Person.objects(city=None)),
"[<Person: Person object>]")
def test_generic_reference_choices(self):
"""Ensure that a GenericReferenceField can handle choices
"""
@ -1982,7 +2154,8 @@ class FieldTest(unittest.TestCase):
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07'))
attachment_required.validate()
attachment_size_limit = AttachmentSizeLimit(blob=b('\xe6\x00\xc4\xff\x07'))
attachment_size_limit = AttachmentSizeLimit(
blob=b('\xe6\x00\xc4\xff\x07'))
self.assertRaises(ValidationError, attachment_size_limit.validate)
attachment_size_limit.blob = b('\xe6\x00\xc4\xff')
attachment_size_limit.validate()
@ -2179,7 +2352,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 1000)
def test_sequence_field_get_next_value(self):
class Person(Document):
id = SequenceField(primary_key=True)
@ -2368,7 +2540,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id)
def test_generic_embedded_document(self):
class Car(EmbeddedDocument):
name = StringField()