diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 2f800fc5..6bca1b2a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -39,8 +39,25 @@ class StringField(BaseField): def lookup_member(self, member_name): return None + def prepare_query_value(self, op, value): + if not isinstance(op, basestring): + return value -class URLField(BaseField): + if op.lstrip('i') in ('startswith', 'endswith', 'contains'): + flags = 0 + if op.startswith('i'): + flags = re.IGNORECASE + op = op.lstrip('i') + + regex = r'%s' + if op == 'startswith': + regex = r'^%s' + elif op == 'endswith': + regex = r'%s$' + value = re.compile(regex % value, flags) + return value + +class URLField(StringField): """A field that validates input as a URL. """ diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 9592236d..4d8fb7ff 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -275,13 +275,15 @@ class QuerySet(object): """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith'] mongo_query = {} for key, value in query.items(): parts = key.split('__') # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in operators: + if parts[-1] in operators + match_operators: op = parts.pop() if _doc_cls: @@ -291,13 +293,15 @@ class QuerySet(object): # Convert value to proper value field = fields[-1] - if op in (None, 'ne', 'gt', 'gte', 'lt', 'lte'): + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte'] + singular_ops += match_operators + if op in singular_ops: value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] - if op: + if op and op not in match_operators: value = {'$' + op: value} key = '.'.join(parts) diff --git a/tests/queryset.py b/tests/queryset.py index d287efad..5b434e95 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -186,6 +186,41 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") + def test_regex_query_shortcuts(self): + """Ensure that contains, startswith, endswith, etc work. + """ + person = self.Person(name='Guido van Rossum') + person.save() + + # Test contains + obj = self.Person.objects(name__contains='van').first() + self.assertEqual(obj, person) + obj = self.Person.objects(name__contains='Van').first() + self.assertEqual(obj, None) + + # Test icontains + obj = self.Person.objects(name__icontains='Van').first() + self.assertEqual(obj, person) + + # Test startswith + obj = self.Person.objects(name__startswith='Guido').first() + self.assertEqual(obj, person) + obj = self.Person.objects(name__startswith='guido').first() + self.assertEqual(obj, None) + + # Test istartswith + obj = self.Person.objects(name__istartswith='guido').first() + self.assertEqual(obj, person) + + # Test endswith + obj = self.Person.objects(name__endswith='Rossum').first() + self.assertEqual(obj, person) + obj = self.Person.objects(name__endswith='rossuM').first() + self.assertEqual(obj, None) + + # Test iendswith + obj = self.Person.objects(name__iendswith='rossuM').first() + self.assertEqual(obj, person) def test_filter_chaining(self): """Ensure filters can be chained together.