Merge branch 'dev' into feature/update_lists

Conflicts:
	mongoengine/queryset.py
This commit is contained in:
Alistair Roche 2011-05-24 11:33:44 +01:00
commit fe5111743d
3 changed files with 120 additions and 13 deletions

View File

@ -298,8 +298,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided # Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'): manager = attrs.get('objects', QuerySetManager())
new_class.objects = QuerySetManager() if hasattr(manager, 'queryset_class'):
meta['queryset_class'] = manager.queryset_class
new_class.objects = manager
user_indexes = [QuerySet._build_index_spec(new_class, spec) user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes for spec in meta['indexes']] + base_indexes

View File

@ -428,8 +428,6 @@ class QuerySet(object):
querying collection querying collection
:param query: Django-style query keyword arguments :param query: Django-style query keyword arguments
""" """
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query) query = Q(**query)
if q_obj: if q_obj:
query &= q_obj query &= q_obj
@ -524,9 +522,14 @@ class QuerySet(object):
fields = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit(): if field_name.isdigit():
fields.append(field_name) try:
field = field.field field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue continue
if field is None: if field is None:
# Look up first field from the document # Look up first field from the document
@ -1075,11 +1078,13 @@ class QuerySet(object):
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [] parts = []
for field in fields: for field in fields:
if isinstance(field, str): if isinstance(field, str):
parts.append(field) parts.append(field)
else: else:
parts.append(field.db_field) parts.append(field.db_field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
@ -1315,8 +1320,11 @@ class QuerySet(object):
class QuerySetManager(object): class QuerySetManager(object):
def __init__(self, manager_func=None): get_queryset = None
self._manager_func = manager_func
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {} self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
@ -1360,11 +1368,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)]) queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func: if self.get_queryset:
if self._manager_func.func_code.co_argcount == 1: if self.get_queryset.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self.get_queryset(queryset)
else: else:
queryset = self._manager_func(owner, queryset) queryset = self.get_queryset(owner, queryset)
return queryset return queryset

View File

@ -5,8 +5,9 @@ import unittest
import pymongo import pymongo
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, QuerySetManager,
DoesNotExist, QueryFieldList) MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import * from mongoengine import *
@ -211,6 +212,55 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
def test_update_array_position(self):
"""Ensure that updating by array position works.
Check update() and update_one() can take syntax like:
set__posts__1__comments__1__name="testc"
Check that it only works for ListFields.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
blog = Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
self.assertEqual(len(testc_blogs), 2)
Blog.drop_collection()
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update only the first blog returned by the query
blog = Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
self.assertEqual(len(testc_blogs), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
Blog.drop_collection()
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.
@ -1737,6 +1787,53 @@ class QuerySetTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
def test_custom_querysets_set_manager_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class CustomQuerySetManager(QuerySetManager):
queryset_class = CustomQuerySet
class Post(Document):
objects = CustomQuerySetManager()
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
def test_custom_querysets_managers_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySetManager(QuerySetManager):
@staticmethod
def get_queryset(doc_cls, queryset):
return queryset(is_published=True)
class Post(Document):
is_published = BooleanField(default=False)
published = CustomQuerySetManager()
Post.drop_collection()
Post().save()
Post(is_published=True).save()
self.assertEquals(Post.objects.count(), 2)
self.assertEquals(Post.published.count(), 1)
Post.drop_collection()
def test_call_after_limits_set(self): def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works """Ensure that re-filtering after slicing works
""" """