Merge branch 'master' of git://github.com/flosch/mongoengine

This commit is contained in:
Harry Marr 2010-05-22 15:59:41 +01:00
commit b8e9790de3
5 changed files with 97 additions and 5 deletions

View File

@ -172,7 +172,7 @@ custom manager methods as you like::
@queryset_manager @queryset_manager
def live_posts(doc_cls, queryset): def live_posts(doc_cls, queryset):
return queryset.filter(published=True) return queryset(published=True).filter(published=True)
BlogPost(title='test1', published=False).save() BlogPost(title='test1', published=False).save()
BlogPost(title='test2', published=True).save() BlogPost(title='test2', published=True).save()

View File

@ -0,0 +1,45 @@
from django.http import Http404
from mongoengine.queryset import QuerySet
from mongoengine.base import BaseDocument
def _get_queryset(cls):
"""Inspired by django.shortcuts.*"""
if isinstance(cls, QuerySet):
return cls
else:
return cls.objects
def get_document_or_404(cls, *args, **kwargs):
"""
Uses get() to return an document, or raises a Http404 exception if the document
does not exist.
cls may be a Document or QuerySet object. All other passed
arguments and keyword arguments are used in the get() query.
Note: Like with get(), an MultipleObjectsReturned will be raised if more than one
object is found.
Inspired by django.shortcuts.*
"""
queryset = _get_queryset(cls)
try:
return queryset.get(*args, **kwargs)
except queryset._document.DoesNotExist:
raise Http404('No %s matches the given query.' % queryset._document._class_name)
def get_list_or_404(cls, *args, **kwargs):
"""
Uses filter() to return a list of documents, or raise a Http404 exception if
the list is empty.
cls may be a Document or QuerySet object. All other passed
arguments and keyword arguments are used in the filter() query.
Inspired by django.shortcuts.*
"""
queryset = _get_queryset(cls)
obj_list = list(queryset.filter(*args, **kwargs))
if not obj_list:
raise Http404('No %s matches the given query.' % queryset._document._class_name)
return obj_list

View File

@ -348,6 +348,11 @@ class DictField(BaseField):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
def __init__(self, basecls=None, *args, **kwargs):
self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField)
super(DictField, self).__init__(*args, **kwargs)
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
@ -360,7 +365,7 @@ class DictField(BaseField):
'contain "." or "$" characters') 'contain "." or "$" characters')
def lookup_member(self, member_name): def lookup_member(self, member_name):
return BaseField(db_field=member_name) return self.basecls(db_field=member_name)
class GeoLocationField(DictField): class GeoLocationField(DictField):
"""Supports geobased fields""" """Supports geobased fields"""

View File

@ -59,8 +59,12 @@ class Q(object):
def _combine(self, other, op): def _combine(self, other, op):
obj = Q() obj = Q()
obj.query = ['('] + copy.deepcopy(self.query) + [op] if not other.query[0]:
obj.query += copy.deepcopy(other.query) + [')'] return self
if self.query[0]:
obj.query = ['('] + copy.deepcopy(self.query) + [op] + copy.deepcopy(other.query) + [')']
else:
obj.query = copy.deepcopy(other.query)
return obj return obj
def __or__(self, other): def __or__(self, other):

View File

@ -1063,6 +1063,29 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_dict_with_custom_baseclass(self):
"""Ensure DictField working with custom base clases.
"""
class Test(Document):
testdict = DictField()
t = Test(testdict={'f': 'Value'})
t.save()
self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0)
self.assertEqual(len(Test.objects(testdict__f='Value')), 1)
Test.drop_collection()
class Test(Document):
testdict = DictField(basecls=StringField)
t = Test(testdict={'f': 'Value'})
t.save()
self.assertEqual(len(Test.objects(testdict__f='Value')), 1)
self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1)
Test.drop_collection()
def test_bulk(self): def test_bulk(self):
"""Ensure bulk querying by object id returns a proper dict. """Ensure bulk querying by object id returns a proper dict.
""" """
@ -1136,5 +1159,20 @@ class QTest(unittest.TestCase):
self.assertEqual(q._item_query_as_js(item, test_scope, 0), js) self.assertEqual(q._item_query_as_js(item, test_scope, 0), js)
self.assertEqual(scope, test_scope) self.assertEqual(scope, test_scope)
def test_empty_q(self):
"""Ensure that empty Q objects won't hurt.
"""
q1 = Q()
q2 = Q(age__gte=18)
q3 = Q()
q4 = Q(name='test')
q5 = Q()
query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')']
self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query)
query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')']
self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()