diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 4447f7da..690569a9 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -172,7 +172,7 @@ custom manager methods as you like:: @queryset_manager def live_posts(doc_cls, queryset): - return queryset.filter(published=True) + return queryset(published=True).filter(published=True) BlogPost(title='test1', published=False).save() BlogPost(title='test2', published=True).save() diff --git a/mongoengine/django/shortcuts.py b/mongoengine/django/shortcuts.py new file mode 100644 index 00000000..29bc17a8 --- /dev/null +++ b/mongoengine/django/shortcuts.py @@ -0,0 +1,45 @@ +from django.http import Http404 +from mongoengine.queryset import QuerySet +from mongoengine.base import BaseDocument + +def _get_queryset(cls): + """Inspired by django.shortcuts.*""" + if isinstance(cls, QuerySet): + return cls + else: + return cls.objects + +def get_document_or_404(cls, *args, **kwargs): + """ + Uses get() to return an document, or raises a Http404 exception if the document + does not exist. + + cls may be a Document or QuerySet object. All other passed + arguments and keyword arguments are used in the get() query. + + Note: Like with get(), an MultipleObjectsReturned will be raised if more than one + object is found. + + Inspired by django.shortcuts.* + """ + queryset = _get_queryset(cls) + try: + return queryset.get(*args, **kwargs) + except queryset._document.DoesNotExist: + raise Http404('No %s matches the given query.' % queryset._document._class_name) + +def get_list_or_404(cls, *args, **kwargs): + """ + Uses filter() to return a list of documents, or raise a Http404 exception if + the list is empty. + + cls may be a Document or QuerySet object. All other passed + arguments and keyword arguments are used in the filter() query. + + Inspired by django.shortcuts.* + """ + queryset = _get_queryset(cls) + obj_list = list(queryset.filter(*args, **kwargs)) + if not obj_list: + raise Http404('No %s matches the given query.' % queryset._document._class_name) + return obj_list diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 55e5addb..75008d6f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -348,6 +348,11 @@ class DictField(BaseField): .. versionadded:: 0.3 """ + def __init__(self, basecls=None, *args, **kwargs): + self.basecls = basecls or BaseField + assert issubclass(self.basecls, BaseField) + super(DictField, self).__init__(*args, **kwargs) + def validate(self, value): """Make sure that a list of valid fields is being used. """ @@ -360,7 +365,7 @@ class DictField(BaseField): 'contain "." or "$" characters') def lookup_member(self, member_name): - return BaseField(db_field=member_name) + return self.basecls(db_field=member_name) class GeoLocationField(DictField): """Supports geobased fields""" diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 396a745c..72a474d7 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -59,8 +59,12 @@ class Q(object): def _combine(self, other, op): obj = Q() - obj.query = ['('] + copy.deepcopy(self.query) + [op] - obj.query += copy.deepcopy(other.query) + [')'] + if not other.query[0]: + return self + if self.query[0]: + obj.query = ['('] + copy.deepcopy(self.query) + [op] + copy.deepcopy(other.query) + [')'] + else: + obj.query = copy.deepcopy(other.query) return obj def __or__(self, other): @@ -313,7 +317,7 @@ class QuerySet(object): op = None if parts[-1] in operators + match_operators: op = parts.pop() - + if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) diff --git a/tests/queryset.py b/tests/queryset.py index aba3bc7d..9daa73ec 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1063,6 +1063,29 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() + def test_dict_with_custom_baseclass(self): + """Ensure DictField working with custom base clases. + """ + class Test(Document): + testdict = DictField() + + t = Test(testdict={'f': 'Value'}) + t.save() + + self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0) + self.assertEqual(len(Test.objects(testdict__f='Value')), 1) + Test.drop_collection() + + class Test(Document): + testdict = DictField(basecls=StringField) + + t = Test(testdict={'f': 'Value'}) + t.save() + + self.assertEqual(len(Test.objects(testdict__f='Value')), 1) + self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) + Test.drop_collection() + def test_bulk(self): """Ensure bulk querying by object id returns a proper dict. """ @@ -1136,5 +1159,20 @@ class QTest(unittest.TestCase): self.assertEqual(q._item_query_as_js(item, test_scope, 0), js) self.assertEqual(scope, test_scope) + def test_empty_q(self): + """Ensure that empty Q objects won't hurt. + """ + q1 = Q() + q2 = Q(age__gte=18) + q3 = Q() + q4 = Q(name='test') + q5 = Q() + + query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')'] + self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query) + + query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')'] + self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query) + if __name__ == '__main__': unittest.main()