Merge branch 'v0.4' of git://github.com/hmarr/mongoengine into v0.4

Conflicts:
	mongoengine/fields.py
	tests/fields.py
This commit is contained in:
Steve Challis 2010-08-31 00:25:10 +01:00
commit bd1bf9ba24
11 changed files with 449 additions and 74 deletions

2
.gitignore vendored
View File

@ -1,7 +1,9 @@
*.pyc *.pyc
.*.swp .*.swp
*.egg
docs/.build docs/.build
docs/_build docs/_build
build/ build/
dist/ dist/
mongoengine.egg-info/ mongoengine.egg-info/
env/

View File

@ -2,6 +2,20 @@
Changelog Changelog
========= =========
Changes in v0.4
===============
- Added ``SortedListField``
- Added ``EmailField``
- Added ``GeoPointField``
- Added ``exact`` and ``iexact`` match operators to ``QuerySet``
- Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts
- Fixed bug in Q-objects
- Fixed document inheritance primary key issue
- Base class can now be defined for ``DictField``
- Fixed MRO error that occured on document inheritance
- Introduced ``min_length`` for ``StringField``
- Other minor fixes
Changes in v0.3 Changes in v0.3
=============== ===============
- Added MapReduce support - Added MapReduce support

View File

@ -71,7 +71,7 @@ Available operators are as follows:
* ``in`` -- value is in list (a list of values should be provided) * ``in`` -- value is in list (a list of values should be provided)
* ``nin`` -- value is not in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided)
* ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values
* ``all`` -- every item in array is in list of values provided * ``all`` -- every item in list of values provided is in array
* ``size`` -- the size of the array is * ``size`` -- the size of the array is
* ``exists`` -- value for field exists * ``exists`` -- value for field exists
@ -174,7 +174,7 @@ custom manager methods as you like::
@queryset_manager @queryset_manager
def live_posts(doc_cls, queryset): def live_posts(doc_cls, queryset):
return queryset(published=True).filter(published=True) return queryset.filter(published=True)
BlogPost(title='test1', published=False).save() BlogPost(title='test1', published=False).save()
BlogPost(title='test2', published=True).save() BlogPost(title='test2', published=True).save()
@ -399,6 +399,7 @@ that you may use with these methods:
* ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``unset`` -- delete a particular value (since MongoDB v1.3+)
* ``inc`` -- increment a value by a given amount * ``inc`` -- increment a value by a given amount
* ``dec`` -- decrement a value by a given amount * ``dec`` -- decrement a value by a given amount
* ``pop`` -- remove the last item from a list
* ``push`` -- append a value to a list * ``push`` -- append a value to a list
* ``push_all`` -- append several values to a list * ``push_all`` -- append several values to a list
* ``pull`` -- remove a value from a list * ``pull`` -- remove a value from a list

View File

@ -22,6 +22,7 @@ class BaseField(object):
# Fields may have _types inserted into indexes by default # Fields may have _types inserted into indexes by default
_index_with_types = True _index_with_types = True
_geo_index = False
def __init__(self, db_field=None, name=None, required=False, default=None, def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False, unique=False, unique_with=None, primary_key=False,
@ -229,12 +230,18 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
id_field = None id_field = None
base_indexes = [] base_indexes = []
base_meta = {}
# Subclassed documents inherit collection from superclass # Subclassed documents inherit collection from superclass
for base in bases: for base in bases:
if hasattr(base, '_meta') and 'collection' in base._meta: if hasattr(base, '_meta') and 'collection' in base._meta:
collection = base._meta['collection'] collection = base._meta['collection']
# Propagate index options.
for key in ('index_background', 'index_drop_dups', 'index_opts'):
if key in base._meta:
base_meta[key] = base._meta[key]
id_field = id_field or base._meta.get('id_field') id_field = id_field or base._meta.get('id_field')
base_indexes += base._meta.get('indexes', []) base_indexes += base._meta.get('indexes', [])
@ -245,7 +252,11 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
'ordering': [], # default ordering applied at runtime 'ordering': [], # default ordering applied at runtime
'indexes': [], # indexes to be ensured at runtime 'indexes': [], # indexes to be ensured at runtime
'id_field': id_field, 'id_field': id_field,
'index_background': False,
'index_drop_dups': False,
'index_opts': {},
} }
meta.update(base_meta)
# Apply document-defined meta options # Apply document-defined meta options
meta.update(attrs.get('meta', {})) meta.update(attrs.get('meta', {}))
@ -254,6 +265,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Set up collection manager, needs the class to have fields so use # Set up collection manager, needs the class to have fields so use
# DocumentMetaclass before instantiating CollectionManager object # DocumentMetaclass before instantiating CollectionManager object
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'):
new_class.objects = QuerySetManager() new_class.objects = QuerySetManager()
user_indexes = [QuerySet._build_index_spec(new_class, spec) user_indexes = [QuerySet._build_index_spec(new_class, spec)
@ -265,7 +279,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Generate a list of indexes needed by uniqueness constraints # Generate a list of indexes needed by uniqueness constraints
if field.unique: if field.unique:
field.required = True field.required = True
unique_fields = [field_name] unique_fields = [field.db_field]
# Add any unique_with fields to the back of the index spec # Add any unique_with fields to the back of the index spec
if field.unique_with: if field.unique_with:
@ -338,8 +352,8 @@ class BaseDocument(object):
try: try:
field._validate(value) field._validate(value)
except (ValueError, AttributeError, AssertionError), e: except (ValueError, AttributeError, AssertionError), e:
raise ValidationError('Invalid value for field of type "' + raise ValidationError('Invalid value for field of type "%s": %s'
field.__class__.__name__ + '"') % (field.__class__.__name__, value))
elif field.required: elif field.required:
raise ValidationError('Field "%s" is required' % field.name) raise ValidationError('Field "%s" is required' % field.name)
@ -414,6 +428,8 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False): self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name] data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and not data['_id']:
del data['_id']
return data return data
@classmethod @classmethod
@ -445,7 +461,9 @@ class BaseDocument(object):
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.items():
if field.db_field in data: if field.db_field in data:
data[field_name] = field.to_python(data[field.db_field]) value = data[field.db_field]
data[field_name] = (value if value is None
else field.to_python(value))
obj = cls(**data) obj = cls(**data)
obj._present_fields = present_fields obj._present_fields = present_fields

View File

@ -32,6 +32,9 @@ class User(Document):
last_login = DateTimeField(default=datetime.datetime.now) last_login = DateTimeField(default=datetime.datetime.now)
date_joined = DateTimeField(default=datetime.datetime.now) date_joined = DateTimeField(default=datetime.datetime.now)
def __unicode__(self):
return self.username
def get_full_name(self): def get_full_name(self):
"""Returns the users first and last names, separated by a space. """Returns the users first and last names, separated by a space.
""" """

View File

@ -0,0 +1,21 @@
#coding: utf-8
from django.test import TestCase
from django.conf import settings
from mongoengine import connect
class MongoTestCase(TestCase):
"""
TestCase class that clear the collection between the tests
"""
db_name = 'test_%s' % settings.MONGO_DATABASE_NAME
def __init__(self, methodName='runtest'):
self.db = connect(self.db_name)
super(MongoTestCase, self).__init__(methodName)
def _post_teardown(self):
super(MongoTestCase, self)._post_teardown()
for collection in self.db.collection_names():
if collection == 'system.indexes':
continue
self.db.drop_collection(collection)

View File

@ -16,11 +16,14 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'ObjectIdField', 'ReferenceField', 'ValidationError',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
<<<<<<< HEAD
'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField'] 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField']
=======
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField): class StringField(BaseField):
"""A unicode string field. """A unicode string field.
""" """
@ -67,6 +70,9 @@ class StringField(BaseField):
regex = r'%s$' regex = r'%s$'
elif op == 'exact': elif op == 'exact':
regex = r'^%s$' regex = r'^%s$'
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
value = re.compile(regex % value, flags) value = re.compile(regex % value, flags)
return value return value
@ -264,6 +270,7 @@ class ListField(BaseField):
raise ValidationError('Argument to ListField constructor must be ' raise ValidationError('Argument to ListField constructor must be '
'a valid field') 'a valid field')
self.field = field self.field = field
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs) super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner): def __get__(self, instance, owner):
@ -356,6 +363,7 @@ class DictField(BaseField):
def __init__(self, basecls=None, *args, **kwargs): def __init__(self, basecls=None, *args, **kwargs):
self.basecls = basecls or BaseField self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField) assert issubclass(self.basecls, BaseField)
kwargs.setdefault('default', lambda: {})
super(DictField, self).__init__(*args, **kwargs) super(DictField, self).__init__(*args, **kwargs)
def validate(self, value): def validate(self, value):
@ -372,24 +380,6 @@ class DictField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.basecls(db_field=member_name) return self.basecls(db_field=member_name)
class GeoLocationField(DictField):
"""Supports geobased fields"""
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, tuple) and not isinstance(value, list):
raise ValidationError('GeoLocationField can only hold tuples or lists of (x, y)')
if len(value) <> 2:
raise ValidationError('GeoLocationField must have exactly two elements (x, y)')
def to_mongo(self, value):
return {'x': value[0], 'y': value[1]}
def to_python(self, value):
return value.keys()
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
access (lazily). access (lazily).
@ -456,7 +446,6 @@ class ReferenceField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
@ -503,6 +492,7 @@ class GenericReferenceField(BaseField):
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value)['_ref'] return self.to_mongo(value)['_ref']
class BinaryField(BaseField): class BinaryField(BaseField):
"""A binary data field. """A binary data field.
""" """
@ -524,14 +514,25 @@ class BinaryField(BaseField):
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
raise ValidationError('Binary value is too long') raise ValidationError('Binary value is too long')
<<<<<<< HEAD
=======
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
class GridFSProxy(object): class GridFSProxy(object):
"""Proxy object to handle writing and reading of files to and from GridFS """Proxy object to handle writing and reading of files to and from GridFS
""" """
<<<<<<< HEAD
def __init__(self): def __init__(self):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes self.newfile = None # Used for partial writes
self.grid_id = None # Store GridFS id for file self.grid_id = None # Store GridFS id for file
=======
def __init__(self, grid_id=None):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
def __getattr__(self, name): def __getattr__(self, name):
obj = self.get() obj = self.get()
@ -542,8 +543,17 @@ class GridFSProxy(object):
return self return self
def get(self, id=None): def get(self, id=None):
<<<<<<< HEAD
try: return self.fs.get(id or self.grid_id) try: return self.fs.get(id or self.grid_id)
except: return None # File has been deleted except: return None # File has been deleted
=======
if id:
self.grid_id = id
try:
return self.fs.get(id or self.grid_id)
except:
return None # File has been deleted
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
def new_file(self, **kwargs): def new_file(self, **kwargs):
self.newfile = self.fs.new_file(**kwargs) self.newfile = self.fs.new_file(**kwargs)
@ -565,8 +575,15 @@ class GridFSProxy(object):
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self): def read(self):
<<<<<<< HEAD
try: return self.get().read() try: return self.get().read()
except: return None except: return None
=======
try:
return self.get().read()
except:
return None
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
def delete(self): def delete(self):
# Delete file from GridFS, FileField still remains # Delete file from GridFS, FileField still remains
@ -584,29 +601,61 @@ class GridFSProxy(object):
msg = "The close() method is only necessary after calling write()" msg = "The close() method is only necessary after calling write()"
warnings.warn(msg) warnings.warn(msg)
<<<<<<< HEAD
=======
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
class FileField(BaseField): class FileField(BaseField):
"""A GridFS storage field. """A GridFS storage field.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
<<<<<<< HEAD
self.gridfs = GridFSProxy() self.gridfs = GridFSProxy()
=======
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
super(FileField, self).__init__(**kwargs) super(FileField, self).__init__(**kwargs)
def __get__(self, instance, owner): def __get__(self, instance, owner):
if instance is None: if instance is None:
return self return self
<<<<<<< HEAD
return self.gridfs return self.gridfs
=======
# Check if a file already exists for this model
grid_file = instance._data.get(self.name)
if grid_file:
return grid_file
return GridFSProxy()
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
def __set__(self, instance, value): def __set__(self, instance, value):
if isinstance(value, file) or isinstance(value, str): if isinstance(value, file) or isinstance(value, str):
# using "FileField() = file/string" notation # using "FileField() = file/string" notation
<<<<<<< HEAD
self.gridfs.put(value) self.gridfs.put(value)
=======
grid_file = instance._data.get(self.name)
# If a file already exists, delete it
if grid_file:
try:
grid_file.delete()
except:
pass
# Create a new file with the new data
grid_file.put(value)
else:
# Create a new proxy object as we don't already have one
instance._data[self.name] = GridFSProxy()
instance._data[self.name].put(value)
>>>>>>> 32e66b29f44f3015be099851201241caee92054f
else: else:
instance._data[self.name] = value instance._data[self.name] = value
def to_mongo(self, value): def to_mongo(self, value):
# Store the GridFS file id in MongoDB # Store the GridFS file id in MongoDB
<<<<<<< HEAD
return self.gridfs.grid_id return self.gridfs.grid_id
def to_python(self, value): def to_python(self, value):
@ -617,3 +666,36 @@ class FileField(BaseField):
assert isinstance(value, GridFSProxy) assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId) assert isinstance(value.grid_id, pymongo.objectid.ObjectId)
=======
if isinstance(value, GridFSProxy) and value.grid_id is not None:
return value.grid_id
return None
def to_python(self, value):
if value is not None:
return GridFSProxy(value)
def validate(self, value):
if value.grid_id is not None:
assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId)
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
"""
_geo_index = True
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, (list, tuple)):
raise ValidationError('GeoPointField can only accept tuples or '
'lists of (x, y)')
if not len(value) == 2:
raise ValidationError('Value must be a two-dimensional point.')
if (not isinstance(value[0], (float, int)) and
not isinstance(value[1], (float, int))):
raise ValidationError('Both values in point must be float or int.')
>>>>>>> 32e66b29f44f3015be099851201241caee92054f

View File

@ -1,5 +1,6 @@
from connection import _get_db from connection import _get_db
import pprint
import pymongo import pymongo
import re import re
import copy import copy
@ -114,13 +115,11 @@ class Q(object):
value, field_js = self._build_op_js(op, key, value, value_name) value, field_js = self._build_op_js(op, key, value, value_name)
js_scope[value_name] = value js_scope[value_name] = value
js.append(field_js) js.append(field_js)
print ' && '.join(js)
return ' && '.join(js) return ' && '.join(js)
def _build_op_js(self, op, key, value, value_name): def _build_op_js(self, op, key, value, value_name):
"""Substitute the values in to the correct chunk of Javascript. """Substitute the values in to the correct chunk of Javascript.
""" """
print op, key, value, value_name
if isinstance(value, RE_TYPE): if isinstance(value, RE_TYPE):
# Regexes are handled specially # Regexes are handled specially
if op.strip('$') == 'ne': if op.strip('$') == 'ne':
@ -134,6 +133,16 @@ class Q(object):
if isinstance(value, pymongo.objectid.ObjectId): if isinstance(value, pymongo.objectid.ObjectId):
value = unicode(value) value = unicode(value)
# Handle DBRef
if isinstance(value, pymongo.dbref.DBRef):
op_js = '(this.%(field)s.$id == "%(id)s" &&'\
' this.%(field)s.$ref == "%(ref)s")' % {
'field': key,
'id': unicode(value.id),
'ref': unicode(value.collection)
}
value = None
# Perform the substitution # Perform the substitution
operation_js = op_js % { operation_js = op_js % {
'field': key, 'field': key,
@ -163,7 +172,8 @@ class QuerySet(object):
self._limit = None self._limit = None
self._skip = None self._skip = None
def ensure_index(self, key_or_list): def ensure_index(self, key_or_list, drop_dups=False, background=False,
**kwargs):
"""Ensure that the given indexes are in place. """Ensure that the given indexes are in place.
:param key_or_list: a single index key or a list of index keys (to :param key_or_list: a single index key or a list of index keys (to
@ -171,7 +181,8 @@ class QuerySet(object):
or a **-** to determine the index ordering or a **-** to determine the index ordering
""" """
index_list = QuerySet._build_index_spec(self._document, key_or_list) index_list = QuerySet._build_index_spec(self._document, key_or_list)
self._collection.ensure_index(index_list) self._collection.ensure_index(index_list, drop_dups=drop_dups,
background=background)
return self return self
@classmethod @classmethod
@ -230,6 +241,10 @@ class QuerySet(object):
""" """
return self.__call__(*q_objs, **query) return self.__call__(*q_objs, **query)
def all(self):
"""Returns all documents."""
return self.__call__()
@property @property
def _collection(self): def _collection(self):
"""Property that returns the collection object. This allows us to """Property that returns the collection object. This allows us to
@ -238,24 +253,33 @@ class QuerySet(object):
if not self._accessed_collection: if not self._accessed_collection:
self._accessed_collection = True self._accessed_collection = True
background = self._document._meta.get('index_background', False)
drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {})
# Ensure document-defined indexes are created # Ensure document-defined indexes are created
if self._document._meta['indexes']: if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']:
#self.ensure_index(key_or_list) self._collection.ensure_index(key_or_list,
self._collection.ensure_index(key_or_list) background=background, **index_opts)
# Ensure indexes created by uniqueness constraints # Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']: for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True) self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index # If _types is being used (for polymorphism), it needs an index
if '_types' in self._query: if '_types' in self._query:
self._collection.ensure_index('_types') self._collection.ensure_index('_types',
background=background, **index_opts)
# Ensure all needed field indexes are created # Ensure all needed field indexes are created
for field_name, field_instance in self._document._fields.iteritems(): for field in self._document._fields.values():
if field_instance.__class__.__name__ == 'GeoLocationField': if field.__class__._geo_index:
self._collection.ensure_index([(field_name, pymongo.GEO2D),]) index_spec = [(field.db_field, pymongo.GEO2D)]
self._collection.ensure_index(index_spec,
background=background, **index_opts)
return self._collection_obj return self._collection_obj
@property @property
@ -311,7 +335,8 @@ class QuerySet(object):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'near'] 'all', 'size', 'exists']
geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
@ -321,7 +346,7 @@ class QuerySet(object):
parts = key.split('__') parts = key.split('__')
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
op = None op = None
if parts[-1] in operators + match_operators: if parts[-1] in operators + match_operators + geo_operators:
op = parts.pop() op = parts.pop()
if _doc_cls: if _doc_cls:
@ -335,14 +360,26 @@ class QuerySet(object):
singular_ops += match_operators singular_ops += match_operators
if op in singular_ops: if op in singular_ops:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all'): elif op in ('in', 'nin', 'all', 'near'):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
if field.__class__.__name__ == 'GenericReferenceField': if field.__class__.__name__ == 'GenericReferenceField':
parts.append('_ref') parts.append('_ref')
if op and op not in match_operators: # if op and op not in match_operators:
if op:
if op in geo_operators:
if op == "within_distance":
value = {'$within': {'$center': value}}
elif op == "near":
value = {'$near': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented" % op)
elif op not in match_operators:
value = {'$' + op: value} value = {'$' + op: value}
key = '.'.join(parts) key = '.'.join(parts)
@ -402,6 +439,14 @@ class QuerySet(object):
message = u'%d items returned, instead of 1' % count message = u'%d items returned, instead of 1' % count
raise self._document.MultipleObjectsReturned(message) raise self._document.MultipleObjectsReturned(message)
def create(self, **kwargs):
"""Create new object. Returns the saved object instance.
.. versionadded:: 0.4
"""
doc = self._document(**kwargs)
doc.save()
return doc
def first(self): def first(self):
"""Retrieve the first object matching the query. """Retrieve the first object matching the query.
""" """
@ -593,6 +638,15 @@ class QuerySet(object):
elif isinstance(key, int): elif isinstance(key, int):
return self._document._from_son(self._cursor[key]) return self._document._from_son(self._cursor[key])
def distinct(self, field):
"""Return a list of distinct values for a given field.
:param field: the field to select distinct values from
.. versionadded:: 0.4
"""
return self._collection.distinct(field)
def only(self, *fields): def only(self, *fields):
"""Load only a subset of this document's fields. :: """Load only a subset of this document's fields. ::
@ -626,11 +680,13 @@ class QuerySet(object):
""" """
key_list = [] key_list = []
for key in keys: for key in keys:
if not key: continue
direction = pymongo.ASCENDING direction = pymongo.ASCENDING
if key[0] == '-': if key[0] == '-':
direction = pymongo.DESCENDING direction = pymongo.DESCENDING
if key[0] in ('-', '+'): if key[0] in ('-', '+'):
key = key[1:] key = key[1:]
key = key.replace('__', '.')
key_list.append((key, direction)) key_list.append((key, direction))
self._ordering = key_list self._ordering = key_list
@ -646,7 +702,6 @@ class QuerySet(object):
plan = self._cursor.explain() plan = self._cursor.explain()
if format: if format:
import pprint
plan = pprint.pformat(plan) plan = pprint.pformat(plan)
return plan return plan
@ -661,8 +716,8 @@ class QuerySet(object):
def _transform_update(cls, _doc_cls=None, **update): def _transform_update(cls, _doc_cls=None, **update):
"""Transform an update spec from Django-style format to Mongo format. """Transform an update spec from Django-style format to Mongo format.
""" """
operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all',
'pull_all'] 'pull', 'pull_all']
mongo_update = {} mongo_update = {}
for key, value in update.items(): for key, value in update.items():
@ -688,7 +743,7 @@ class QuerySet(object):
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'unset', 'push', 'pull'): if op in (None, 'set', 'unset', 'pop', 'push', 'pull'):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
@ -837,7 +892,7 @@ class QuerySet(object):
var total = 0.0; var total = 0.0;
var num = 0; var num = 0;
db[collection].find(query).forEach(function(doc) { db[collection].find(query).forEach(function(doc) {
if (doc[averageField]) { if (doc[averageField] !== undefined) {
total += doc[averageField]; total += doc[averageField];
num += 1; num += 1;
} }

View File

@ -264,11 +264,12 @@ class DocumentTest(unittest.TestCase):
# Indexes are lazy so use list() to perform query # Indexes are lazy so use list() to perform query
list(BlogPost.objects) list(BlogPost.objects)
info = BlogPost.objects._collection.index_information() info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info.values()) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
# tags is a list field so it shouldn't have _types in the index # tags is a list field so it shouldn't have _types in the index
self.assertTrue([('tags', 1)] in info.values()) self.assertTrue([('tags', 1)] in info)
class ExtendedBlogPost(BlogPost): class ExtendedBlogPost(BlogPost):
title = StringField() title = StringField()
@ -278,10 +279,11 @@ class DocumentTest(unittest.TestCase):
list(ExtendedBlogPost.objects) list(ExtendedBlogPost.objects)
info = ExtendedBlogPost.objects._collection.index_information() info = ExtendedBlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info.values()) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
self.assertTrue([('_types', 1), ('title', 1)] in info.values()) self.assertTrue([('_types', 1), ('title', 1)] in info)
BlogPost.drop_collection() BlogPost.drop_collection()

View File

@ -228,7 +228,8 @@ class FieldTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
comments = SortedListField(EmbeddedDocumentField(Comment), ordering='order') comments = SortedListField(EmbeddedDocumentField(Comment),
ordering='order')
tags = SortedListField(StringField()) tags = SortedListField(StringField())
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
@ -675,6 +676,63 @@ class FieldTest(unittest.TestCase):
StreamFile.drop_collection() StreamFile.drop_collection()
SetFile.drop_collection() SetFile.drop_collection()
# Make sure FileField is optional and not required
class DemoFile(Document):
file = FileField()
d = DemoFile.objects.create()
def test_file_uniqueness(self):
"""Ensure that each instance of a FileField is unique
"""
class TestFile(Document):
name = StringField()
file = FileField()
# First instance
testfile = TestFile()
testfile.name = "Hello, World!"
testfile.file.put('Hello, World!')
testfile.save()
# Second instance
testfiledupe = TestFile()
data = testfiledupe.file.read() # Should be None
self.assertTrue(testfile.name != testfiledupe.name)
self.assertTrue(testfile.file.read() != data)
TestFile.drop_collection()
def test_geo_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Event(Document):
title = StringField()
location = GeoPointField()
Event.drop_collection()
event = Event(title="Coltrane Motion @ Double Door",
location=[41.909889, -87.677137])
event.save()
info = Event.objects._collection.index_information()
self.assertTrue(u'location_2d' in info)
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
Event.drop_collection()
def test_ensure_unique_default_instances(self):
"""Ensure that every field has it's own unique default instance."""
class D(Document):
data = DictField()
data2 = DictField(default=lambda: {})
d1 = D()
d1.data['foo'] = 'bar'
d1.data2['foo'] = 'bar'
d2 = D()
self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -289,6 +289,13 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
# Test unsafe expressions
person = self.Person(name='Guido van Rossum [.\'Geek\']')
person.save()
obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first()
self.assertEqual(obj, person)
def test_filter_chaining(self): def test_filter_chaining(self):
"""Ensure filters can be chained together. """Ensure filters can be chained together.
""" """
@ -664,28 +671,27 @@ class QuerySetTest(unittest.TestCase):
post.reload() post.reload()
self.assertTrue('db' in post.tags and 'nosql' in post.tags) self.assertTrue('db' in post.tags and 'nosql' in post.tags)
tags = post.tags[:-1]
BlogPost.objects.update(pop__tags=1)
post.reload()
self.assertEqual(post.tags, tags)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self): def test_update_pull(self):
"""Ensure that the 'pull' update operation works correctly. """Ensure that the 'pull' update operation works correctly.
""" """
class Comment(EmbeddedDocument):
content = StringField()
class BlogPost(Document): class BlogPost(Document):
slug = StringField() slug = StringField()
comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField())
comment1 = Comment(content="test1") post = BlogPost(slug="test", tags=['code', 'mongodb', 'code'])
comment2 = Comment(content="test2")
post = BlogPost(slug="test", comments=[comment1, comment2])
post.save() post.save()
self.assertTrue(comment2 in post.comments)
BlogPost.objects(slug="test").update(pull__comments__content="test2") BlogPost.objects(slug="test").update(pull__tags="code")
post.reload() post.reload()
self.assertTrue(comment2 not in post.comments) self.assertTrue('code' not in post.tags)
self.assertEqual(len(post.tags), 1)
def test_order_by(self): def test_order_by(self):
"""Ensure that QuerySets may be ordered. """Ensure that QuerySets may be ordered.
@ -948,11 +954,14 @@ class QuerySetTest(unittest.TestCase):
def test_average(self): def test_average(self):
"""Ensure that field can be averaged correctly. """Ensure that field can be averaged correctly.
""" """
self.Person(name='person', age=0).save()
self.assertEqual(int(self.Person.objects.average('age')), 0)
ages = [23, 54, 12, 94, 27] ages = [23, 54, 12, 94, 27]
for i, age in enumerate(ages): for i, age in enumerate(ages):
self.Person(name='test%s' % i, age=age).save() self.Person(name='test%s' % i, age=age).save()
avg = float(sum(ages)) / len(ages) avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
@ -970,15 +979,30 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save() self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
def test_distinct(self):
"""Ensure that the QuerySet.distinct method works.
"""
self.Person(name='Mr Orange', age=20).save()
self.Person(name='Mr White', age=20).save()
self.Person(name='Mr Orange', age=30).save()
self.assertEqual(self.Person.objects.distinct('name'),
['Mr Orange', 'Mr White'])
self.assertEqual(self.Person.objects.distinct('age'), [20, 30])
def test_custom_manager(self): def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected. """Ensure that custom QuerySetManager instances work as expected.
""" """
class BlogPost(Document): class BlogPost(Document):
tags = ListField(StringField()) tags = ListField(StringField())
deleted = BooleanField(default=False)
@queryset_manager
def objects(doc_cls, queryset):
return queryset(deleted=False)
@queryset_manager @queryset_manager
def music_posts(doc_cls, queryset): def music_posts(doc_cls, queryset):
return queryset(tags='music') return queryset(tags='music', deleted=False)
BlogPost.drop_collection() BlogPost.drop_collection()
@ -988,6 +1012,8 @@ class QuerySetTest(unittest.TestCase):
post2.save() post2.save()
post3 = BlogPost(tags=['film', 'actors']) post3 = BlogPost(tags=['film', 'actors'])
post3.save() post3.save()
post4 = BlogPost(tags=['film', 'actors'], deleted=True)
post4.save()
self.assertEqual([p.id for p in BlogPost.objects], self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id]) [post1.id, post2.id, post3.id])
@ -1087,8 +1113,9 @@ class QuerySetTest(unittest.TestCase):
# Indexes are lazy so use list() to perform query # Indexes are lazy so use list() to perform query
list(BlogPost.objects) list(BlogPost.objects)
info = BlogPost.objects._collection.index_information() info = BlogPost.objects._collection.index_information()
self.assertTrue([('_types', 1)] in info.values()) info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('date', -1)] in info.values()) self.assertTrue([('_types', 1)] in info)
self.assertTrue([('_types', 1), ('date', -1)] in info)
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1164,6 +1191,81 @@ class QuerySetTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.Person.drop_collection() self.Person.drop_collection()
def test_geospatial_operators(self):
"""Ensure that geospatial queries are working.
"""
class Event(Document):
title = StringField()
date = DateTimeField()
location = GeoPointField()
def __unicode__(self):
return self.title
Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137])
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
date=datetime.now() - timedelta(days=10),
location=[37.7749295, -122.4194155])
event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(),
location=[41.900474, -87.686638])
event1.save()
event2.save()
event3.save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(location__near=[41.9120459, -87.67892])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 miles of pitchfork office, chicago
point_and_distance = [[41.9120459, -87.67892], 5]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertTrue(event2 not in events)
self.assertTrue(event1 in events)
self.assertTrue(event3 in events)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date")
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2])
# find events around san francisco
point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1 mile of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance)
events = events.order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
# check that within_box works
box = [(35.0, -125.0), (40.0, -100.0)]
events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
Event.drop_collection()
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
@ -1219,5 +1321,22 @@ class QTest(unittest.TestCase):
query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')']
self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query)
def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly"""
connect(db='mongoenginetest')
class User(Document):
pass
class Post(Document):
created_user = ReferenceField(User)
user = User.objects.create()
Post.objects.create(created_user=user)
self.assertEqual(Post.objects.filter(created_user=user).count(), 1)
self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()