Added support for querying by array position. Closes #36.
This commit is contained in:
parent
1849f75ad0
commit
2af5f3c56e
@ -344,6 +344,8 @@ class QuerySet(object):
|
|||||||
mongo_query = {}
|
mongo_query = {}
|
||||||
for key, value in query.items():
|
for key, value in query.items():
|
||||||
parts = key.split('__')
|
parts = key.split('__')
|
||||||
|
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
|
||||||
|
parts = [part for part in parts if not part.isdigit()]
|
||||||
# Check for an operator and transform to mongo-style if there is
|
# Check for an operator and transform to mongo-style if there is
|
||||||
op = None
|
op = None
|
||||||
if parts[-1] in operators + match_operators + geo_operators:
|
if parts[-1] in operators + match_operators + geo_operators:
|
||||||
@ -381,7 +383,9 @@ class QuerySet(object):
|
|||||||
"been implemented" % op)
|
"been implemented" % op)
|
||||||
elif op not in match_operators:
|
elif op not in match_operators:
|
||||||
value = {'$' + op: value}
|
value = {'$' + op: value}
|
||||||
|
|
||||||
|
for i, part in indices:
|
||||||
|
parts.insert(i, part)
|
||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
if op is None or key not in mongo_query:
|
if op is None or key not in mongo_query:
|
||||||
mongo_query[key] = value
|
mongo_query[key] = value
|
||||||
|
@ -165,8 +165,49 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
person = self.Person.objects.get(age__lt=30)
|
person = self.Person.objects.get(age__lt=30)
|
||||||
self.assertEqual(person.name, "User A")
|
self.assertEqual(person.name, "User A")
|
||||||
|
|
||||||
|
def test_find_array_position(self):
|
||||||
|
"""Ensure that query by array position works.
|
||||||
|
"""
|
||||||
|
class Comment(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Post(EmbeddedDocument):
|
||||||
|
comments = ListField(EmbeddedDocumentField(Comment))
|
||||||
|
|
||||||
|
class Blog(Document):
|
||||||
|
tags = ListField(StringField())
|
||||||
|
posts = ListField(EmbeddedDocumentField(Post))
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
Blog.objects.create(tags=['a', 'b'])
|
||||||
|
self.assertEqual(len(Blog.objects(tags__0='a')), 1)
|
||||||
|
self.assertEqual(len(Blog.objects(tags__0='b')), 0)
|
||||||
|
self.assertEqual(len(Blog.objects(tags__1='a')), 0)
|
||||||
|
self.assertEqual(len(Blog.objects(tags__1='b')), 1)
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
comment1 = Comment(name='testa')
|
||||||
|
comment2 = Comment(name='testb')
|
||||||
|
post1 = Post(comments=[comment1, comment2])
|
||||||
|
post2 = Post(comments=[comment2, comment2])
|
||||||
|
blog1 = Blog.objects.create(posts=[post1, post2])
|
||||||
|
blog2 = Blog.objects.create(posts=[post2, post1])
|
||||||
|
|
||||||
|
blog = Blog.objects(posts__0__comments__0__name='testa').get()
|
||||||
|
self.assertEqual(blog, blog1)
|
||||||
|
|
||||||
|
query = Blog.objects(posts__1__comments__1__name='testb')
|
||||||
|
self.assertEqual(len(query), 2)
|
||||||
|
|
||||||
|
query = Blog.objects(posts__1__comments__1__name='testa')
|
||||||
|
self.assertEqual(len(query), 0)
|
||||||
|
|
||||||
|
query = Blog.objects(posts__0__comments__1__name='testa')
|
||||||
|
self.assertEqual(len(query), 0)
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
def test_get_or_create(self):
|
def test_get_or_create(self):
|
||||||
"""Ensure that ``get_or_create`` returns one result or creates a new
|
"""Ensure that ``get_or_create`` returns one result or creates a new
|
||||||
|
Loading…
x
Reference in New Issue
Block a user