Merge branch 'dev' into pull_124

This commit is contained in:
Ross Lawley
2011-05-25 09:54:56 +01:00
10 changed files with 1184 additions and 240 deletions

View File

@@ -7,22 +7,32 @@ import pymongo
import pymongo.objectid
_document_registry = {}
def get_document(name):
return _document_registry[name]
class NotRegistered(Exception):
pass
class ValidationError(Exception):
pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
"""
# Fields may have _types inserted into indexes by default
# Fields may have _types inserted into indexes by default
_index_with_types = True
_geo_index = False
@@ -32,7 +42,7 @@ class BaseField(object):
creation_counter = 0
auto_creation_counter = -1
def __init__(self, db_field=None, name=None, required=False, default=None,
def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False,
validation=None, choices=None):
self.db_field = (db_field or name) if not primary_key else '_id'
@@ -57,7 +67,7 @@ class BaseField(object):
BaseField.creation_counter += 1
def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do
"""Descriptor for retrieving a value from a field in a document. Do
any necessary conversion between Python and MongoDB types.
"""
if instance is None:
@@ -167,8 +177,8 @@ class DocumentMetaclass(type):
superclasses.update(base._superclasses)
if hasattr(base, '_meta'):
# Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on
# Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on
# additional fields _cls and _types
if base._meta.get('allow_inheritance', True) == False:
raise ValueError('Document %s may not be subclassed' %
@@ -211,12 +221,12 @@ class DocumentMetaclass(type):
module = attrs.get('__module__')
base_excs = tuple(base.DoesNotExist for base in bases
base_excs = tuple(base.DoesNotExist for base in bases
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
exc = subclass_exception('DoesNotExist', base_excs, module)
new_class.add_to_class('DoesNotExist', exc)
base_excs = tuple(base.MultipleObjectsReturned for base in bases
base_excs = tuple(base.MultipleObjectsReturned for base in bases
if hasattr(base, 'MultipleObjectsReturned'))
base_excs = base_excs or (MultipleObjectsReturned,)
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
@@ -238,12 +248,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
def __new__(cls, name, bases, attrs):
super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have
# Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc.
# __metaclass__ is only set on the class with the __metaclass__
# __metaclass__ is only set on the class with the __metaclass__
# attribute (i.e. it is not set on subclasses). This differentiates
# 'real' documents from the 'Document' class
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass:
#
# Also assume a class is abstract if it has abstract set to True in
# its meta dictionary. This allows custom Document superclasses.
if (attrs.get('__metaclass__') == TopLevelDocumentMetaclass or
('meta' in attrs and attrs['meta'].get('abstract', False))):
# Make sure no base class was non-abstract
non_abstract_bases = [b for b in bases
if hasattr(b,'_meta') and not b._meta.get('abstract', False)]
if non_abstract_bases:
raise ValueError("Abstract document cannot have non-abstract base")
return super_new(cls, name, bases, attrs)
collection = name.lower()
@@ -266,6 +285,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
base_indexes += base._meta.get('indexes', [])
meta = {
'abstract': False,
'collection': collection,
'max_documents': None,
'max_size': None,
@@ -289,13 +309,39 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'):
new_class.objects = QuerySetManager()
manager = attrs.get('objects', QuerySetManager())
if hasattr(manager, 'queryset_class'):
meta['queryset_class'] = manager.queryset_class
new_class.objects = manager
user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes
unique_indexes = cls._unique_with_indexes(new_class)
new_class._meta['unique_indexes'] = unique_indexes
for field_name, field in new_class._fields.items():
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
@classmethod
def _unique_with_indexes(cls, new_class, namespace=""):
unique_indexes = []
for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints
@@ -321,28 +367,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
unique_fields += unique_with
# Add the new index to the list
index = [(f, pymongo.ASCENDING) for f in unique_fields]
index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields]
unique_indexes.append(index)
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
# Grab any embedded document field unique indexes
if field.__class__.__name__ == "EmbeddedDocumentField":
field_namespace = "%s." % field_name
unique_indexes += cls._unique_with_indexes(field.document_type,
field_namespace)
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
new_class._meta['unique_indexes'] = unique_indexes
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
return unique_indexes
class BaseDocument(object):
@@ -366,7 +400,7 @@ class BaseDocument(object):
are present.
"""
# Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name))
fields = [(field, getattr(self, name))
for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value
@@ -461,7 +495,7 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and not data['_id']:
if data.has_key('_id') and data['_id'] is None:
del data['_id']
return data

View File

@@ -1,5 +1,6 @@
from pymongo import Connection
import multiprocessing
import threading
__all__ = ['ConnectionError', 'connect']
@@ -22,17 +23,22 @@ class ConnectionError(Exception):
def _get_connection(reconnect=False):
"""Handles the connection to the database
"""
global _connection
identity = get_identity()
# Connect to the database if not already connected
if _connection.get(identity) is None or reconnect:
try:
_connection[identity] = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
except Exception, e:
raise ConnectionError("Cannot connect to the database:\n%s" % e)
return _connection[identity]
def _get_db(reconnect=False):
"""Handles database connections and authentication based on the current
identity
"""
global _db, _connection
identity = get_identity()
# Connect if not already connected
@@ -52,12 +58,17 @@ def _get_db(reconnect=False):
return _db[identity]
def get_identity():
"""Creates an identity key based on the current process and thread
identity.
"""
identity = multiprocessing.current_process()._identity
identity = 0 if not identity else identity[0]
identity = (identity, threading.current_thread().ident)
return identity
def connect(db, username=None, password=None, **kwargs):
"""Connect to the database specified by the 'db' argument. Connection
"""Connect to the database specified by the 'db' argument. Connection
settings may be provided here as well if the database is not running on
the default port on localhost. If authentication is needed, provide
username and password arguments as well.

View File

@@ -86,7 +86,7 @@ class User(Document):
else:
email = '@'.join([email_name, domain_part.lower()])
user = User(username=username, email=email, date_joined=now)
user = cls(username=username, email=email, date_joined=now)
user.set_password(password)
user.save()
return user

View File

@@ -40,44 +40,54 @@ class Document(BaseDocument):
presence of `_cls` and `_types`, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary.
A :class:`~mongoengine.Document` may use a **Capped Collection** by
A :class:`~mongoengine.Document` may use a **Capped Collection** by
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
dictionary. :attr:`max_documents` is the maximum number of documents that
is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to
is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field
dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign.
"""
__metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True):
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
If ``safe=True`` and the operation is unsuccessful, an
If ``safe=True`` and the operation is unsuccessful, an
:class:`~mongoengine.OperationError` will be raised.
:param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow
:param force_insert: only try to create a new document, don't allow
updates of existing documents
:param validate: validates the document; set to ``False`` to skip.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
"""
if validate:
self.validate()
if not write_options:
write_options = {}
doc = self.to_mongo()
try:
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe)
object_id = collection.insert(doc, safe=safe, **write_options)
else:
object_id = collection.save(doc, safe=safe)
object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):
@@ -131,9 +141,9 @@ class MapReduceDocument(object):
"""A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``,
:param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property.
:param value: The result(s) for this key.
@@ -148,7 +158,7 @@ class MapReduceDocument(object):
@property
def object(self):
"""Lazy-load the object referenced by ``self.key``. ``self.key``
"""Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``.
"""
id_field = self._document()._meta['id_field']

View File

@@ -17,7 +17,7 @@ import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
@@ -339,7 +339,7 @@ class ListField(BaseField):
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
# Get value from document instance if available
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
@@ -449,7 +449,108 @@ class DictField(BaseField):
'contain "." or "$" characters')
def lookup_member(self, member_name):
return self.basecls(db_field=member_name)
return DictField(basecls=self.basecls, db_field=member_name)
def prepare_query_value(self, op, value):
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
class MapField(BaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
.. versionadded:: 0.5
"""
def __init__(self, field=None, *args, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_dict[key] = referenced_type._from_son(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
if isinstance(self.field, GenericReferenceField):
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_dict[key] = self.field.dereference(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
def prepare_query_value(self, op, value):
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on
@@ -522,6 +623,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3
"""
@@ -601,6 +705,7 @@ class GridFSProxy(object):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
def __getattr__(self, name):
obj = self.get()
@@ -614,8 +719,12 @@ class GridFSProxy(object):
def get(self, id=None):
if id:
self.grid_id = id
if self.grid_id is None:
return None
try:
return self.fs.get(id or self.grid_id)
if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except:
# File has been deleted
return None
@@ -643,11 +752,11 @@ class GridFSProxy(object):
if not self.newfile:
self.new_file()
self.grid_id = self.newfile._id
self.newfile.writelines(lines)
self.newfile.writelines(lines)
def read(self):
def read(self, size=-1):
try:
return self.get().read()
return self.get().read(size)
except:
return None
@@ -655,6 +764,7 @@ class GridFSProxy(object):
# Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id)
self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs):
self.delete()

View File

@@ -8,6 +8,7 @@ import pymongo.objectid
import re
import copy
import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True
EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]):
self.direction = direction
def __init__(self, fields=[], value=ONLY, always_include=[]):
self.value = value
self.fields = set(fields)
self.always_include = set(always_include)
def as_dict(self):
return dict((field, self.direction) for field in self.fields)
return dict((field, self.value) for field in self.fields)
def __add__(self, f):
if not self.fields:
self.fields = f.fields
self.direction = f.direction
elif self.direction is self.ONLY and f.direction is self.ONLY:
self.value = f.value
elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
self.direction = self.ONLY
elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.value = self.ONLY
self.fields = f.fields - self.fields
if self.always_include:
if self.direction is self.ONLY and self.fields:
if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include)
else:
self.fields -= self.always_include
@@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self):
self.fields = set([])
self.direction = self.ONLY
self.value = self.ONLY
def __nonzero__(self):
return bool(self.fields)
@@ -334,6 +335,7 @@ class QuerySet(object):
self._ordering = []
self._snapshot = False
self._timeout = True
self._class_check = True
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
@@ -344,11 +346,26 @@ class QuerySet(object):
self._limit = None
self._skip = None
def clone(self):
"""Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`"""
c = self.__class__(self._document, self._collection_obj)
copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip')
for prop in copy_props:
val = getattr(self, prop)
setattr(c, prop, copy.deepcopy(val))
return c
@property
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query)
if self._class_check:
self._mongo_query.update(self._initial_query)
return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False,
@@ -399,7 +416,7 @@ class QuerySet(object):
return index_list
def __call__(self, q_obj=None, **query):
def __call__(self, q_obj=None, class_check=True, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -407,16 +424,17 @@ class QuerySet(object):
the query; the :class:`~mongoengine.queryset.QuerySet` is filtered
multiple times with different :class:`~mongoengine.queryset.Q`
objects, only the last one will be used
:param class_check: If set to False bypass class name check when
querying collection
:param query: Django-style query keyword arguments
"""
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query)
if q_obj:
query &= q_obj
self._query_obj &= query
self._mongo_query = None
self._cursor_obj = None
self._class_check = class_check
return self
def filter(self, *q_objs, **query):
@@ -440,17 +458,17 @@ class QuerySet(object):
drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {})
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# Ensure document-defined indexes are created
if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']:
self._collection.ensure_index(key_or_list,
background=background, **index_opts)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index
if '_types' in self._query:
self._collection.ensure_index('_types',
@@ -474,7 +492,7 @@ class QuerySet(object):
}
if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query,
self._cursor_obj = self._collection.find(self._query,
**cursor_args)
# Apply where clauses to cursor
if self._where_clause:
@@ -504,6 +522,15 @@ class QuerySet(object):
fields = []
field = None
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue
if field is None:
# Look up first field from the document
if field_name == 'pk':
@@ -528,14 +555,14 @@ class QuerySet(object):
return '.'.join(parts)
@classmethod
def _transform_query(cls, _doc_cls=None, **query):
def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
mongo_query = {}
@@ -577,8 +604,12 @@ class QuerySet(object):
if op in geo_operators:
if op == "within_distance":
value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "near":
value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
@@ -620,9 +651,9 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name)
def get_or_create(self, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object
def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found.
@@ -630,6 +661,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3
"""
defaults = query.get('defaults', {})
@@ -641,7 +676,7 @@ class QuerySet(object):
if count == 0:
query.update(defaults)
doc = self._document(**query)
doc.save()
doc.save(write_options=write_options)
return doc, True
elif count == 1:
return self.first(), False
@@ -725,7 +760,7 @@ class QuerySet(object):
def __len__(self):
return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None,
def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False):
"""Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
@@ -739,26 +774,26 @@ class QuerySet(object):
:param map_f: map function, as :class:`~pymongo.code.Code` or string
:param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string
:param output: output collection name
:param finalize_f: finalize function, an optional function that
performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide
to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo
.. note:: Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**.
PyMongo version **>= 1.11**.
.. versionadded:: 0.3
"""
from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1")
raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {}
if isinstance(map_f, pymongo.code.Code):
@@ -789,8 +824,7 @@ class QuerySet(object):
if limit:
mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
results = self._collection.map_reduce(map_f, reduce_f, output, **mr_args)
results = results.find()
if self._ordering:
@@ -835,7 +869,7 @@ class QuerySet(object):
self._skip, self._limit = key.start, key.stop
except IndexError, err:
# PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it.
# bin it, kill it.
start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop:
@@ -868,10 +902,8 @@ class QuerySet(object):
.. versionadded:: 0.3
"""
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
return self
fields = dict([(f, QueryFieldList.ONLY) for f in fields])
return self.fields(**fields)
def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. ::
@@ -880,8 +912,44 @@ class QuerySet(object):
:param fields: fields to exclude
"""
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self
def all_fields(self):
@@ -917,6 +985,10 @@ class QuerySet(object):
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
try:
key = QuerySet._translate_field_name(self._document, key)
except:
pass
key_list.append((key, direction))
self._ordering = key_list
@@ -1007,10 +1079,17 @@ class QuerySet(object):
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields]
parts = []
for field in fields:
if isinstance(field, str):
parts.append(field)
else:
parts.append(field.db_field)
# Convert value to proper value
field = fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
@@ -1029,22 +1108,27 @@ class QuerySet(object):
return mongo_update
def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query. When
def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2
"""
if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, err:
@@ -1053,22 +1137,27 @@ class QuerySet(object):
raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query. When
def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments
.. versionadded:: 0.2
"""
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
# Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True'
if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
else:
# Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update,
@@ -1082,8 +1171,8 @@ class QuerySet(object):
return self
def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be
"""When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be
substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor).
"""
@@ -1106,9 +1195,9 @@ class QuerySet(object):
options specified as keyword arguments.
As fields in MongoEngine may use different names in the database (set
using the :attr:`db_field` keyword argument to a :class:`Field`
using the :attr:`db_field` keyword argument to a :class:`Field`
constructor), a mechanism exists for replacing MongoEngine field names
with the database field names in Javascript code. When accessing a
with the database field names in Javascript code. When accessing a
field, use square-bracket notation, and prefix the MongoEngine field
name with a tilde (~).
@@ -1241,8 +1330,11 @@ class QuerySet(object):
class QuerySetManager(object):
def __init__(self, manager_func=None):
self._manager_func = manager_func
get_queryset = None
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {}
def __get__(self, instance, owner):
@@ -1259,7 +1351,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents']
if collection in db.collection_names():
@@ -1286,11 +1378,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func:
if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset)
if self.get_queryset:
if self.get_queryset.func_code.co_argcount == 1:
queryset = self.get_queryset(queryset)
else:
queryset = self._manager_func(owner, queryset)
queryset = self.get_queryset(owner, queryset)
return queryset