diff --git a/.gitignore b/.gitignore index 42dcc6e6..51a9ca1d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ *.pyc .*.swp +*.egg docs/.build docs/_build build/ dist/ -mongoengine.egg-info/ \ No newline at end of file +mongoengine.egg-info/ +env/ \ No newline at end of file diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 1fd2ed57..bef19bc5 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -71,7 +71,7 @@ Available operators are as follows: * ``in`` -- value is in list (a list of values should be provided) * ``nin`` -- value is not in list (a list of values should be provided) * ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values -* ``all`` -- every item in array is in list of values provided +* ``all`` -- every item in list of values provided is in array * ``size`` -- the size of the array is * ``exists`` -- value for field exists @@ -399,6 +399,7 @@ that you may use with these methods: * ``unset`` -- delete a particular value (since MongoDB v1.3+) * ``inc`` -- increment a value by a given amount * ``dec`` -- decrement a value by a given amount +* ``pop`` -- remove the last item from a list * ``push`` -- append a value to a list * ``push_all`` -- append several values to a list * ``pull`` -- remove a value from a list diff --git a/mongoengine/base.py b/mongoengine/base.py index 086c7874..0cbd707d 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -230,12 +230,18 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): id_field = None base_indexes = [] + base_meta = {} # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta') and 'collection' in base._meta: collection = base._meta['collection'] + # Propagate index options. + for key in ('index_background', 'index_drop_dups', 'index_opts'): + if key in base._meta: + base_meta[key] = base._meta[key] + id_field = id_field or base._meta.get('id_field') base_indexes += base._meta.get('indexes', []) @@ -246,7 +252,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): 'ordering': [], # default ordering applied at runtime 'indexes': [], # indexes to be ensured at runtime 'id_field': id_field, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': {}, + 'queryset_class': QuerySet, } + meta.update(base_meta) # Apply document-defined meta options meta.update(attrs.get('meta', {})) @@ -255,7 +266,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set up collection manager, needs the class to have fields so use # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) - new_class.objects = QuerySetManager() + + # Provide a default queryset unless one has been manually provided + if not hasattr(new_class, 'objects'): + new_class.objects = QuerySetManager() user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes @@ -266,7 +280,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Generate a list of indexes needed by uniqueness constraints if field.unique: field.required = True - unique_fields = [field_name] + unique_fields = [field.db_field] # Add any unique_with fields to the back of the index spec if field.unique_with: @@ -415,6 +429,8 @@ class BaseDocument(object): self._meta.get('allow_inheritance', True) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] + if data.has_key('_id') and not data['_id']: + del data['_id'] return data @classmethod @@ -446,7 +462,9 @@ class BaseDocument(object): for field_name, field in cls._fields.items(): if field.db_field in data: - data[field_name] = field.to_python(data[field.db_field]) + value = data[field.db_field] + data[field_name] = (value if value is None + else field.to_python(value)) obj = cls(**data) obj._present_fields = present_fields diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index d4b0ff0b..da0005c8 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -32,6 +32,9 @@ class User(Document): last_login = DateTimeField(default=datetime.datetime.now) date_joined = DateTimeField(default=datetime.datetime.now) + def __unicode__(self): + return self.username + def get_full_name(self): """Returns the users first and last names, separated by a space. """ diff --git a/mongoengine/django/tests.py b/mongoengine/django/tests.py new file mode 100644 index 00000000..a8d7c7ff --- /dev/null +++ b/mongoengine/django/tests.py @@ -0,0 +1,21 @@ +#coding: utf-8 +from django.test import TestCase +from django.conf import settings + +from mongoengine import connect + +class MongoTestCase(TestCase): + """ + TestCase class that clear the collection between the tests + """ + db_name = 'test_%s' % settings.MONGO_DATABASE_NAME + def __init__(self, methodName='runtest'): + self.db = connect(self.db_name) + super(MongoTestCase, self).__init__(methodName) + + def _post_teardown(self): + super(MongoTestCase, self)._post_teardown() + for collection in self.db.collection_names(): + if collection == 'system.indexes': + continue + self.db.drop_collection(collection) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f84f751b..418f57cc 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -66,6 +66,9 @@ class StringField(BaseField): regex = r'%s$' elif op == 'exact': regex = r'^%s$' + + # escape unsafe characters which could lead to a re.error + value = re.escape(value) value = re.compile(regex % value, flags) return value @@ -263,7 +266,7 @@ class ListField(BaseField): raise ValidationError('Argument to ListField constructor must be ' 'a valid field') self.field = field - kwargs.setdefault('default', []) + kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) def __get__(self, instance, owner): @@ -356,7 +359,7 @@ class DictField(BaseField): def __init__(self, basecls=None, *args, **kwargs): self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) - kwargs.setdefault('default', {}) + kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) def validate(self, value): @@ -507,14 +510,19 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') + +class GridFSError(Exception): + pass + + class GridFSProxy(object): """Proxy object to handle writing and reading of files to and from GridFS """ - def __init__(self): + def __init__(self, grid_id=None): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes - self.grid_id = None # Store GridFS id for file + self.grid_id = grid_id # Store GridFS id for file def __getattr__(self, name): obj = self.get() @@ -525,21 +533,30 @@ class GridFSProxy(object): return self def get(self, id=None): - if id: self.grid_id = id - try: return self.fs.get(id or self.grid_id) - except: return None # File has been deleted + if id: + self.grid_id = id + try: + return self.fs.get(id or self.grid_id) + except: + return None # File has been deleted def new_file(self, **kwargs): self.newfile = self.fs.new_file(**kwargs) self.grid_id = self.newfile._id def put(self, file, **kwargs): + if self.grid_id: + raise GridFSError('This document alreay has a file. Either delete ' + 'it or call replace to overwrite it') self.grid_id = self.fs.put(file, **kwargs) def write(self, string): - if not self.newfile: + if self.grid_id: + if not self.newfile: + raise GridFSError('This document alreay has a file. Either ' + 'delete it or call replace to overwrite it') + else: self.new_file() - self.grid_id = self.newfile._id self.newfile.write(string) def writelines(self, lines): @@ -549,8 +566,10 @@ class GridFSProxy(object): self.newfile.writelines(lines) def read(self): - try: return self.get().read() - except: return None + try: + return self.get().read() + except: + return None def delete(self): # Delete file from GridFS, FileField still remains @@ -568,38 +587,52 @@ class GridFSProxy(object): msg = "The close() method is only necessary after calling write()" warnings.warn(msg) + class FileField(BaseField): """A GridFS storage field. """ def __init__(self, **kwargs): - self.gridfs = GridFSProxy() super(FileField, self).__init__(**kwargs) def __get__(self, instance, owner): if instance is None: return self - return self.gridfs + # Check if a file already exists for this model + grid_file = instance._data.get(self.name) + if grid_file: + return grid_file + return GridFSProxy() def __set__(self, instance, value): if isinstance(value, file) or isinstance(value, str): # using "FileField() = file/string" notation - self.gridfs.put(value) + grid_file = instance._data.get(self.name) + # If a file already exists, delete it + if grid_file: + try: + grid_file.delete() + except: + pass + # Create a new file with the new data + grid_file.put(value) + else: + # Create a new proxy object as we don't already have one + instance._data[self.name] = GridFSProxy() + instance._data[self.name].put(value) else: instance._data[self.name] = value def to_mongo(self, value): # Store the GridFS file id in MongoDB - if self.gridfs.grid_id is not None: - return self.gridfs.grid_id + if isinstance(value, GridFSProxy) and value.grid_id is not None: + return value.grid_id return None def to_python(self, value): - # Use stored value (id) to lookup file in GridFS - if self.gridfs.grid_id is not None: - return self.gridfs.get(id=value) - return None + if value is not None: + return GridFSProxy(value) def validate(self, value): if value.grid_id is not None: @@ -623,4 +656,4 @@ class GeoPointField(BaseField): raise ValidationError('Value must be a two-dimensional point.') if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): - raise ValidationError('Both values in point must be float or int.') \ No newline at end of file + raise ValidationError('Both values in point must be float or int.') diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 00a7f7a2..8b486093 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -135,7 +135,6 @@ class Q(object): # Handle DBRef if isinstance(value, pymongo.dbref.DBRef): - # this.created_user.$id == "4c4c56f8cc1831418c000000" op_js = '(this.%(field)s.$id == "%(id)s" &&'\ ' this.%(field)s.$ref == "%(ref)s")' % { 'field': key, @@ -173,7 +172,8 @@ class QuerySet(object): self._limit = None self._skip = None - def ensure_index(self, key_or_list): + def ensure_index(self, key_or_list, drop_dups=False, background=False, + **kwargs): """Ensure that the given indexes are in place. :param key_or_list: a single index key or a list of index keys (to @@ -181,7 +181,8 @@ class QuerySet(object): or a **-** to determine the index ordering """ index_list = QuerySet._build_index_spec(self._document, key_or_list) - self._collection.ensure_index(index_list) + self._collection.ensure_index(index_list, drop_dups=drop_dups, + background=background) return self @classmethod @@ -239,6 +240,10 @@ class QuerySet(object): """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` """ return self.__call__(*q_objs, **query) + + def all(self): + """Returns all documents.""" + return self.__call__() @property def _collection(self): @@ -247,25 +252,33 @@ class QuerySet(object): """ if not self._accessed_collection: self._accessed_collection = True + + background = self._document._meta.get('index_background', False) + drop_dups = self._document._meta.get('index_drop_dups', False) + index_opts = self._document._meta.get('index_options', {}) # Ensure document-defined indexes are created if self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']: - self._collection.ensure_index(key_or_list) + self._collection.ensure_index(key_or_list, + background=background, **index_opts) # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: - self._collection.ensure_index(index, unique=True) + self._collection.ensure_index(index, unique=True, + background=background, drop_dups=drop_dups, **index_opts) # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: - self._collection.ensure_index('_types') + self._collection.ensure_index('_types', + background=background, **index_opts) # Ensure all needed field indexes are created for field in self._document._fields.values(): if field.__class__._geo_index: index_spec = [(field.db_field, pymongo.GEO2D)] - self._collection.ensure_index(index_spec) + self._collection.ensure_index(index_spec, + background=background, **index_opts) return self._collection_obj @@ -331,6 +344,8 @@ class QuerySet(object): mongo_query = {} for key, value in query.items(): parts = key.split('__') + indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] + parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None if parts[-1] in operators + match_operators + geo_operators: @@ -368,7 +383,9 @@ class QuerySet(object): "been implemented" % op) elif op not in match_operators: value = {'$' + op: value} - + + for i, part in indices: + parts.insert(i, part) key = '.'.join(parts) if op is None or key not in mongo_query: mongo_query[key] = value @@ -667,11 +684,13 @@ class QuerySet(object): """ key_list = [] for key in keys: + if not key: continue direction = pymongo.ASCENDING if key[0] == '-': direction = pymongo.DESCENDING if key[0] in ('-', '+'): key = key[1:] + key = key.replace('__', '.') key_list.append((key, direction)) self._ordering = key_list @@ -701,8 +720,8 @@ class QuerySet(object): def _transform_update(cls, _doc_cls=None, **update): """Transform an update spec from Django-style format to Mongo format. """ - operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', - 'pull_all'] + operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', + 'pull', 'pull_all'] mongo_update = {} for key, value in update.items(): @@ -728,7 +747,7 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'set', 'unset', 'push', 'pull'): + if op in (None, 'set', 'unset', 'pop', 'push', 'pull'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] @@ -877,7 +896,7 @@ class QuerySet(object): var total = 0.0; var num = 0; db[collection].find(query).forEach(function(doc) { - if (doc[averageField]) { + if (doc[averageField] !== undefined) { total += doc[averageField]; num += 1; } @@ -973,7 +992,8 @@ class QuerySetManager(object): self._collection = db[collection] # owner is the document that contains the QuerySetManager - queryset = QuerySet(owner, self._collection) + queryset_class = owner._meta['queryset_class'] or QuerySet + queryset = queryset_class(owner, self._collection) if self._manager_func: if self._manager_func.func_code.co_argcount == 1: queryset = self._manager_func(queryset) diff --git a/tests/fields.py b/tests/fields.py index 136437b8..8c727196 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -674,7 +674,34 @@ class FieldTest(unittest.TestCase): PutFile.drop_collection() StreamFile.drop_collection() SetFile.drop_collection() + + # Make sure FileField is optional and not required + class DemoFile(Document): + file = FileField() + d = DemoFile.objects.create() + + def test_file_uniqueness(self): + """Ensure that each instance of a FileField is unique + """ + class TestFile(Document): + name = StringField() + file = FileField() + + # First instance + testfile = TestFile() + testfile.name = "Hello, World!" + testfile.file.put('Hello, World!') + testfile.save() + # Second instance + testfiledupe = TestFile() + data = testfiledupe.file.read() # Should be None + + self.assertTrue(testfile.name != testfiledupe.name) + self.assertTrue(testfile.file.read() != data) + + TestFile.drop_collection() + def test_geo_indexes(self): """Ensure that indexes are created automatically for GeoPointFields. """ @@ -693,5 +720,18 @@ class FieldTest(unittest.TestCase): Event.drop_collection() + def test_ensure_unique_default_instances(self): + """Ensure that every field has it's own unique default instance.""" + class D(Document): + data = DictField() + data2 = DictField(default=lambda: {}) + + d1 = D() + d1.data['foo'] = 'bar' + d1.data2['foo'] = 'bar' + d2 = D() + self.assertEqual(d2.data, {}) + self.assertEqual(d2.data2, {}) + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index 8cbd9a40..4491be8c 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -165,8 +165,49 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") + def test_find_array_position(self): + """Ensure that query by array position works. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) + self.assertEqual(len(Blog.objects(tags__0='a')), 1) + self.assertEqual(len(Blog.objects(tags__0='b')), 0) + self.assertEqual(len(Blog.objects(tags__1='a')), 0) + self.assertEqual(len(Blog.objects(tags__1='b')), 1) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + blog = Blog.objects(posts__0__comments__0__name='testa').get() + self.assertEqual(blog, blog1) + + query = Blog.objects(posts__1__comments__1__name='testb') + self.assertEqual(len(query), 2) + + query = Blog.objects(posts__1__comments__1__name='testa') + self.assertEqual(len(query), 0) + + query = Blog.objects(posts__0__comments__1__name='testa') + self.assertEqual(len(query), 0) + + Blog.drop_collection() def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new @@ -288,6 +329,13 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) + + # Test unsafe expressions + person = self.Person(name='Guido van Rossum [.\'Geek\']') + person.save() + + obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first() + self.assertEqual(obj, person) def test_filter_chaining(self): """Ensure filters can be chained together. @@ -664,28 +712,27 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertTrue('db' in post.tags and 'nosql' in post.tags) + tags = post.tags[:-1] + BlogPost.objects.update(pop__tags=1) + post.reload() + self.assertEqual(post.tags, tags) + BlogPost.drop_collection() def test_update_pull(self): """Ensure that the 'pull' update operation works correctly. """ - class Comment(EmbeddedDocument): - content = StringField() - class BlogPost(Document): slug = StringField() - comments = ListField(EmbeddedDocumentField(Comment)) + tags = ListField(StringField()) - comment1 = Comment(content="test1") - comment2 = Comment(content="test2") - - post = BlogPost(slug="test", comments=[comment1, comment2]) + post = BlogPost(slug="test", tags=['code', 'mongodb', 'code']) post.save() - self.assertTrue(comment2 in post.comments) - BlogPost.objects(slug="test").update(pull__comments__content="test2") + BlogPost.objects(slug="test").update(pull__tags="code") post.reload() - self.assertTrue(comment2 not in post.comments) + self.assertTrue('code' not in post.tags) + self.assertEqual(len(post.tags), 1) def test_order_by(self): """Ensure that QuerySets may be ordered. @@ -948,11 +995,14 @@ class QuerySetTest(unittest.TestCase): def test_average(self): """Ensure that field can be averaged correctly. """ + self.Person(name='person', age=0).save() + self.assertEqual(int(self.Person.objects.average('age')), 0) + ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): self.Person(name='test%s' % i, age=age).save() - avg = float(sum(ages)) / len(ages) + avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) self.Person(name='ageless person').save() @@ -985,10 +1035,15 @@ class QuerySetTest(unittest.TestCase): """ class BlogPost(Document): tags = ListField(StringField()) + deleted = BooleanField(default=False) + + @queryset_manager + def objects(doc_cls, queryset): + return queryset(deleted=False) @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music') + return queryset(tags='music', deleted=False) BlogPost.drop_collection() @@ -998,6 +1053,8 @@ class QuerySetTest(unittest.TestCase): post2.save() post3 = BlogPost(tags=['film', 'actors']) post3.save() + post4 = BlogPost(tags=['film', 'actors'], deleted=True) + post4.save() self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) @@ -1307,6 +1364,8 @@ class QTest(unittest.TestCase): def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" + connect(db='mongoenginetest') + class User(Document): pass @@ -1319,5 +1378,26 @@ class QTest(unittest.TestCase): self.assertEqual(Post.objects.filter(created_user=user).count(), 1) self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1) + def test_custom_querysets(self): + """Ensure that custom QuerySet classes may be used. + """ + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class Post(Document): + meta = {'queryset_class': CustomQuerySet} + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + if __name__ == '__main__': unittest.main()