Merge pull request #709 from wpjunior/cached-reference-field

CachedReferenceField implementation
This commit is contained in:
Wilson Júnior 2014-08-03 21:38:06 -03:00
commit 4814066c67
8 changed files with 796 additions and 144 deletions

View File

@ -84,6 +84,7 @@ Fields
.. autoclass:: mongoengine.fields.MapField .. autoclass:: mongoengine.fields.MapField
.. autoclass:: mongoengine.fields.ReferenceField .. autoclass:: mongoengine.fields.ReferenceField
.. autoclass:: mongoengine.fields.GenericReferenceField .. autoclass:: mongoengine.fields.GenericReferenceField
.. autoclass:: mongoengine.fields.CachedReferenceField
.. autoclass:: mongoengine.fields.BinaryField .. autoclass:: mongoengine.fields.BinaryField
.. autoclass:: mongoengine.fields.FileField .. autoclass:: mongoengine.fields.FileField
.. autoclass:: mongoengine.fields.ImageField .. autoclass:: mongoengine.fields.ImageField

View File

@ -23,9 +23,10 @@ __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__' NON_FIELD_ERRORS = '__all__'
class BaseDocument(object): class BaseDocument(object):
__slots__ = ('_changed_fields', '_initialised', '_created', '_data', __slots__ = ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
_dynamic = False _dynamic = False
_dynamic_lock = True _dynamic_lock = True
@ -50,7 +51,8 @@ class BaseDocument(object):
for value in args: for value in args:
name = next(field) name = next(field)
if name in values: if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'") raise TypeError(
"Multiple values for keyword argument '" + name + "'")
values[name] = value values[name] = value
__auto_convert = values.pop("__auto_convert", True) __auto_convert = values.pop("__auto_convert", True)
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
@ -58,7 +60,8 @@ class BaseDocument(object):
if self.STRICT and not self._dynamic: if self.STRICT and not self._dynamic:
self._data = StrictDict.create(allowed_keys=self._fields_ordered)() self._data = StrictDict.create(allowed_keys=self._fields_ordered)()
else: else:
self._data = SemiStrictDict.create(allowed_keys=self._fields_ordered)() self._data = SemiStrictDict.create(
allowed_keys=self._fields_ordered)()
_created = values.pop("_created", True) _created = values.pop("_created", True)
self._data = {} self._data = {}
@ -144,8 +147,8 @@ class BaseDocument(object):
self__created = True self__created = True
if (self._is_document and not self__created and if (self._is_document and not self__created and
name in self._meta.get('shard_key', tuple()) and name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value): self._data.get(name) != value):
OperationError = _import_class('OperationError') OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg) raise OperationError(msg)
@ -156,8 +159,8 @@ class BaseDocument(object):
self__initialised = False self__initialised = False
# Check if the user has created a new instance of a class # Check if the user has created a new instance of a class
if (self._is_document and self__initialised if (self._is_document and self__initialised
and self__created and name == self._meta['id_field']): and self__created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value) super(BaseDocument, self).__setattr__(name, value)
@ -174,7 +177,7 @@ class BaseDocument(object):
if isinstance(data["_data"], SON): if isinstance(data["_data"], SON):
data["_data"] = self.__class__._from_son(data["_data"])._data data["_data"] = self.__class__._from_son(data["_data"])._data
for k in ('_changed_fields', '_initialised', '_created', '_data', for k in ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields'): '_dynamic_fields'):
if k in data: if k in data:
setattr(self, k, data[k]) setattr(self, k, data[k])
if '_fields_ordered' in data: if '_fields_ordered' in data:
@ -257,23 +260,41 @@ class BaseDocument(object):
""" """
pass pass
def to_mongo(self, use_db_field=True): def to_mongo(self, use_db_field=True, fields=[]):
"""Return as SON data ready for use with MongoDB. """
Return as SON data ready for use with MongoDB.
""" """
data = SON() data = SON()
data["_id"] = None data["_id"] = None
data['_cls'] = self._class_name data['_cls'] = self._class_name
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
root_fields = set([f.split('.')[0] for f in fields])
for field_name in self: for field_name in self:
if root_fields and field_name not in root_fields:
continue
value = self._data.get(field_name, None) value = self._data.get(field_name, None)
field = self._fields.get(field_name) field = self._fields.get(field_name)
if field is None and self._dynamic: if field is None and self._dynamic:
field = self._dynamic_fields.get(field_name) field = self._dynamic_fields.get(field_name)
if value is not None: if value is not None:
EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, (EmbeddedDocument)) and use_db_field==False: if isinstance(field, (EmbeddedDocumentField)):
value = field.to_mongo(value, use_db_field) if fields:
key = '%s.' % field_name
embedded_fields = [
i.replace(key, '') for i in fields
if i.startswith(key)]
else:
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
else: else:
value = field.to_mongo(value) value = field.to_mongo(value)
@ -299,7 +320,7 @@ class BaseDocument(object):
# Only add _cls if allow_inheritance is True # Only add _cls if allow_inheritance is True
if (not hasattr(self, '_meta') or if (not hasattr(self, '_meta') or
not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)):
data.pop('_cls') data.pop('_cls')
return data return data
@ -321,7 +342,8 @@ class BaseDocument(object):
self._data.get(name)) for name in self._fields_ordered] self._data.get(name)) for name in self._fields_ordered]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class(
"GenericEmbeddedDocumentField")
for field, value in fields: for field, value in fields:
if value is not None: if value is not None:
@ -352,7 +374,8 @@ class BaseDocument(object):
"""Converts a document to JSON. """Converts a document to JSON.
:param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False :param use_db_field: Set to True by default but enables the output of the json structure with the field names and not the mongodb store db_names in case of set to False
""" """
use_db_field = kwargs.pop('use_db_field') if kwargs.has_key('use_db_field') else True use_db_field = kwargs.pop('use_db_field') if kwargs.has_key(
'use_db_field') else True
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
@classmethod @classmethod
@ -387,7 +410,7 @@ class BaseDocument(object):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)): not isinstance(value, BaseList)):
value = BaseList(value, self, name) value = BaseList(value, self, name)
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, self, name) value = BaseDict(value, self, name)
@ -452,9 +475,10 @@ class BaseDocument(object):
if hasattr(value, '_get_changed_fields'): if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected) changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k) changed_fields += ["%s%s" % (list_key, k)
for k in changed if k] for k in changed if k]
elif isinstance(value, (list, tuple, dict)): elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected) self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None): def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
@ -484,16 +508,17 @@ class BaseDocument(object):
if isinstance(field, ReferenceField): if isinstance(field, ReferenceField):
continue continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in changed_fields): and db_field_name not in changed_fields):
# Find all embedded fields that have been changed # Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected) changed = data._get_changed_fields(inspected)
changed_fields += ["%s%s" % (key, k) for k in changed if k] changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in changed_fields): db_field_name not in changed_fields):
if (hasattr(field, 'field') and if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)): isinstance(field.field, ReferenceField)):
continue continue
self._nestable_types_changed_fields(changed_fields, key, data, inspected) self._nestable_types_changed_fields(
changed_fields, key, data, inspected)
return changed_fields return changed_fields
def _delta(self): def _delta(self):
@ -539,7 +564,7 @@ class BaseDocument(object):
# If we've set a value that ain't the default value dont unset it. # If we've set a value that ain't the default value dont unset it.
default = None default = None
if (self._dynamic and len(parts) and parts[0] in if (self._dynamic and len(parts) and parts[0] in
self._dynamic_fields): self._dynamic_fields):
del(set_data[path]) del(set_data[path])
unset_data[path] = 1 unset_data[path] = 1
continue continue
@ -625,13 +650,14 @@ class BaseDocument(object):
if errors_dict: if errors_dict:
errors = "\n".join(["%s - %s" % (k, v) errors = "\n".join(["%s - %s" % (k, v)
for k, v in errors_dict.items()]) for k, v in errors_dict.items()])
msg = ("Invalid data to create a `%s` instance.\n%s" msg = ("Invalid data to create a `%s` instance.\n%s"
% (cls._class_name, errors)) % (cls._class_name, errors))
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
if cls.STRICT: if cls.STRICT:
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) data = dict((k, v)
for k, v in data.iteritems() if k in cls._fields)
obj = cls(__auto_convert=False, _created=False, **data) obj = cls(__auto_convert=False, _created=False, **data)
obj._changed_fields = changed_fields obj._changed_fields = changed_fields
if not _auto_dereference: if not _auto_dereference:
@ -778,7 +804,7 @@ class BaseDocument(object):
# Grab any embedded document field unique indexes # Grab any embedded document field unique indexes
if (field.__class__.__name__ == "EmbeddedDocumentField" and if (field.__class__.__name__ == "EmbeddedDocumentField" and
field.document_type != cls): field.document_type != cls):
field_namespace = "%s." % field_name field_namespace = "%s." % field_name
doc_cls = field.document_type doc_cls = field.document_type
unique_indexes += doc_cls._unique_with_indexes(field_namespace) unique_indexes += doc_cls._unique_with_indexes(field_namespace)
@ -794,7 +820,8 @@ class BaseDocument(object):
geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField",
"PointField", "LineStringField", "PolygonField"] "PointField", "LineStringField", "PolygonField"]
geo_field_types = tuple([_import_class(field) for field in geo_field_type_names]) geo_field_types = tuple([_import_class(field)
for field in geo_field_type_names])
for field in cls._fields.values(): for field in cls._fields.values():
if not isinstance(field, geo_field_types): if not isinstance(field, geo_field_types):
@ -804,13 +831,14 @@ class BaseDocument(object):
if field_cls in inspected: if field_cls in inspected:
continue continue
if hasattr(field_cls, '_geo_indices'): if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected, parent_field=field.db_field) geo_indices += field_cls._geo_indices(
inspected, parent_field=field.db_field)
elif field._geo_index: elif field._geo_index:
field_name = field.db_field field_name = field.db_field
if parent_field: if parent_field:
field_name = "%s.%s" % (parent_field, field_name) field_name = "%s.%s" % (parent_field, field_name)
geo_indices.append({'fields': geo_indices.append({'fields':
[(field_name, field._geo_index)]}) [(field_name, field._geo_index)]})
return geo_indices return geo_indices
@classmethod @classmethod

View File

@ -11,10 +11,12 @@ from mongoengine.errors import ValidationError
from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.datastructures import BaseDict, BaseList
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") __all__ = ("BaseField", "ComplexBaseField",
"ObjectIdField", "GeoJsonBaseField")
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
@ -60,6 +62,7 @@ class BaseField(object):
used when generating model forms from the document model. used when generating model forms from the document model.
""" """
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
if name: if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
@ -105,7 +108,7 @@ class BaseField(object):
if instance._initialised: if instance._initialised:
try: try:
if (self.name not in instance._data or if (self.name not in instance._data or
instance._data[self.name] != value): instance._data[self.name] != value):
instance._mark_as_changed(self.name) instance._mark_as_changed(self.name)
except: except:
# Values cant be compared eg: naive and tz datetimes # Values cant be compared eg: naive and tz datetimes
@ -175,6 +178,7 @@ class BaseField(object):
class ComplexBaseField(BaseField): class ComplexBaseField(BaseField):
"""Handles complex fields, such as lists / dictionaries. """Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types. Allows for nesting of embedded documents inside complex types.
@ -197,7 +201,7 @@ class ComplexBaseField(BaseField):
GenericReferenceField = _import_class('GenericReferenceField') GenericReferenceField = _import_class('GenericReferenceField')
dereference = (self._auto_dereference and dereference = (self._auto_dereference and
(self.field is None or isinstance(self.field, (self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField)))) (GenericReferenceField, ReferenceField))))
_dereference = _import_class("DeReference")() _dereference = _import_class("DeReference")()
@ -212,7 +216,7 @@ class ComplexBaseField(BaseField):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)): not isinstance(value, BaseList)):
value = BaseList(value, instance, self.name) value = BaseList(value, instance, self.name)
instance._data[self.name] = value instance._data[self.name] = value
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
@ -220,8 +224,8 @@ class ComplexBaseField(BaseField):
instance._data[self.name] = value instance._data[self.name] = value
if (self._auto_dereference and instance._initialised and if (self._auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict)) isinstance(value, (BaseList, BaseDict))
and not value._dereferenced): and not value._dereferenced):
value = _dereference( value = _dereference(
value, max_depth=1, instance=instance, name=self.name value, max_depth=1, instance=instance, name=self.name
) )
@ -384,6 +388,7 @@ class ComplexBaseField(BaseField):
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
"""A field wrapper around MongoDB's ObjectIds. """A field wrapper around MongoDB's ObjectIds.
""" """
@ -412,6 +417,7 @@ class ObjectIdField(BaseField):
class GeoJsonBaseField(BaseField): class GeoJsonBaseField(BaseField):
"""A geo json field storing a geojson style object. """A geo json field storing a geojson style object.
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
@ -435,7 +441,8 @@ class GeoJsonBaseField(BaseField):
if isinstance(value, dict): if isinstance(value, dict):
if set(value.keys()) == set(['type', 'coordinates']): if set(value.keys()) == set(['type', 'coordinates']):
if value['type'] != self._type: if value['type'] != self._type:
self.error('%s type must be "%s"' % (self._name, self._type)) self.error('%s type must be "%s"' %
(self._name, self._type))
return self.validate(value['coordinates']) return self.validate(value['coordinates'])
else: else:
self.error('%s can only accept a valid GeoJson dictionary' self.error('%s can only accept a valid GeoJson dictionary'

View File

@ -16,6 +16,7 @@ __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass')
class DocumentMetaclass(type): class DocumentMetaclass(type):
"""Metaclass for all documents. """Metaclass for all documents.
""" """
@ -29,6 +30,7 @@ class DocumentMetaclass(type):
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
attrs['_is_document'] = attrs.get('_is_document', False) attrs['_is_document'] = attrs.get('_is_document', False)
attrs['_cached_reference_fields'] = []
# EmbeddedDocuments could have meta data for inheritance # EmbeddedDocuments could have meta data for inheritance
if 'meta' in attrs: if 'meta' in attrs:
@ -90,7 +92,7 @@ class DocumentMetaclass(type):
# Set _fields and db_field maps # Set _fields and db_field maps
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_reverse_db_field_map'] = dict( attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems()) (v, k) for k, v in attrs['_db_field_map'].iteritems())
@ -105,7 +107,7 @@ class DocumentMetaclass(type):
class_name = [name] class_name = [name]
for base in flattened_bases: for base in flattened_bases:
if (not getattr(base, '_is_base_cls', True) and if (not getattr(base, '_is_base_cls', True) and
not getattr(base, '_meta', {}).get('abstract', True)): not getattr(base, '_meta', {}).get('abstract', True)):
# Collate heirarchy for _cls and _subclasses # Collate heirarchy for _cls and _subclasses
class_name.append(base.__name__) class_name.append(base.__name__)
@ -115,7 +117,7 @@ class DocumentMetaclass(type):
allow_inheritance = base._meta.get('allow_inheritance', allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
if (allow_inheritance is not True and if (allow_inheritance is not True and
not base._meta.get('abstract')): not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
base.__name__) base.__name__)
@ -141,7 +143,8 @@ class DocumentMetaclass(type):
base._subclasses += (_cls,) base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types base._types = base._subclasses # TODO depreciate _types
Document, EmbeddedDocument, DictField = cls._import_classes() (Document, EmbeddedDocument, DictField,
CachedReferenceField) = cls._import_classes()
if issubclass(new_class, Document): if issubclass(new_class, Document):
new_class._collection = None new_class._collection = None
@ -170,6 +173,20 @@ class DocumentMetaclass(type):
f = field f = field
f.owner_document = new_class f.owner_document = new_class
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
if isinstance(f, CachedReferenceField):
if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError(
"CachedReferenceFields is not allowed in EmbeddedDocuments")
if not f.document_type:
raise InvalidDocumentError(
"Document is not avaiable to sync")
if f.auto_sync:
f.start_listener()
f.document_type._cached_reference_fields.append(f)
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
delete_rule = getattr(f.field, delete_rule = getattr(f.field,
'reverse_delete_rule', 'reverse_delete_rule',
@ -191,7 +208,7 @@ class DocumentMetaclass(type):
field.name, delete_rule) field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()): EmbeddedDocument not in new_class.mro()):
msg = ("%s is a document method and not a valid " msg = ("%s is a document method and not a valid "
"field name" % field.name) "field name" % field.name)
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
@ -224,10 +241,12 @@ class DocumentMetaclass(type):
Document = _import_class('Document') Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
DictField = _import_class('DictField') DictField = _import_class('DictField')
return (Document, EmbeddedDocument, DictField) CachedReferenceField = _import_class('CachedReferenceField')
return (Document, EmbeddedDocument, DictField, CachedReferenceField)
class TopLevelDocumentMetaclass(DocumentMetaclass): class TopLevelDocumentMetaclass(DocumentMetaclass):
"""Metaclass for top-level documents (i.e. documents that have their own """Metaclass for top-level documents (i.e. documents that have their own
collection in the database. collection in the database.
""" """
@ -275,21 +294,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Find the parent document class # Find the parent document class
parent_doc_cls = [b for b in flattened_bases parent_doc_cls = [b for b in flattened_bases
if b.__class__ == TopLevelDocumentMetaclass] if b.__class__ == TopLevelDocumentMetaclass]
parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0]
# Prevent classes setting collection different to their parents # Prevent classes setting collection different to their parents
# If parent wasn't an abstract class # If parent wasn't an abstract class
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) if (parent_doc_cls and 'collection' in attrs.get('_meta', {})
and not parent_doc_cls._meta.get('abstract', True)): and not parent_doc_cls._meta.get('abstract', True)):
msg = "Trying to set a collection on a subclass (%s)" % name msg = "Trying to set a collection on a subclass (%s)" % name
warnings.warn(msg, SyntaxWarning) warnings.warn(msg, SyntaxWarning)
del(attrs['_meta']['collection']) del(attrs['_meta']['collection'])
# Ensure abstract documents have abstract bases # Ensure abstract documents have abstract bases
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
if (parent_doc_cls and if (parent_doc_cls and
not parent_doc_cls._meta.get('abstract', False)): not parent_doc_cls._meta.get('abstract', False)):
msg = "Abstract document cannot have non-abstract base" msg = "Abstract document cannot have non-abstract base"
raise ValueError(msg) raise ValueError(msg)
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
@ -306,7 +325,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Set collection in the meta if its callable # Set collection in the meta if its callable
if (getattr(base, '_is_document', False) and if (getattr(base, '_is_document', False) and
not base._meta.get('abstract')): not base._meta.get('abstract')):
collection = meta.get('collection', None) collection = meta.get('collection', None)
if callable(collection): if callable(collection):
meta['collection'] = collection(base) meta['collection'] = collection(base)
@ -318,7 +337,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
simple_class = all([b._meta.get('abstract') simple_class = all([b._meta.get('abstract')
for b in flattened_bases if hasattr(b, '_meta')]) for b in flattened_bases if hasattr(b, '_meta')])
if (not simple_class and meta['allow_inheritance'] is False and if (not simple_class and meta['allow_inheritance'] is False and
not meta['abstract']): not meta['abstract']):
raise ValueError('Only direct subclasses of Document may set ' raise ValueError('Only direct subclasses of Document may set '
'"allow_inheritance" to False') '"allow_inheritance" to False')
@ -378,7 +397,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
for exc in exceptions_to_merge: for exc in exceptions_to_merge:
name = exc.__name__ name = exc.__name__
parents = tuple(getattr(base, name) for base in flattened_bases parents = tuple(getattr(base, name) for base in flattened_bases
if hasattr(base, name)) or (exc,) if hasattr(base, name)) or (exc,)
# Create new exception and set to new_class # Create new exception and set to new_class
exception = type(name, parents, {'__module__': module}) exception = type(name, parents, {'__module__': module})
setattr(new_class, name, exception) setattr(new_class, name, exception)
@ -387,6 +406,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
class MetaDict(dict): class MetaDict(dict):
"""Custom dictionary for meta classes. """Custom dictionary for meta classes.
Handles the merging of set indexes Handles the merging of set indexes
""" """
@ -401,5 +421,6 @@ class MetaDict(dict):
class BasesTuple(tuple): class BasesTuple(tuple):
"""Special class to handle introspection of bases tuple in __new__""" """Special class to handle introspection of bases tuple in __new__"""
pass pass

View File

@ -25,6 +25,7 @@ def _import_class(cls_name):
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'ListField', 'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField', 'PolygonField', 'ReferenceField', 'StringField',
'CachedReferenceField',
'ComplexBaseField', 'GeoJsonBaseField') 'ComplexBaseField', 'GeoJsonBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View File

@ -34,22 +34,24 @@ except ImportError:
Image = None Image = None
ImageOps = None ImageOps = None
__all__ = ['StringField', 'URLField', 'EmailField', 'IntField', 'LongField', __all__ = [
'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField',
'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
'SortedListField', 'DictField', 'MapField', 'ReferenceField', 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
'GenericReferenceField', 'BinaryField', 'GridFSError', 'SortedListField', 'DictField', 'MapField', 'ReferenceField',
'GridFSProxy', 'FileField', 'ImageGridFsProxy', 'CachedReferenceField', 'GenericReferenceField', 'BinaryField',
'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', 'GridFSError', 'GridFSProxy', 'FileField', 'ImageGridFsProxy',
'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField',
'GeoJsonBaseField'] 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField',
'GeoJsonBaseField']
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField): class StringField(BaseField):
"""A unicode string field. """A unicode string field.
""" """
@ -109,6 +111,7 @@ class StringField(BaseField):
class URLField(StringField): class URLField(StringField):
"""A field that validates input as an URL. """A field that validates input as an URL.
.. versionadded:: 0.3 .. versionadded:: 0.3
@ -116,7 +119,8 @@ class URLField(StringField):
_URL_REGEX = re.compile( _URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https:// r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... # domain...
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'
r'localhost|' # localhost... r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port r'(?::\d+)?' # optional port
@ -145,15 +149,19 @@ class URLField(StringField):
class EmailField(StringField): class EmailField(StringField):
"""A field that validates input as an E-Mail-Address. """A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4 .. versionadded:: 0.4
""" """
EMAIL_REGEX = re.compile( EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*"
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE # domain # quoted-string
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"'
# domain
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}$', re.IGNORECASE
) )
def validate(self, value): def validate(self, value):
@ -163,6 +171,7 @@ class EmailField(StringField):
class IntField(BaseField): class IntField(BaseField):
"""An 32-bit integer field. """An 32-bit integer field.
""" """
@ -197,6 +206,7 @@ class IntField(BaseField):
class LongField(BaseField): class LongField(BaseField):
"""An 64-bit integer field. """An 64-bit integer field.
""" """
@ -231,6 +241,7 @@ class LongField(BaseField):
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
""" """
@ -265,6 +276,7 @@ class FloatField(BaseField):
class DecimalField(BaseField): class DecimalField(BaseField):
"""A fixed-point decimal number field. """A fixed-point decimal number field.
.. versionchanged:: 0.8 .. versionchanged:: 0.8
@ -338,6 +350,7 @@ class DecimalField(BaseField):
class BooleanField(BaseField): class BooleanField(BaseField):
"""A boolean field type. """A boolean field type.
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
@ -356,6 +369,7 @@ class BooleanField(BaseField):
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """A datetime field.
Uses the python-dateutil library if available alternatively use time.strptime Uses the python-dateutil library if available alternatively use time.strptime
@ -406,15 +420,15 @@ class DateTimeField(BaseField):
kwargs = {'microsecond': usecs} kwargs = {'microsecond': usecs}
try: # Seconds are optional, so try converting seconds first. try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, return datetime.datetime(*time.strptime(value,
'%Y-%m-%d %H:%M:%S')[:6], **kwargs) '%Y-%m-%d %H:%M:%S')[:6], **kwargs)
except ValueError: except ValueError:
try: # Try without seconds. try: # Try without seconds.
return datetime.datetime(*time.strptime(value, return datetime.datetime(*time.strptime(value,
'%Y-%m-%d %H:%M')[:5], **kwargs) '%Y-%m-%d %H:%M')[:5], **kwargs)
except ValueError: # Try without hour/minutes/seconds. except ValueError: # Try without hour/minutes/seconds.
try: try:
return datetime.datetime(*time.strptime(value, return datetime.datetime(*time.strptime(value,
'%Y-%m-%d')[:3], **kwargs) '%Y-%m-%d')[:3], **kwargs)
except ValueError: except ValueError:
return None return None
@ -423,6 +437,7 @@ class DateTimeField(BaseField):
class ComplexDateTimeField(StringField): class ComplexDateTimeField(StringField):
""" """
ComplexDateTimeField handles microseconds exactly instead of rounding ComplexDateTimeField handles microseconds exactly instead of rounding
like DateTimeField does. like DateTimeField does.
@ -525,6 +540,7 @@ class ComplexDateTimeField(StringField):
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
"""An embedded document field - with a declared document_type. """An embedded document field - with a declared document_type.
Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`.
""" """
@ -551,10 +567,11 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._from_son(value) return self.document_type._from_son(value)
return value return value
def to_mongo(self, value, use_db_field=True): def to_mongo(self, value, use_db_field=True, fields=[]):
if not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
return value return value
return self.document_type.to_mongo(value, use_db_field) return self.document_type.to_mongo(value, use_db_field,
fields=fields)
def validate(self, value, clean=True): def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
@ -574,6 +591,7 @@ class EmbeddedDocumentField(BaseField):
class GenericEmbeddedDocumentField(BaseField): class GenericEmbeddedDocumentField(BaseField):
"""A generic embedded document field - allows any """A generic embedded document field - allows any
:class:`~mongoengine.EmbeddedDocument` to be stored. :class:`~mongoengine.EmbeddedDocument` to be stored.
@ -612,6 +630,7 @@ class GenericEmbeddedDocumentField(BaseField):
class DynamicField(BaseField): class DynamicField(BaseField):
"""A truly dynamic field type capable of handling different and varying """A truly dynamic field type capable of handling different and varying
types of data. types of data.
@ -675,6 +694,7 @@ class DynamicField(BaseField):
class ListField(ComplexBaseField): class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances """A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database. of the field to be used as a list in the database.
@ -693,21 +713,22 @@ class ListField(ComplexBaseField):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
if (not isinstance(value, (list, tuple, QuerySet)) or if (not isinstance(value, (list, tuple, QuerySet)) or
isinstance(value, basestring)): isinstance(value, basestring)):
self.error('Only lists and tuples may be used in a list field') self.error('Only lists and tuples may be used in a list field')
super(ListField, self).validate(value) super(ListField, self).validate(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if self.field: if self.field:
if op in ('set', 'unset') and (not isinstance(value, basestring) if op in ('set', 'unset') and (not isinstance(value, basestring)
and not isinstance(value, BaseDocument) and not isinstance(value, BaseDocument)
and hasattr(value, '__iter__')): and hasattr(value, '__iter__')):
return [self.field.prepare_query_value(op, v) for v in value] return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value) return self.field.prepare_query_value(op, value)
return super(ListField, self).prepare_query_value(op, value) return super(ListField, self).prepare_query_value(op, value)
class SortedListField(ListField): class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to """A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always the database in order to ensure that a sorted list is always
retrieved. retrieved.
@ -739,6 +760,7 @@ class SortedListField(ListField):
reverse=self._order_reverse) reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse) return sorted(value, reverse=self._order_reverse)
def key_not_string(d): def key_not_string(d):
""" Helper function to recursively determine if any key in a dictionary is """ Helper function to recursively determine if any key in a dictionary is
not a string. not a string.
@ -747,6 +769,7 @@ def key_not_string(d):
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)):
return True return True
def key_has_dot_or_dollar(d): def key_has_dot_or_dollar(d):
""" Helper function to recursively determine if any key in a dictionary """ Helper function to recursively determine if any key in a dictionary
contains a dot or a dollar sign. contains a dot or a dollar sign.
@ -755,7 +778,9 @@ def key_has_dot_or_dollar(d):
if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
return True return True
class DictField(ComplexBaseField): class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is """A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined. similar to an embedded document, but the structure is not defined.
@ -807,6 +832,7 @@ class DictField(ComplexBaseField):
class MapField(DictField): class MapField(DictField):
"""A field that maps a name to a specified field type. Similar to """A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified a DictField, except the 'value' of each item must match the specified
field type. field type.
@ -822,6 +848,7 @@ class MapField(DictField):
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
access (lazily). access (lazily).
@ -932,7 +959,7 @@ class ReferenceField(BaseField):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
if (not self.dbref and if (not self.dbref and
not isinstance(value, (DBRef, Document, EmbeddedDocument))): not isinstance(value, (DBRef, Document, EmbeddedDocument))):
collection = self.document_type._get_collection_name() collection = self.document_type._get_collection_name()
value = DBRef(collection, self.document_type.id.to_python(value)) value = DBRef(collection, self.document_type.id.to_python(value))
return value return value
@ -955,7 +982,147 @@ class ReferenceField(BaseField):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
class CachedReferenceField(BaseField):
"""
A referencefield with cache fields to porpuse pseudo-joins
.. versionadded:: 0.9
"""
def __init__(self, document_type, fields=[], auto_sync=True, **kwargs):
"""Initialises the Cached Reference Field.
:param fields: A list of fields to be cached in document
:param auto_sync: if True documents are auto updated.
"""
if not isinstance(document_type, basestring) and \
not issubclass(document_type, (Document, basestring)):
self.error('Argument to CachedReferenceField constructor must be a'
' document class or a string')
self.auto_sync = auto_sync
self.document_type_obj = document_type
self.fields = fields
super(CachedReferenceField, self).__init__(**kwargs)
def start_listener(self):
from mongoengine import signals
signals.post_save.connect(self.on_document_pre_save,
sender=self.document_type)
def on_document_pre_save(self, sender, document, created, **kwargs):
if not created:
update_kwargs = dict(
('set__%s__%s' % (self.name, k), v)
for k, v in document._delta()[0].items()
if k in self.fields)
if update_kwargs:
filter_kwargs = {}
filter_kwargs[self.name] = document
self.owner_document.objects(
**filter_kwargs).update(**update_kwargs)
def to_python(self, value):
if isinstance(value, dict):
collection = self.document_type._get_collection_name()
value = DBRef(
collection, self.document_type.id.to_python(value['_id']))
return value
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def __get__(self, instance, owner):
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
doc_tipe = self.document_type
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
if id_ is None:
self.error('You can only reference documents once they have'
' been saved to the database')
else:
self.error('Only accept a document object')
value = SON((
("_id", id_field.to_mongo(id_)),
))
value.update(dict(document.to_mongo(fields=self.fields)))
return value
def prepare_query_value(self, op, value):
if value is None:
return None
if isinstance(value, Document):
if value.pk is None:
self.error('You can only reference documents once they have'
' been saved to the database')
return {'_id': value.pk}
raise NotImplementedError
def validate(self, value):
if not isinstance(value, (self.document_type)):
self.error("A CachedReferenceField only accepts documents")
if isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been '
'saved to the database')
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
def sync_all(self):
"""
Sync all cached fields on demand.
Caution: this operation may be slower.
"""
update_key = 'set__%s' % self.name
for doc in self.document_type.objects:
filter_kwargs = {}
filter_kwargs[self.name] = doc
update_kwargs = {}
update_kwargs[update_key] = doc
self.owner_document.objects(
**filter_kwargs).update(**update_kwargs)
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
@ -974,6 +1141,7 @@ class GenericReferenceField(BaseField):
return self return self
value = instance._data.get(self.name) value = instance._data.get(self.name)
self._auto_dereference = instance._fields[self.name]._auto_dereference self._auto_dereference = instance._fields[self.name]._auto_dereference
if self._auto_dereference and isinstance(value, (dict, SON)): if self._auto_dereference and isinstance(value, (dict, SON)):
instance._data[self.name] = self.dereference(value) instance._data[self.name] = self.dereference(value)
@ -1036,6 +1204,7 @@ class GenericReferenceField(BaseField):
class BinaryField(BaseField): class BinaryField(BaseField):
"""A binary data field. """A binary data field.
""" """
@ -1056,7 +1225,7 @@ class BinaryField(BaseField):
if not isinstance(value, (bin_type, txt_type, Binary)): if not isinstance(value, (bin_type, txt_type, Binary)):
self.error("BinaryField only accepts instances of " self.error("BinaryField only accepts instances of "
"(%s, %s, Binary)" % ( "(%s, %s, Binary)" % (
bin_type.__name__, txt_type.__name__)) bin_type.__name__, txt_type.__name__))
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
self.error('Binary value is too long') self.error('Binary value is too long')
@ -1067,6 +1236,7 @@ class GridFSError(Exception):
class GridFSProxy(object): class GridFSProxy(object):
"""Proxy object to handle writing and reading of files to and from GridFS """Proxy object to handle writing and reading of files to and from GridFS
.. versionadded:: 0.4 .. versionadded:: 0.4
@ -1121,7 +1291,8 @@ class GridFSProxy(object):
return '<%s: %s>' % (self.__class__.__name__, self.grid_id) return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
def __str__(self): def __str__(self):
name = getattr(self.get(), 'filename', self.grid_id) if self.get() else '(no file)' name = getattr(
self.get(), 'filename', self.grid_id) if self.get() else '(no file)'
return '<%s: %s>' % (self.__class__.__name__, name) return '<%s: %s>' % (self.__class__.__name__, name)
def __eq__(self, other): def __eq__(self, other):
@ -1135,7 +1306,8 @@ class GridFSProxy(object):
@property @property
def fs(self): def fs(self):
if not self._fs: if not self._fs:
self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name) self._fs = gridfs.GridFS(
get_db(self.db_alias), self.collection_name)
return self._fs return self._fs
def get(self, id=None): def get(self, id=None):
@ -1209,6 +1381,7 @@ class GridFSProxy(object):
class FileField(BaseField): class FileField(BaseField):
"""A GridFS storage field. """A GridFS storage field.
.. versionadded:: 0.4 .. versionadded:: 0.4
@ -1253,7 +1426,8 @@ class FileField(BaseField):
pass pass
# Create a new proxy object as we don't already have one # Create a new proxy object as we don't already have one
instance._data[key] = self.get_proxy_obj(key=key, instance=instance) instance._data[key] = self.get_proxy_obj(
key=key, instance=instance)
instance._data[key].put(value) instance._data[key].put(value)
else: else:
instance._data[key] = value instance._data[key] = value
@ -1291,11 +1465,13 @@ class FileField(BaseField):
class ImageGridFsProxy(GridFSProxy): class ImageGridFsProxy(GridFSProxy):
""" """
Proxy for ImageField Proxy for ImageField
versionadded: 0.6 versionadded: 0.6
""" """
def put(self, file_obj, **kwargs): def put(self, file_obj, **kwargs):
""" """
Insert a image in database Insert a image in database
@ -1341,7 +1517,8 @@ class ImageGridFsProxy(GridFSProxy):
size = field.thumbnail_size size = field.thumbnail_size
if size['force']: if size['force']:
thumbnail = ImageOps.fit(img, (size['width'], size['height']), Image.ANTIALIAS) thumbnail = ImageOps.fit(
img, (size['width'], size['height']), Image.ANTIALIAS)
else: else:
thumbnail = img.copy() thumbnail = img.copy()
thumbnail.thumbnail((size['width'], thumbnail.thumbnail((size['width'],
@ -1367,7 +1544,7 @@ class ImageGridFsProxy(GridFSProxy):
**kwargs) **kwargs)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
#deletes thumbnail # deletes thumbnail
out = self.get() out = self.get()
if out and out.thumbnail_id: if out and out.thumbnail_id:
self.fs.delete(out.thumbnail_id) self.fs.delete(out.thumbnail_id)
@ -1427,6 +1604,7 @@ class ImproperlyConfigured(Exception):
class ImageField(FileField): class ImageField(FileField):
""" """
A Image File storage field. A Image File storage field.
@ -1465,6 +1643,7 @@ class ImageField(FileField):
class SequenceField(BaseField): class SequenceField(BaseField):
"""Provides a sequental counter see: """Provides a sequental counter see:
http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers
@ -1534,7 +1713,7 @@ class SequenceField(BaseField):
data = collection.find_one({"_id": sequence_id}) data = collection.find_one({"_id": sequence_id})
if data: if data:
return self.value_decorator(data['next']+1) return self.value_decorator(data['next'] + 1)
return self.value_decorator(1) return self.value_decorator(1)
@ -1579,6 +1758,7 @@ class SequenceField(BaseField):
class UUIDField(BaseField): class UUIDField(BaseField):
"""A UUID field. """A UUID field.
.. versionadded:: 0.6 .. versionadded:: 0.6
@ -1631,6 +1811,7 @@ class UUIDField(BaseField):
class GeoPointField(BaseField): class GeoPointField(BaseField):
"""A list storing a longitude and latitude coordinate. """A list storing a longitude and latitude coordinate.
.. note:: this represents a generic point in a 2D plane and a legacy way of .. note:: this represents a generic point in a 2D plane and a legacy way of
@ -1651,13 +1832,16 @@ class GeoPointField(BaseField):
'of (x, y)') 'of (x, y)')
if not len(value) == 2: if not len(value) == 2:
self.error("Value (%s) must be a two-dimensional point" % repr(value)) self.error("Value (%s) must be a two-dimensional point" %
repr(value))
elif (not isinstance(value[0], (float, int)) or elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))): not isinstance(value[1], (float, int))):
self.error("Both values (%s) in point must be float or int" % repr(value)) self.error(
"Both values (%s) in point must be float or int" % repr(value))
class PointField(GeoJsonBaseField): class PointField(GeoJsonBaseField):
"""A GeoJSON field storing a longitude and latitude coordinate. """A GeoJSON field storing a longitude and latitude coordinate.
The data is represented as: The data is represented as:
@ -1677,6 +1861,7 @@ class PointField(GeoJsonBaseField):
class LineStringField(GeoJsonBaseField): class LineStringField(GeoJsonBaseField):
"""A GeoJSON field storing a line of longitude and latitude coordinates. """A GeoJSON field storing a line of longitude and latitude coordinates.
The data is represented as: The data is represented as:
@ -1695,6 +1880,7 @@ class LineStringField(GeoJsonBaseField):
class PolygonField(GeoJsonBaseField): class PolygonField(GeoJsonBaseField):
"""A GeoJSON field storing a polygon of longitude and latitude coordinates. """A GeoJSON field storing a polygon of longitude and latitude coordinates.
The data is represented as: The data is represented as:

View File

@ -12,21 +12,21 @@ __all__ = ('query', 'update')
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not', 'elemMatch') 'all', 'size', 'exists', 'not', 'elemMatch')
GEO_OPERATORS = ('within_distance', 'within_spherical_distance', GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
'within_box', 'within_polygon', 'near', 'near_sphere', 'within_box', 'within_polygon', 'near', 'near_sphere',
'max_distance', 'geo_within', 'geo_within_box', 'max_distance', 'geo_within', 'geo_within_box',
'geo_within_polygon', 'geo_within_center', 'geo_within_polygon', 'geo_within_center',
'geo_within_sphere', 'geo_intersects') 'geo_within_sphere', 'geo_intersects')
STRING_OPERATORS = ('contains', 'icontains', 'startswith', STRING_OPERATORS = ('contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact') 'exact', 'iexact')
CUSTOM_OPERATORS = ('match',) CUSTOM_OPERATORS = ('match',)
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS) STRING_OPERATORS + CUSTOM_OPERATORS)
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set', 'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert') 'set_on_insert')
def query(_doc_cls=None, _field_operation=False, **query): def query(_doc_cls=None, _field_operation=False, **query):
@ -60,14 +60,20 @@ def query(_doc_cls=None, _field_operation=False, **query):
raise InvalidQueryError(e) raise InvalidQueryError(e)
parts = [] parts = []
CachedReferenceField = _import_class('CachedReferenceField')
cleaned_fields = [] cleaned_fields = []
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, basestring):
parts.append(field) parts.append(field)
append_field = False append_field = False
# is last and CachedReferenceField
elif isinstance(field, CachedReferenceField) and fields[-1] == field:
parts.append('%s._id' % field.db_field)
else: else:
parts.append(field.db_field) parts.append(field.db_field)
if append_field: if append_field:
cleaned_fields.append(field) cleaned_fields.append(field)
@ -79,13 +85,17 @@ def query(_doc_cls=None, _field_operation=False, **query):
if op in singular_ops: if op in singular_ops:
if isinstance(field, basestring): if isinstance(field, basestring):
if (op in STRING_OPERATORS and if (op in STRING_OPERATORS and
isinstance(value, basestring)): isinstance(value, basestring)):
StringField = _import_class('StringField') StringField = _import_class('StringField')
value = StringField.prepare_query_value(op, value) value = StringField.prepare_query_value(op, value)
else: else:
value = field value = field
else: else:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value:
value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
@ -125,10 +135,12 @@ def query(_doc_cls=None, _field_operation=False, **query):
continue continue
value_son[k] = v value_son[k] = v
if (get_connection().max_wire_version <= 1): if (get_connection().max_wire_version <= 1):
value_son['$maxDistance'] = value_dict['$maxDistance'] value_son['$maxDistance'] = value_dict[
'$maxDistance']
else: else:
value_son['$near'] = SON(value_son['$near']) value_son['$near'] = SON(value_son['$near'])
value_son['$near']['$maxDistance'] = value_dict['$maxDistance'] value_son['$near'][
'$maxDistance'] = value_dict['$maxDistance']
else: else:
for k, v in value_dict.iteritems(): for k, v in value_dict.iteritems():
if k == '$maxDistance': if k == '$maxDistance':
@ -264,7 +276,8 @@ def update(_doc_cls=None, **update):
if ListField in field_classes: if ListField in field_classes:
# Join all fields via dot notation to the last ListField # Join all fields via dot notation to the last ListField
# Then process as normal # Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField) last_listField = len(
cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField]) key = ".".join(parts[:last_listField])
parts = parts[last_listField:] parts = parts[last_listField:]
parts.insert(0, key) parts.insert(0, key)

View File

@ -47,7 +47,8 @@ class FieldTest(unittest.TestCase):
# Confirm saving now would store values # Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) self.assertEqual(
data_to_be_saved, ['age', 'created', 'name', 'userid'])
self.assertTrue(person.validate() is None) self.assertTrue(person.validate() is None)
@ -63,7 +64,8 @@ class FieldTest(unittest.TestCase):
# Confirm introspection changes nothing # Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys()) data_to_be_saved = sorted(person.to_mongo().keys())
self.assertEqual(data_to_be_saved, ['age', 'created', 'name', 'userid']) self.assertEqual(
data_to_be_saved, ['age', 'created', 'name', 'userid'])
def test_default_values_set_to_None(self): def test_default_values_set_to_None(self):
"""Ensure that default field values are used when creating a document. """Ensure that default field values are used when creating a document.
@ -587,7 +589,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped # Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
log = LogEntry() log = LogEntry()
@ -688,7 +691,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields # Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry() log = LogEntry()
log.date = d1 log.date = d1
@ -696,14 +700,16 @@ class FieldTest(unittest.TestCase):
log.reload() log.reload()
self.assertEqual(log.date, d1) self.assertEqual(log.date, d1)
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields # Post UTC - microseconds are rounded (down) nearest millisecond - with
# default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
log.date = d1 log.date = d1
log.save() log.save()
log.reload() log.reload()
self.assertEqual(log.date, d1) self.assertEqual(log.date, d1)
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields # Pre UTC dates microseconds below 1000 are dropped - with default
# datetimefields
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
log.date = d1 log.date = d1
log.save() log.save()
@ -929,12 +935,16 @@ class FieldTest(unittest.TestCase):
post.save() post.save()
self.assertEqual(BlogPost.objects.count(), 3) self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__exact='test').count(), 1) self.assertEqual(
self.assertEqual(BlogPost.objects.filter(info__0__test='test').count(), 1) BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__0__test='test').count(), 1)
# Confirm handles non strings or non existing keys # Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) self.assertEqual(
self.assertEqual(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_field_passed_in_value(self): def test_list_field_passed_in_value(self):
@ -951,7 +961,6 @@ class FieldTest(unittest.TestCase):
foo.bars.append(bar) foo.bars.append(bar)
self.assertEqual(repr(foo.bars), '[<Bar: Bar object>]') self.assertEqual(repr(foo.bars), '[<Bar: Bar object>]')
def test_list_field_strict(self): def test_list_field_strict(self):
"""Ensure that list field handles validation if provided a strict field type.""" """Ensure that list field handles validation if provided a strict field type."""
@ -1082,20 +1091,28 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
# Test querying # Test querying
self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) self.assertEqual(
Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
# Confirm can update # Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) self.assertEqual(
Simple.objects.filter(mapping__1__value=10).count(), 1)
Simple.objects().update( Simple.objects().update(
set__mapping__2__list__1=StringSetting(value='Boo')) set__mapping__2__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEqual(
Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
Simple.drop_collection() Simple.drop_collection()
@ -1141,12 +1158,16 @@ class FieldTest(unittest.TestCase):
post.save() post.save()
self.assertEqual(BlogPost.objects.count(), 3) self.assertEqual(BlogPost.objects.count(), 3)
self.assertEqual(BlogPost.objects.filter(info__title__exact='test').count(), 1) self.assertEqual(
self.assertEqual(BlogPost.objects.filter(info__details__test__exact='test').count(), 1) BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEqual(
BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
# Confirm handles non strings or non existing keys # Confirm handles non strings or non existing keys
self.assertEqual(BlogPost.objects.filter(info__details__test__exact=5).count(), 0) self.assertEqual(
self.assertEqual(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEqual(
BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
post = BlogPost.objects.create(info={'title': 'original'}) post = BlogPost.objects.create(info={'title': 'original'})
post.info.update({'title': 'updated'}) post.info.update({'title': 'updated'})
@ -1207,19 +1228,26 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
# Test querying # Test querying
self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) self.assertEqual(
Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
# Confirm can update # Confirm can update
Simple.objects().update( Simple.objects().update(
set__mapping={"someint": IntegerSetting(value=10)}) set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update( Simple.objects().update(
set__mapping__nested_dict__list__1=StringSetting(value='Boo')) set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) self.assertEqual(
self.assertEqual(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEqual(
Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
Simple.drop_collection() Simple.drop_collection()
@ -1290,7 +1318,7 @@ class FieldTest(unittest.TestCase):
class Test(Document): class Test(Document):
my_map = MapField(field=EmbeddedDocumentField(Embedded), my_map = MapField(field=EmbeddedDocumentField(Embedded),
db_field='x') db_field='x')
Test.drop_collection() Test.drop_collection()
@ -1334,7 +1362,7 @@ class FieldTest(unittest.TestCase):
Log(name="wilson", visited={'friends': datetime.datetime.now()}).save() Log(name="wilson", visited={'friends': datetime.datetime.now()}).save()
self.assertEqual(1, Log.objects( self.assertEqual(1, Log.objects(
visited__friends__exists=True).count()) visited__friends__exists=True).count())
def test_embedded_db_field(self): def test_embedded_db_field(self):
@ -1477,6 +1505,375 @@ class FieldTest(unittest.TestCase):
mongoed = p1.to_mongo() mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId)) self.assertTrue(isinstance(mongoed['parent'], ObjectId))
def test_cached_reference_fields(self):
class Animal(Document):
name = StringField()
tag = StringField()
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(name="Leopard", tag="heavy")
a.save()
self.assertEqual(Animal._cached_reference_fields, [Ocorrence.animal])
o = Ocorrence(person="teste", animal=a)
o.save()
p = Ocorrence(person="Wilson")
p.save()
self.assertEqual(Ocorrence.objects(animal=None).count(), 1)
self.assertEqual(
a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
count = Ocorrence.objects(animal__tag='heavy').count()
self.assertEqual(count, 1)
ocorrence = Ocorrence.objects(animal__tag='heavy').first()
self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_field_decimal(self):
class PersonAuto(Document):
name = StringField()
salary = DecimalField()
class SocialTest(Document):
group = StringField()
person = CachedReferenceField(
PersonAuto,
fields=('salary',))
PersonAuto.drop_collection()
SocialTest.drop_collection()
p = PersonAuto(name="Alberto", salary=Decimal('7000.00'))
p.save()
s = SocialTest(group="dev", person=p)
s.save()
self.assertEqual(
SocialTest.objects._collection.find_one({'person.salary': 7000.00}), {
'_id': s.pk,
'group': s.group,
'person': {
'_id': p.pk,
'salary': 7000.00
}
})
def test_cached_reference_field_reference(self):
class Group(Document):
name = StringField()
class Person(Document):
name = StringField()
group = ReferenceField(Group)
class SocialData(Document):
obs = StringField()
tags = ListField(
StringField())
person = CachedReferenceField(
Person,
fields=('group',))
Group.drop_collection()
Person.drop_collection()
SocialData.drop_collection()
g1 = Group(name='dev')
g1.save()
g2 = Group(name="designers")
g2.save()
p1 = Person(name="Alberto", group=g1)
p1.save()
p2 = Person(name="Andre", group=g1)
p2.save()
p3 = Person(name="Afro design", group=g2)
p3.save()
s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2'])
s1.save()
s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4'])
s2.save()
self.assertEqual(SocialData.objects._collection.find_one(
{'tags': 'tag2'}), {
'_id': s1.pk,
'obs': 'testing 123',
'tags': ['tag1', 'tag2'],
'person': {
'_id': p1.pk,
'group': g1.pk
}
})
self.assertEqual(SocialData.objects(person__group=g2).count(), 1)
self.assertEqual(SocialData.objects(person__group=g2).first(), s2)
def test_cached_reference_field_update_all(self):
class Person(Document):
TYPES = (
('pf', "PF"),
('pj', "PJ")
)
name = StringField()
tp = StringField(
choices=TYPES
)
father = CachedReferenceField('self', fields=('tp',))
Person.drop_collection()
a1 = Person(name="Wilson Father", tp="pj")
a1.save()
a2 = Person(name='Wilson Junior', tp='pf', father=a1)
a2.save()
self.assertEqual(dict(a2.to_mongo()), {
"_id": a2.pk,
"name": u"Wilson Junior",
"tp": u"pf",
"father": {
"_id": a1.pk,
"tp": u"pj"
}
})
self.assertEqual(Person.objects(father=a1)._query, {
'father._id': a1.pk
})
self.assertEqual(Person.objects(father=a1).count(), 1)
Person.objects.update(set__tp="pf")
Person.father.sync_all()
a2.reload()
self.assertEqual(dict(a2.to_mongo()), {
"_id": a2.pk,
"name": u"Wilson Junior",
"tp": u"pf",
"father": {
"_id": a1.pk,
"tp": u"pf"
}
})
def test_cached_reference_fields_on_embedded_documents(self):
def build():
class Test(Document):
name = StringField()
type('WrongEmbeddedDocument', (
EmbeddedDocument,), {
'test': CachedReferenceField(Test)
})
self.assertRaises(InvalidDocumentError, build)
def test_cached_reference_auto_sync(self):
class Person(Document):
TYPES = (
('pf', "PF"),
('pj', "PJ")
)
name = StringField()
tp = StringField(
choices=TYPES
)
father = CachedReferenceField('self', fields=('tp',))
Person.drop_collection()
a1 = Person(name="Wilson Father", tp="pj")
a1.save()
a2 = Person(name='Wilson Junior', tp='pf', father=a1)
a2.save()
a1.tp = 'pf'
a1.save()
a2.reload()
self.assertEqual(dict(a2.to_mongo()), {
'_id': a2.pk,
'name': 'Wilson Junior',
'tp': 'pf',
'father': {
'_id': a1.pk,
'tp': 'pf'
}
})
def test_cached_reference_auto_sync_disabled(self):
class Persone(Document):
TYPES = (
('pf', "PF"),
('pj', "PJ")
)
name = StringField()
tp = StringField(
choices=TYPES
)
father = CachedReferenceField(
'self', fields=('tp',), auto_sync=False)
Persone.drop_collection()
a1 = Persone(name="Wilson Father", tp="pj")
a1.save()
a2 = Persone(name='Wilson Junior', tp='pf', father=a1)
a2.save()
a1.tp = 'pf'
a1.save()
self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), {
'_id': a2.pk,
'name': 'Wilson Junior',
'tp': 'pf',
'father': {
'_id': a1.pk,
'tp': 'pj'
}
})
def test_cached_reference_embedded_fields(self):
class Owner(EmbeddedDocument):
TPS = (
('n', "Normal"),
('u', "Urgent")
)
name = StringField()
tp = StringField(
verbose_name="Type",
db_field="t",
choices=TPS)
class Animal(Document):
name = StringField()
tag = StringField()
owner = EmbeddedDocumentField(Owner)
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag', 'owner.tp'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy",
owner=Owner(tp='u', name="Wilson Júnior")
)
a.save()
o = Ocorrence(person="teste", animal=a)
o.save()
self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), {
'_id': a.pk,
'tag': 'heavy',
'owner': {
't': 'u'
}
})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u')
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
count = Ocorrence.objects(
animal__tag='heavy', animal__owner__tp='u').count()
self.assertEqual(count, 1)
ocorrence = Ocorrence.objects(
animal__tag='heavy',
animal__owner__tp='u').first()
self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_embedded_list_fields(self):
class Owner(EmbeddedDocument):
name = StringField()
tags = ListField(StringField())
class Animal(Document):
name = StringField()
tag = StringField()
owner = EmbeddedDocumentField(Owner)
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(
Animal, fields=['tag', 'owner.tags'])
Animal.drop_collection()
Ocorrence.drop_collection()
a = Animal(nam="Leopard", tag="heavy",
owner=Owner(tags=['cool', 'funny'],
name="Wilson Júnior")
)
a.save()
o = Ocorrence(person="teste 2", animal=a)
o.save()
self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), {
'_id': a.pk,
'tag': 'heavy',
'owner': {
'tags': ['cool', 'funny']
}
})
self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
self.assertEqual(o.to_mongo()['animal']['owner']['tags'],
['cool', 'funny'])
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
query = Ocorrence.objects(
animal__tag='heavy', animal__owner__tags='cool')._query
self.assertEqual(
query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'})
ocorrence = Ocorrence.objects(
animal__tag='heavy',
animal__owner__tags='cool').first()
self.assertEqual(ocorrence.person, "teste 2")
self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_objectid_reference_fields(self): def test_objectid_reference_fields(self):
class Person(Document): class Person(Document):
@ -1834,8 +2231,7 @@ class FieldTest(unittest.TestCase):
Person(name="Wilson Jr").save() Person(name="Wilson Jr").save()
self.assertEqual(repr(Person.objects(city=None)), self.assertEqual(repr(Person.objects(city=None)),
"[<Person: Person object>]") "[<Person: Person object>]")
def test_generic_reference_choices(self): def test_generic_reference_choices(self):
"""Ensure that a GenericReferenceField can handle choices """Ensure that a GenericReferenceField can handle choices
@ -1982,7 +2378,8 @@ class FieldTest(unittest.TestCase):
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07'))
attachment_required.validate() attachment_required.validate()
attachment_size_limit = AttachmentSizeLimit(blob=b('\xe6\x00\xc4\xff\x07')) attachment_size_limit = AttachmentSizeLimit(
blob=b('\xe6\x00\xc4\xff\x07'))
self.assertRaises(ValidationError, attachment_size_limit.validate) self.assertRaises(ValidationError, attachment_size_limit.validate)
attachment_size_limit.blob = b('\xe6\x00\xc4\xff') attachment_size_limit.blob = b('\xe6\x00\xc4\xff')
attachment_size_limit.validate() attachment_size_limit.validate()
@ -2030,8 +2427,8 @@ class FieldTest(unittest.TestCase):
""" """
class Shirt(Document): class Shirt(Document):
size = StringField(max_length=3, choices=( size = StringField(max_length=3, choices=(
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=( style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
@ -2061,7 +2458,7 @@ class FieldTest(unittest.TestCase):
""" """
class Shirt(Document): class Shirt(Document):
size = StringField(max_length=3, size = StringField(max_length=3,
choices=('S', 'M', 'L', 'XL', 'XXL')) choices=('S', 'M', 'L', 'XL', 'XXL'))
Shirt.drop_collection() Shirt.drop_collection()
@ -2179,7 +2576,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 1000) self.assertEqual(c['next'], 1000)
def test_sequence_field_get_next_value(self): def test_sequence_field_get_next_value(self):
class Person(Document): class Person(Document):
id = SequenceField(primary_key=True) id = SequenceField(primary_key=True)
@ -2368,7 +2764,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, post.comments[0].id) self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id) self.assertEqual(2, post.comments[1].id)
def test_generic_embedded_document(self): def test_generic_embedded_document(self):
class Car(EmbeddedDocument): class Car(EmbeddedDocument):
name = StringField() name = StringField()
@ -2478,7 +2873,7 @@ class FieldTest(unittest.TestCase):
self.assertTrue('comments' in error.errors) self.assertTrue('comments' in error.errors)
self.assertTrue(1 in error.errors['comments']) self.assertTrue(1 in error.errors['comments'])
self.assertTrue(isinstance(error.errors['comments'][1]['content'], self.assertTrue(isinstance(error.errors['comments'][1]['content'],
ValidationError)) ValidationError))
# ValidationError.schema property # ValidationError.schema property
error_dict = error.to_dict() error_dict = error.to_dict()
@ -2604,11 +2999,11 @@ class FieldTest(unittest.TestCase):
DictFieldTest.drop_collection() DictFieldTest.drop_collection()
test = DictFieldTest(dictionary=None) test = DictFieldTest(dictionary=None)
test.dictionary # Just access to test getter test.dictionary # Just access to test getter
self.assertRaises(ValidationError, test.validate) self.assertRaises(ValidationError, test.validate)
test = DictFieldTest(dictionary=False) test = DictFieldTest(dictionary=False)
test.dictionary # Just access to test getter test.dictionary # Just access to test getter
self.assertRaises(ValidationError, test.validate) self.assertRaises(ValidationError, test.validate)