Merge branch 'master' into 0.8M

Conflicts:
	AUTHORS
	docs/django.rst
	mongoengine/base.py
	mongoengine/queryset.py
	tests/fields/fields.py
	tests/queryset/queryset.py
	tests/test_dereference.py
	tests/test_document.py
This commit is contained in:
Ross Lawley
2013-04-17 11:57:53 +00:00
17 changed files with 474 additions and 101 deletions

View File

@@ -30,13 +30,22 @@ class BaseDocument(object):
_dynamic_lock = True
_initialised = False
def __init__(self, __auto_convert=True, **values):
def __init__(self, __auto_convert=True, *args, **values):
"""
Initialise a document or embedded document
:param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document
"""
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
field = iter(self._fields_ordered)
for value in args:
name = next(field)
if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
values[name] = value
signals.pre_init.send(self.__class__, document=self, values=values)
@@ -117,15 +126,15 @@ class BaseDocument(object):
self._mark_as_changed(name)
if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
@@ -143,7 +152,10 @@ class BaseDocument(object):
self.__set_field_display()
def __iter__(self):
return iter(self._fields)
if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered)
def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present.
@@ -264,7 +276,7 @@ class BaseDocument(object):
for name, field in self._fields.items()]
if self._dynamic:
fields += [(field, self._data.get(name))
for name, field in self._dynamic_fields.items()]
for name, field in self._dynamic_fields.items()]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
@@ -273,7 +285,7 @@ class BaseDocument(object):
if value is not None:
try:
if isinstance(field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)):
GenericEmbeddedDocumentField)):
field._validate(value, clean=clean)
else:
field._validate(value)
@@ -330,7 +342,7 @@ class BaseDocument(object):
# Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)):
not isinstance(value, BaseList)):
value = BaseList(value, self, name)
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, self, name)
@@ -344,9 +356,25 @@ class BaseDocument(object):
return
key = self._db_field_map.get(key, key)
if (hasattr(self, '_changed_fields') and
key not in self._changed_fields):
key not in self._changed_fields):
self._changed_fields.append(key)
def _clear_changed_fields(self):
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _get_changed_fields(self, key='', inspected=None):
"""Returns a list of all fields that have explicitly been changed.
"""
@@ -418,7 +446,7 @@ class BaseDocument(object):
for p in parts:
if isinstance(d, DBRef):
break
elif p.isdigit():
elif isinstance(d, list) and p.isdigit():
d = d[int(p)]
elif hasattr(d, 'get'):
d = d.get(p)
@@ -449,7 +477,7 @@ class BaseDocument(object):
parts = path.split('.')
db_field_name = parts.pop()
for p in parts:
if p.isdigit():
if isinstance(d, list) and p.isdigit():
d = d[int(p)]
elif (hasattr(d, '__getattribute__') and
not isinstance(d, dict)):
@@ -514,7 +542,7 @@ class BaseDocument(object):
value = data[field.db_field]
try:
data[field_name] = (value if value is None
else field.to_python(value))
else field.to_python(value))
if field_name != field.db_field:
del data[field.db_field]
except (AttributeError, ValueError), e:
@@ -548,14 +576,14 @@ class BaseDocument(object):
geo_indices = cls._geo_indices()
unique_indices = cls._unique_with_indexes()
index_specs = [cls._build_index_spec(spec)
for spec in meta_indexes]
for spec in meta_indexes]
def merge_index_specs(index_specs, indices):
if not indices:
return index_specs
spec_fields = [v['fields']
for k, v in enumerate(index_specs)]
for k, v in enumerate(index_specs)]
# Merge unqiue_indexes with existing specs
for k, v in enumerate(indices):
if v['fields'] in spec_fields:
@@ -727,7 +755,7 @@ class BaseDocument(object):
field = DynamicField(db_field=field_name)
else:
raise LookUpError('Cannot resolve field "%s"'
% field_name)
% field_name)
else:
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
@@ -744,7 +772,7 @@ class BaseDocument(object):
continue
elif not new_field:
raise LookUpError('Cannot resolve field "%s"'
% field_name)
% field_name)
field = new_field # update field to the new field type
fields.append(field)
return fields

View File

@@ -81,8 +81,12 @@ class BaseField(object):
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
instance._data[self.name] = value
if instance._initialised:
changed = False
if (self.name not in instance._data or
instance._data[self.name] != value):
changed = True
instance._data[self.name] = value
if changed and instance._initialised:
instance._mark_as_changed(self.name)
def error(self, message="", errors=None, field_name=None):

View File

@@ -78,7 +78,7 @@ class DocumentMetaclass(type):
# Count names to ensure no db_field redefinitions
field_names[attr_value.db_field] = field_names.get(
attr_value.db_field, 0) + 1
attr_value.db_field, 0) + 1
# Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
@@ -90,9 +90,12 @@ class DocumentMetaclass(type):
# Set _fields and db_field maps
attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()])
for k, v in doc_fields.iteritems()])
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems())
(v, k) for k, v in attrs['_db_field_map'].iteritems())
#
# Set document hierarchy
@@ -101,7 +104,7 @@ class DocumentMetaclass(type):
class_name = [name]
for base in flattened_bases:
if (not getattr(base, '_is_base_cls', True) and
not getattr(base, '_meta', {}).get('abstract', True)):
not getattr(base, '_meta', {}).get('abstract', True)):
# Collate heirarchy for _cls and _subclasses
class_name.append(base.__name__)
@@ -109,11 +112,11 @@ class DocumentMetaclass(type):
# Warn if allow_inheritance isn't set and prevent
# inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
if (allow_inheritance != True and
not base._meta.get('abstract')):
ALLOW_INHERITANCE)
if (allow_inheritance is not True and
not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' %
base.__name__)
base.__name__)
# Get superclasses from last base superclass
document_bases = [b for b in flattened_bases

View File

@@ -33,7 +33,7 @@ class DeReference(object):
self.max_depth = max_depth
doc_type = None
if instance and instance._fields:
if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name)
if hasattr(doc_type, 'field'):
doc_type = doc_type.field
@@ -84,7 +84,7 @@ class DeReference(object):
# Recursively find dbreferences
depth += 1
for k, item in iterator:
if hasattr(item, '_fields'):
if isinstance(item, Document):
for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None)
if isinstance(v, (DBRef)):
@@ -174,6 +174,7 @@ class DeReference(object):
if not hasattr(items, 'items'):
is_list = True
as_tuple = isinstance(items, tuple)
iterator = enumerate(items)
data = []
else:
@@ -190,7 +191,7 @@ class DeReference(object):
if k in self.object_map and not is_list:
data[k] = self.object_map[k]
elif hasattr(v, '_fields'):
elif isinstance(v, Document):
for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)):
@@ -208,7 +209,7 @@ class DeReference(object):
if instance and name:
if is_list:
return BaseList(data, instance, name)
return tuple(data) if as_tuple else BaseList(data, instance, name)
return BaseDict(data, instance, name)
depth += 1
return data

View File

@@ -32,9 +32,17 @@ class MongoSession(Document):
else fields.DictField()
expire_date = fields.DateTimeField()
meta = {'collection': MONGOENGINE_SESSION_COLLECTION,
'db_alias': MONGOENGINE_SESSION_DB_ALIAS,
'allow_inheritance': False}
meta = {
'collection': MONGOENGINE_SESSION_COLLECTION,
'db_alias': MONGOENGINE_SESSION_DB_ALIAS,
'allow_inheritance': False,
'indexes': [
{
'fields': ['expire_date'],
'expireAfterSeconds': settings.SESSION_COOKIE_AGE
}
]
}
def get_decoded(self):
return SessionStore().decode(self.session_data)

View File

@@ -160,7 +160,7 @@ class Document(BaseDocument):
def save(self, safe=True, force_insert=False, validate=True, clean=True,
write_options=None, cascade=None, cascade_kwargs=None,
_refs=None):
_refs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -278,7 +278,7 @@ class Document(BaseDocument):
if id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
self._changed_fields = []
self._clear_changed_fields()
self._created = False
signals.post_save.send(self.__class__, document=self, created=created)
return self

View File

@@ -27,7 +27,7 @@ except ImportError:
Image = None
ImageOps = None
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
__all__ = ['StringField', 'IntField', 'LongField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField',
@@ -143,7 +143,7 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value):
@@ -153,7 +153,7 @@ class EmailField(StringField):
class IntField(BaseField):
"""An integer field.
"""An 32-bit integer field.
"""
def __init__(self, min_value=None, max_value=None, **kwargs):
@@ -186,6 +186,40 @@ class IntField(BaseField):
return int(value)
class LongField(BaseField):
"""An 64-bit integer field.
"""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(LongField, self).__init__(**kwargs)
def to_python(self, value):
try:
value = long(value)
except ValueError:
pass
return value
def validate(self, value):
try:
value = long(value)
except:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
self.error('Long value is too small')
if self.max_value is not None and value > self.max_value:
self.error('Long value is too large')
def prepare_query_value(self, op, value):
if value is None:
return value
return long(value)
class FloatField(BaseField):
"""An floating point number field.
"""

View File

@@ -609,8 +609,11 @@ class QuerySet(object):
.. versionchanged:: 0.6 - Improved db_field refrence handling
"""
queryset = self.clone()
return queryset._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=queryset._document)
try:
field = self._fields_to_dbfields([field]).pop()
finally:
return self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document)
def only(self, *fields):
"""Load only a subset of this document's fields. ::
@@ -696,7 +699,7 @@ class QuerySet(object):
prefixed with **+** or **-** to determine the ordering direction
"""
queryset = self.clone()
queryset._ordering = self._get_order_by(keys)
queryset._ordering = queryset._get_order_by(keys)
return queryset
def explain(self, format=False):

View File

@@ -32,7 +32,7 @@ class SimplificationVisitor(QNodeVisitor):
if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children]
queries = [n.query for n in combination.children]
return Q(**self._query_conjunction(queries))
return combination