diff --git a/.gitignore b/.gitignore index 42dcc6e6..51a9ca1d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ *.pyc .*.swp +*.egg docs/.build docs/_build build/ dist/ -mongoengine.egg-info/ \ No newline at end of file +mongoengine.egg-info/ +env/ \ No newline at end of file diff --git a/docs/changelog.rst b/docs/changelog.rst index 479ea21c..8dd5b00d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,20 @@ Changelog ========= +Changes in v0.4 +=============== +- Added ``SortedListField`` +- Added ``EmailField`` +- Added ``GeoPointField`` +- Added ``exact`` and ``iexact`` match operators to ``QuerySet`` +- Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts +- Fixed bug in Q-objects +- Fixed document inheritance primary key issue +- Base class can now be defined for ``DictField`` +- Fixed MRO error that occured on document inheritance +- Introduced ``min_length`` for ``StringField`` +- Other minor fixes + Changes in v0.3 =============== - Added MapReduce support diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 113ee431..bef19bc5 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -71,7 +71,7 @@ Available operators are as follows: * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values -* ``all`` -- every item in array is in list of values provided +* ``all`` -- every item in list of values provided is in array * ``size`` -- the size of the array is * ``exists`` -- value for field exists @@ -174,7 +174,7 @@ custom manager methods as you like:: @queryset_manager def live_posts(doc_cls, queryset): - return queryset(published=True).filter(published=True) + return queryset.filter(published=True) BlogPost(title='test1', published=False).save() BlogPost(title='test2', published=True).save() @@ -399,6 +399,7 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount +* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list * ``pull`` -- remove a value from a list diff --git a/mongoengine/base.py b/mongoengine/base.py index b6d5a63b..836817da 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -22,6 +22,7 @@ class BaseField(object): # Fields may have _types inserted into indexes by default _index_with_types = True + _geo_index = False def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, @@ -229,12 +230,18 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): id_field = None base_indexes = [] + base_meta = {} # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta') and 'collection' in base._meta: collection = base._meta['collection'] + # Propagate index options. + for key in ('index_background', 'index_drop_dups', 'index_opts'): + if key in base._meta: + base_meta[key] = base._meta[key] + id_field = id_field or base._meta.get('id_field') base_indexes += base._meta.get('indexes', []) @@ -245,7 +252,11 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): 'ordering': [], # default ordering applied at runtime 'indexes': [], # indexes to be ensured at runtime 'id_field': id_field, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': {}, } + meta.update(base_meta) # Apply document-defined meta options meta.update(attrs.get('meta', {})) @@ -254,7 +265,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set up collection manager, needs the class to have fields so use # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) - new_class.objects = QuerySetManager() + + # Provide a default queryset unless one has been manually provided + if not hasattr(new_class, 'objects'): + new_class.objects = QuerySetManager() user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes @@ -265,7 +279,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Generate a list of indexes needed by uniqueness constraints if field.unique: field.required = True - unique_fields = [field_name] + unique_fields = [field.db_field] # Add any unique_with fields to the back of the index spec if field.unique_with: @@ -338,8 +352,8 @@ class BaseDocument(object): try: field._validate(value) except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field of type "' + - field.__class__.__name__ + '"') + raise ValidationError('Invalid value for field of type "%s": %s' + % (field.__class__.__name__, value)) elif field.required: raise ValidationError('Field "%s" is required' % field.name) @@ -414,6 +428,8 @@ class BaseDocument(object): self._meta.get('allow_inheritance', True) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] + if data.has_key('_id') and not data['_id']: + del data['_id'] return data @classmethod @@ -445,7 +461,9 @@ class BaseDocument(object): for field_name, field in cls._fields.items(): if field.db_field in data: - data[field_name] = field.to_python(data[field.db_field]) + value = data[field.db_field] + data[field_name] = (value if value is None + else field.to_python(value)) obj = cls(**data) obj._present_fields = present_fields diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index d4b0ff0b..da0005c8 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -32,6 +32,9 @@ class User(Document): last_login = DateTimeField(default=datetime.datetime.now) date_joined = DateTimeField(default=datetime.datetime.now) + def __unicode__(self): + return self.username + def get_full_name(self): """Returns the users first and last names, separated by a space. """ diff --git a/mongoengine/django/tests.py b/mongoengine/django/tests.py new file mode 100644 index 00000000..a8d7c7ff --- /dev/null +++ b/mongoengine/django/tests.py @@ -0,0 +1,21 @@ +#coding: utf-8 +from django.test import TestCase +from django.conf import settings + +from mongoengine import connect + +class MongoTestCase(TestCase): + """ + TestCase class that clear the collection between the tests + """ + db_name = 'test_%s' % settings.MONGO_DATABASE_NAME + def __init__(self, methodName='runtest'): + self.db = connect(self.db_name) + super(MongoTestCase, self).__init__(methodName) + + def _post_teardown(self): + super(MongoTestCase, self)._post_teardown() + for collection in self.db.collection_names(): + if collection == 'system.indexes': + continue + self.db.drop_collection(collection) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a276399e..76bc4fbe 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -16,11 +16,14 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', +<<<<<<< HEAD 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField'] +======= + 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] +>>>>>>> 32e66b29f44f3015be099851201241caee92054f RECURSIVE_REFERENCE_CONSTANT = 'self' - class StringField(BaseField): """A unicode string field. """ @@ -67,6 +70,9 @@ class StringField(BaseField): regex = r'%s$' elif op == 'exact': regex = r'^%s$' + + # escape unsafe characters which could lead to a re.error + value = re.escape(value) value = re.compile(regex % value, flags) return value @@ -264,6 +270,7 @@ class ListField(BaseField): raise ValidationError('Argument to ListField constructor must be ' 'a valid field') self.field = field + kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) def __get__(self, instance, owner): @@ -356,6 +363,7 @@ class DictField(BaseField): def __init__(self, basecls=None, *args, **kwargs): self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) + kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) def validate(self, value): @@ -372,24 +380,6 @@ class DictField(BaseField): def lookup_member(self, member_name): return self.basecls(db_field=member_name) -class GeoLocationField(DictField): - """Supports geobased fields""" - - def validate(self, value): - """Make sure that a geo-value is of type (x, y) - """ - if not isinstance(value, tuple) and not isinstance(value, list): - raise ValidationError('GeoLocationField can only hold tuples or lists of (x, y)') - - if len(value) <> 2: - raise ValidationError('GeoLocationField must have exactly two elements (x, y)') - - def to_mongo(self, value): - return {'x': value[0], 'y': value[1]} - - def to_python(self, value): - return value.keys() - class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -456,7 +446,6 @@ class ReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) - class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -503,6 +492,7 @@ class GenericReferenceField(BaseField): def prepare_query_value(self, op, value): return self.to_mongo(value)['_ref'] + class BinaryField(BaseField): """A binary data field. """ @@ -524,14 +514,25 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') +<<<<<<< HEAD +======= + +>>>>>>> 32e66b29f44f3015be099851201241caee92054f class GridFSProxy(object): """Proxy object to handle writing and reading of files to and from GridFS """ +<<<<<<< HEAD def __init__(self): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = None # Store GridFS id for file +======= + def __init__(self, grid_id=None): + self.fs = gridfs.GridFS(_get_db()) # Filesystem instance + self.newfile = None # Used for partial writes + self.grid_id = grid_id # Store GridFS id for file +>>>>>>> 32e66b29f44f3015be099851201241caee92054f def __getattr__(self, name): obj = self.get() @@ -542,8 +543,17 @@ class GridFSProxy(object): return self def get(self, id=None): +<<<<<<< HEAD try: return self.fs.get(id or self.grid_id) except: return None # File has been deleted +======= + if id: + self.grid_id = id + try: + return self.fs.get(id or self.grid_id) + except: + return None # File has been deleted +>>>>>>> 32e66b29f44f3015be099851201241caee92054f def new_file(self, **kwargs): self.newfile = self.fs.new_file(**kwargs) @@ -565,8 +575,15 @@ class GridFSProxy(object): self.newfile.writelines(lines) def read(self): +<<<<<<< HEAD try: return self.get().read() except: return None +======= + try: + return self.get().read() + except: + return None +>>>>>>> 32e66b29f44f3015be099851201241caee92054f def delete(self): # Delete file from GridFS, FileField still remains @@ -584,29 +601,61 @@ class GridFSProxy(object): msg = "The close() method is only necessary after calling write()" warnings.warn(msg) +<<<<<<< HEAD +======= + +>>>>>>> 32e66b29f44f3015be099851201241caee92054f class FileField(BaseField): """A GridFS storage field. """ def __init__(self, **kwargs): +<<<<<<< HEAD self.gridfs = GridFSProxy() +======= +>>>>>>> 32e66b29f44f3015be099851201241caee92054f super(FileField, self).__init__(**kwargs) def __get__(self, instance, owner): if instance is None: return self +<<<<<<< HEAD return self.gridfs +======= + # Check if a file already exists for this model + grid_file = instance._data.get(self.name) + if grid_file: + return grid_file + return GridFSProxy() +>>>>>>> 32e66b29f44f3015be099851201241caee92054f def __set__(self, instance, value): if isinstance(value, file) or isinstance(value, str): # using "FileField() = file/string" notation +<<<<<<< HEAD self.gridfs.put(value) +======= + grid_file = instance._data.get(self.name) + # If a file already exists, delete it + if grid_file: + try: + grid_file.delete() + except: + pass + # Create a new file with the new data + grid_file.put(value) + else: + # Create a new proxy object as we don't already have one + instance._data[self.name] = GridFSProxy() + instance._data[self.name].put(value) +>>>>>>> 32e66b29f44f3015be099851201241caee92054f else: instance._data[self.name] = value def to_mongo(self, value): # Store the GridFS file id in MongoDB +<<<<<<< HEAD return self.gridfs.grid_id def to_python(self, value): @@ -617,3 +666,36 @@ class FileField(BaseField): assert isinstance(value, GridFSProxy) assert isinstance(value.grid_id, pymongo.objectid.ObjectId) +======= + if isinstance(value, GridFSProxy) and value.grid_id is not None: + return value.grid_id + return None + + def to_python(self, value): + if value is not None: + return GridFSProxy(value) + + def validate(self, value): + if value.grid_id is not None: + assert isinstance(value, GridFSProxy) + assert isinstance(value.grid_id, pymongo.objectid.ObjectId) + +class GeoPointField(BaseField): + """A list storing a latitude and longitude. + """ + + _geo_index = True + + def validate(self, value): + """Make sure that a geo-value is of type (x, y) + """ + if not isinstance(value, (list, tuple)): + raise ValidationError('GeoPointField can only accept tuples or ' + 'lists of (x, y)') + + if not len(value) == 2: + raise ValidationError('Value must be a two-dimensional point.') + if (not isinstance(value[0], (float, int)) and + not isinstance(value[1], (float, int))): + raise ValidationError('Both values in point must be float or int.') +>>>>>>> 32e66b29f44f3015be099851201241caee92054f diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 069ab113..662fa8c3 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1,5 +1,6 @@ from connection import _get_db +import pprint import pymongo import re import copy @@ -114,13 +115,11 @@ class Q(object): value, field_js = self._build_op_js(op, key, value, value_name) js_scope[value_name] = value js.append(field_js) - print ' && '.join(js) return ' && '.join(js) def _build_op_js(self, op, key, value, value_name): """Substitute the values in to the correct chunk of Javascript. """ - print op, key, value, value_name if isinstance(value, RE_TYPE): # Regexes are handled specially if op.strip('$') == 'ne': @@ -134,6 +133,16 @@ class Q(object): if isinstance(value, pymongo.objectid.ObjectId): value = unicode(value) + # Handle DBRef + if isinstance(value, pymongo.dbref.DBRef): + op_js = '(this.%(field)s.$id == "%(id)s" &&'\ + ' this.%(field)s.$ref == "%(ref)s")' % { + 'field': key, + 'id': unicode(value.id), + 'ref': unicode(value.collection) + } + value = None + # Perform the substitution operation_js = op_js % { 'field': key, @@ -163,7 +172,8 @@ class QuerySet(object): self._limit = None self._skip = None - def ensure_index(self, key_or_list): + def ensure_index(self, key_or_list, drop_dups=False, background=False, + **kwargs): """Ensure that the given indexes are in place. :param key_or_list: a single index key or a list of index keys (to @@ -171,7 +181,8 @@ class QuerySet(object): or a **-** to determine the index ordering """ index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list) + self._collection.ensure_index(index_list, drop_dups=drop_dups, + background=background) return self @classmethod @@ -229,6 +240,10 @@ class QuerySet(object): """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` """ return self.__call__(*q_objs, **query) + + def all(self): + """Returns all documents.""" + return self.__call__() @property def _collection(self): @@ -237,25 +252,34 @@ class QuerySet(object): """ if not self._accessed_collection: self._accessed_collection = True + + background = self._document._meta.get('index_background', False) + drop_dups = self._document._meta.get('index_drop_dups', False) + index_opts = self._document._meta.get('index_options', {}) # Ensure document-defined indexes are created if self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']: - #self.ensure_index(key_or_list) - self._collection.ensure_index(key_or_list) + self._collection.ensure_index(key_or_list, + background=background, **index_opts) # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: - self._collection.ensure_index(index, unique=True) - + self._collection.ensure_index(index, unique=True, + background=background, drop_dups=drop_dups, **index_opts) + # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: - self._collection.ensure_index('_types') + self._collection.ensure_index('_types', + background=background, **index_opts) # Ensure all needed field indexes are created - for field_name, field_instance in self._document._fields.iteritems(): - if field_instance.__class__.__name__ == 'GeoLocationField': - self._collection.ensure_index([(field_name, pymongo.GEO2D),]) + for field in self._document._fields.values(): + if field.__class__._geo_index: + index_spec = [(field.db_field, pymongo.GEO2D)] + self._collection.ensure_index(index_spec, + background=background, **index_opts) + return self._collection_obj @property @@ -311,9 +335,10 @@ class QuerySet(object): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'near'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + 'all', 'size', 'exists'] + geo_operators = ['within_distance', 'within_box', 'near'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -321,7 +346,7 @@ class QuerySet(object): parts = key.split('__') # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in operators + match_operators: + if parts[-1] in operators + match_operators + geo_operators: op = parts.pop() if _doc_cls: @@ -335,15 +360,27 @@ class QuerySet(object): singular_ops += match_operators if op in singular_ops: value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all'): + elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] if field.__class__.__name__ == 'GenericReferenceField': parts.append('_ref') - if op and op not in match_operators: - value = {'$' + op: value} + # if op and op not in match_operators: + if op: + if op in geo_operators: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "near": + value = {'$near': value} + elif op == 'within_box': + value = {'$within': {'$box': value}} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented" % op) + elif op not in match_operators: + value = {'$' + op: value} key = '.'.join(parts) if op is None or key not in mongo_query: @@ -402,6 +439,14 @@ class QuerySet(object): message = u'%d items returned, instead of 1' % count raise self._document.MultipleObjectsReturned(message) + def create(self, **kwargs): + """Create new object. Returns the saved object instance. + .. versionadded:: 0.4 + """ + doc = self._document(**kwargs) + doc.save() + return doc + def first(self): """Retrieve the first object matching the query. """ @@ -592,6 +637,15 @@ class QuerySet(object): # Integer index provided elif isinstance(key, int): return self._document._from_son(self._cursor[key]) + + def distinct(self, field): + """Return a list of distinct values for a given field. + + :param field: the field to select distinct values from + + .. versionadded:: 0.4 + """ + return self._collection.distinct(field) def only(self, *fields): """Load only a subset of this document's fields. :: @@ -626,11 +680,13 @@ class QuerySet(object): """ key_list = [] for key in keys: + if not key: continue direction = pymongo.ASCENDING if key[0] == '-': direction = pymongo.DESCENDING if key[0] in ('-', '+'): key = key[1:] + key = key.replace('__', '.') key_list.append((key, direction)) self._ordering = key_list @@ -646,7 +702,6 @@ class QuerySet(object): plan = self._cursor.explain() if format: - import pprint plan = pprint.pformat(plan) return plan @@ -661,8 +716,8 @@ class QuerySet(object): def _transform_update(cls, _doc_cls=None, **update): """Transform an update spec from Django-style format to Mongo format. """ - operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', - 'pull_all'] + operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', + 'pull', 'pull_all'] mongo_update = {} for key, value in update.items(): @@ -688,7 +743,7 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'set', 'unset', 'push', 'pull'): + if op in (None, 'set', 'unset', 'pop', 'push', 'pull'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] @@ -837,7 +892,7 @@ class QuerySet(object): var total = 0.0; var num = 0; db[collection].find(query).forEach(function(doc) { - if (doc[averageField]) { + if (doc[averageField] !== undefined) { total += doc[averageField]; num += 1; } diff --git a/tests/document.py b/tests/document.py index 8bc907c5..1160b353 100644 --- a/tests/document.py +++ b/tests/document.py @@ -264,11 +264,12 @@ class DocumentTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index - self.assertTrue([('tags', 1)] in info.values()) + self.assertTrue([('tags', 1)] in info) class ExtendedBlogPost(BlogPost): title = StringField() @@ -278,10 +279,11 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) - self.assertTrue([('_types', 1), ('title', 1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) + self.assertTrue([('_types', 1), ('title', 1)] in info) BlogPost.drop_collection() diff --git a/tests/fields.py b/tests/fields.py index 8dddcb3e..536a9f1a 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -228,7 +228,8 @@ class FieldTest(unittest.TestCase): class BlogPost(Document): content = StringField() - comments = SortedListField(EmbeddedDocumentField(Comment), ordering='order') + comments = SortedListField(EmbeddedDocumentField(Comment), + ordering='order') tags = SortedListField(StringField()) post = BlogPost(content='Went for a walk today...') @@ -675,6 +676,63 @@ class FieldTest(unittest.TestCase): StreamFile.drop_collection() SetFile.drop_collection() + # Make sure FileField is optional and not required + class DemoFile(Document): + file = FileField() + d = DemoFile.objects.create() + + def test_file_uniqueness(self): + """Ensure that each instance of a FileField is unique + """ + class TestFile(Document): + name = StringField() + file = FileField() + + # First instance + testfile = TestFile() + testfile.name = "Hello, World!" + testfile.file.put('Hello, World!') + testfile.save() + + # Second instance + testfiledupe = TestFile() + data = testfiledupe.file.read() # Should be None + + self.assertTrue(testfile.name != testfiledupe.name) + self.assertTrue(testfile.file.read() != data) + + TestFile.drop_collection() + + def test_geo_indexes(self): + """Ensure that indexes are created automatically for GeoPointFields. + """ + class Event(Document): + title = StringField() + location = GeoPointField() + + Event.drop_collection() + event = Event(title="Coltrane Motion @ Double Door", + location=[41.909889, -87.677137]) + event.save() + + info = Event.objects._collection.index_information() + self.assertTrue(u'location_2d' in info) + self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) + + Event.drop_collection() + + def test_ensure_unique_default_instances(self): + """Ensure that every field has it's own unique default instance.""" + class D(Document): + data = DictField() + data2 = DictField(default=lambda: {}) + + d1 = D() + d1.data['foo'] = 'bar' + d1.data2['foo'] = 'bar' + d2 = D() + self.assertEqual(d2.data, {}) + self.assertEqual(d2.data2, {}) if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index 51f92993..e3912246 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -288,6 +288,13 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) + + # Test unsafe expressions + person = self.Person(name='Guido van Rossum [.\'Geek\']') + person.save() + + obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() + self.assertEqual(obj, person) def test_filter_chaining(self): """Ensure filters can be chained together. @@ -664,28 +671,27 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertTrue('db' in post.tags and 'nosql' in post.tags) + tags = post.tags[:-1] + BlogPost.objects.update(pop__tags=1) + post.reload() + self.assertEqual(post.tags, tags) + BlogPost.drop_collection() def test_update_pull(self): """Ensure that the 'pull' update operation works correctly. """ - class Comment(EmbeddedDocument): - content = StringField() - class BlogPost(Document): slug = StringField() - comments = ListField(EmbeddedDocumentField(Comment)) + tags = ListField(StringField()) - comment1 = Comment(content="test1") - comment2 = Comment(content="test2") - - post = BlogPost(slug="test", comments=[comment1, comment2]) + post = BlogPost(slug="test", tags=['code', 'mongodb', 'code']) post.save() - self.assertTrue(comment2 in post.comments) - BlogPost.objects(slug="test").update(pull__comments__content="test2") + BlogPost.objects(slug="test").update(pull__tags="code") post.reload() - self.assertTrue(comment2 not in post.comments) + self.assertTrue('code' not in post.tags) + self.assertEqual(len(post.tags), 1) def test_order_by(self): """Ensure that QuerySets may be ordered. @@ -948,11 +954,14 @@ class QuerySetTest(unittest.TestCase): def test_average(self): """Ensure that field can be averaged correctly. """ + self.Person(name='person', age=0).save() + self.assertEqual(int(self.Person.objects.average('age')), 0) + ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): self.Person(name='test%s' % i, age=age).save() - avg = float(sum(ages)) / len(ages) + avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) self.Person(name='ageless person').save() @@ -970,15 +979,30 @@ class QuerySetTest(unittest.TestCase): self.Person(name='ageless person').save() self.assertEqual(int(self.Person.objects.sum('age')), sum(ages)) + def test_distinct(self): + """Ensure that the QuerySet.distinct method works. + """ + self.Person(name='Mr Orange', age=20).save() + self.Person(name='Mr White', age=20).save() + self.Person(name='Mr Orange', age=30).save() + self.assertEqual(self.Person.objects.distinct('name'), + ['Mr Orange', 'Mr White']) + self.assertEqual(self.Person.objects.distinct('age'), [20, 30]) + def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """ class BlogPost(Document): tags = ListField(StringField()) + deleted = BooleanField(default=False) + + @queryset_manager + def objects(doc_cls, queryset): + return queryset(deleted=False) @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music') + return queryset(tags='music', deleted=False) BlogPost.drop_collection() @@ -988,6 +1012,8 @@ class QuerySetTest(unittest.TestCase): post2.save() post3 = BlogPost(tags=['film', 'actors']) post3.save() + post4 = BlogPost(tags=['film', 'actors'], deleted=True) + post4.save() self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) @@ -1087,8 +1113,9 @@ class QuerySetTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() - self.assertTrue([('_types', 1)] in info.values()) - self.assertTrue([('_types', 1), ('date', -1)] in info.values()) + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('_types', 1)] in info) + self.assertTrue([('_types', 1), ('date', -1)] in info) BlogPost.drop_collection() @@ -1164,6 +1191,81 @@ class QuerySetTest(unittest.TestCase): def tearDown(self): self.Person.drop_collection() + def test_geospatial_operators(self): + """Ensure that geospatial queries are working. + """ + class Event(Document): + title = StringField() + date = DateTimeField() + location = GeoPointField() + + def __unicode__(self): + return self.title + + Event.drop_collection() + + event1 = Event(title="Coltrane Motion @ Double Door", + date=datetime.now() - timedelta(days=1), + location=[41.909889, -87.677137]) + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + date=datetime.now() - timedelta(days=10), + location=[37.7749295, -122.4194155]) + event3 = Event(title="Coltrane Motion @ Empty Bottle", + date=datetime.now(), + location=[41.900474, -87.686638]) + + event1.save() + event2.save() + event3.save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(location__near=[41.9120459, -87.67892]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + # find events within 5 miles of pitchfork office, chicago + point_and_distance = [[41.9120459, -87.67892], 5] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 2) + events = list(events) + self.assertTrue(event2 not in events) + self.assertTrue(event1 in events) + self.assertTrue(event3 in events) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[41.9120459, -87.67892]) + events = events.order_by("-date") + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event3, event1, event2]) + + # find events around san francisco + point_and_distance = [[37.7566023, -122.415579], 10] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 1 mile of greenpoint, broolyn, nyc, ny + point_and_distance = [[40.7237134, -73.9509714], 1] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 0) + + # ensure ordering is respected by "within_distance" + point_and_distance = [[41.9120459, -87.67892], 10] + events = Event.objects(location__within_distance=point_and_distance) + events = events.order_by("-date") + self.assertEqual(events.count(), 2) + self.assertEqual(events[0], event3) + + # check that within_box works + box = [(35.0, -125.0), (40.0, -100.0)] + events = Event.objects(location__within_box=box) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0].id, event2.id) + + Event.drop_collection() + class QTest(unittest.TestCase): @@ -1218,6 +1320,23 @@ class QTest(unittest.TestCase): query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) + + def test_q_with_dbref(self): + """Ensure Q objects handle DBRefs correctly""" + connect(db='mongoenginetest') + + class User(Document): + pass + + class Post(Document): + created_user = ReferenceField(User) + + user = User.objects.create() + Post.objects.create(created_user=user) + + self.assertEqual(Post.objects.filter(created_user=user).count(), 1) + self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + if __name__ == '__main__': unittest.main()