This commit: 1. Formats all of our existing code using `black`. 2. Adds a note about using `black` to `CONTRIBUTING.rst`. 3. Runs `black --check` as part of CI (failing builds that aren't properly formatted).
468 lines
16 KiB
Python
468 lines
16 KiB
Python
import unittest
|
|
|
|
from mongoengine import *
|
|
from mongoengine.queryset import QueryFieldList
|
|
|
|
__all__ = ("QueryFieldListTest", "OnlyExcludeAllTest")
|
|
|
|
|
|
class QueryFieldListTest(unittest.TestCase):
|
|
def test_empty(self):
|
|
q = QueryFieldList()
|
|
self.assertFalse(q)
|
|
|
|
q = QueryFieldList(always_include=["_cls"])
|
|
self.assertFalse(q)
|
|
|
|
def test_include_include(self):
|
|
q = QueryFieldList()
|
|
q += QueryFieldList(
|
|
fields=["a", "b"], value=QueryFieldList.ONLY, _only_called=True
|
|
)
|
|
self.assertEqual(q.as_dict(), {"a": 1, "b": 1})
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"a": 1, "b": 1, "c": 1})
|
|
|
|
def test_include_exclude(self):
|
|
q = QueryFieldList()
|
|
q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"a": 1, "b": 1})
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE)
|
|
self.assertEqual(q.as_dict(), {"a": 1})
|
|
|
|
def test_exclude_exclude(self):
|
|
q = QueryFieldList()
|
|
q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE)
|
|
self.assertEqual(q.as_dict(), {"a": 0, "b": 0})
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE)
|
|
self.assertEqual(q.as_dict(), {"a": 0, "b": 0, "c": 0})
|
|
|
|
def test_exclude_include(self):
|
|
q = QueryFieldList()
|
|
q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE)
|
|
self.assertEqual(q.as_dict(), {"a": 0, "b": 0})
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"c": 1})
|
|
|
|
def test_always_include(self):
|
|
q = QueryFieldList(always_include=["x", "y"])
|
|
q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE)
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1})
|
|
|
|
def test_reset(self):
|
|
q = QueryFieldList(always_include=["x", "y"])
|
|
q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE)
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1})
|
|
q.reset()
|
|
self.assertFalse(q)
|
|
q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
|
|
self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "b": 1, "c": 1})
|
|
|
|
def test_using_a_slice(self):
|
|
q = QueryFieldList()
|
|
q += QueryFieldList(fields=["a"], value={"$slice": 5})
|
|
self.assertEqual(q.as_dict(), {"a": {"$slice": 5}})
|
|
|
|
|
|
class OnlyExcludeAllTest(unittest.TestCase):
|
|
def setUp(self):
|
|
connect(db="mongoenginetest")
|
|
|
|
class Person(Document):
|
|
name = StringField()
|
|
age = IntField()
|
|
meta = {"allow_inheritance": True}
|
|
|
|
Person.drop_collection()
|
|
self.Person = Person
|
|
|
|
def test_mixing_only_exclude(self):
|
|
class MyDoc(Document):
|
|
a = StringField()
|
|
b = StringField()
|
|
c = StringField()
|
|
d = StringField()
|
|
e = StringField()
|
|
f = StringField()
|
|
|
|
include = ["a", "b", "c", "d", "e"]
|
|
exclude = ["d", "e"]
|
|
only = ["b", "c"]
|
|
|
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
|
self.assertEqual(
|
|
qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}
|
|
)
|
|
qs = qs.only(*only)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
|
|
qs = qs.exclude(*exclude)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
|
|
|
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
|
qs = qs.exclude(*exclude)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1})
|
|
qs = qs.only(*only)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
|
|
|
|
qs = MyDoc.objects.exclude(*exclude)
|
|
qs = qs.fields(**{i: 1 for i in include})
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1})
|
|
qs = qs.only(*only)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
|
|
|
|
def test_slicing(self):
|
|
class MyDoc(Document):
|
|
a = ListField()
|
|
b = ListField()
|
|
c = ListField()
|
|
d = ListField()
|
|
e = ListField()
|
|
f = ListField()
|
|
|
|
include = ["a", "b", "c", "d", "e"]
|
|
exclude = ["d", "e"]
|
|
only = ["b", "c"]
|
|
|
|
qs = MyDoc.objects.fields(**{i: 1 for i in include})
|
|
qs = qs.exclude(*exclude)
|
|
qs = qs.only(*only)
|
|
qs = qs.fields(slice__b=5)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": 1})
|
|
|
|
qs = qs.fields(slice__c=[5, 1])
|
|
self.assertEqual(
|
|
qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": {"$slice": [5, 1]}}
|
|
)
|
|
|
|
qs = qs.exclude("c")
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}})
|
|
|
|
def test_mix_slice_with_other_fields(self):
|
|
class MyDoc(Document):
|
|
a = ListField()
|
|
b = ListField()
|
|
c = ListField()
|
|
|
|
qs = MyDoc.objects.fields(a=1, b=0, slice__c=2)
|
|
self.assertEqual(qs._loaded_fields.as_dict(), {"c": {"$slice": 2}, "a": 1})
|
|
|
|
def test_only(self):
|
|
"""Ensure that QuerySet.only only returns the requested fields.
|
|
"""
|
|
person = self.Person(name="test", age=25)
|
|
person.save()
|
|
|
|
obj = self.Person.objects.only("name").get()
|
|
self.assertEqual(obj.name, person.name)
|
|
self.assertEqual(obj.age, None)
|
|
|
|
obj = self.Person.objects.only("age").get()
|
|
self.assertEqual(obj.name, None)
|
|
self.assertEqual(obj.age, person.age)
|
|
|
|
obj = self.Person.objects.only("name", "age").get()
|
|
self.assertEqual(obj.name, person.name)
|
|
self.assertEqual(obj.age, person.age)
|
|
|
|
obj = self.Person.objects.only(*("id", "name")).get()
|
|
self.assertEqual(obj.name, person.name)
|
|
self.assertEqual(obj.age, None)
|
|
|
|
# Check polymorphism still works
|
|
class Employee(self.Person):
|
|
salary = IntField(db_field="wage")
|
|
|
|
employee = Employee(name="test employee", age=40, salary=30000)
|
|
employee.save()
|
|
|
|
obj = self.Person.objects(id=employee.id).only("age").get()
|
|
self.assertIsInstance(obj, Employee)
|
|
|
|
# Check field names are looked up properly
|
|
obj = Employee.objects(id=employee.id).only("salary").get()
|
|
self.assertEqual(obj.salary, employee.salary)
|
|
self.assertEqual(obj.name, None)
|
|
|
|
def test_only_with_subfields(self):
|
|
class User(EmbeddedDocument):
|
|
name = StringField()
|
|
email = StringField()
|
|
|
|
class Comment(EmbeddedDocument):
|
|
title = StringField()
|
|
text = StringField()
|
|
|
|
class VariousData(EmbeddedDocument):
|
|
some = BooleanField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
author = EmbeddedDocumentField(User)
|
|
comments = ListField(EmbeddedDocumentField(Comment))
|
|
various = MapField(field=EmbeddedDocumentField(VariousData))
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
post = BlogPost(
|
|
content="Had a good coffee today...",
|
|
various={"test_dynamic": {"some": True}},
|
|
)
|
|
post.author = User(name="Test User")
|
|
post.comments = [
|
|
Comment(title="I aggree", text="Great post!"),
|
|
Comment(title="Coffee", text="I hate coffee"),
|
|
]
|
|
post.save()
|
|
|
|
obj = BlogPost.objects.only("author.name").get()
|
|
self.assertEqual(obj.content, None)
|
|
self.assertEqual(obj.author.email, None)
|
|
self.assertEqual(obj.author.name, "Test User")
|
|
self.assertEqual(obj.comments, [])
|
|
|
|
obj = BlogPost.objects.only("various.test_dynamic.some").get()
|
|
self.assertEqual(obj.various["test_dynamic"].some, True)
|
|
|
|
obj = BlogPost.objects.only("content", "comments.title").get()
|
|
self.assertEqual(obj.content, "Had a good coffee today...")
|
|
self.assertEqual(obj.author, None)
|
|
self.assertEqual(obj.comments[0].title, "I aggree")
|
|
self.assertEqual(obj.comments[1].title, "Coffee")
|
|
self.assertEqual(obj.comments[0].text, None)
|
|
self.assertEqual(obj.comments[1].text, None)
|
|
|
|
obj = BlogPost.objects.only("comments").get()
|
|
self.assertEqual(obj.content, None)
|
|
self.assertEqual(obj.author, None)
|
|
self.assertEqual(obj.comments[0].title, "I aggree")
|
|
self.assertEqual(obj.comments[1].title, "Coffee")
|
|
self.assertEqual(obj.comments[0].text, "Great post!")
|
|
self.assertEqual(obj.comments[1].text, "I hate coffee")
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
def test_exclude(self):
|
|
class User(EmbeddedDocument):
|
|
name = StringField()
|
|
email = StringField()
|
|
|
|
class Comment(EmbeddedDocument):
|
|
title = StringField()
|
|
text = StringField()
|
|
|
|
class BlogPost(Document):
|
|
content = StringField()
|
|
author = EmbeddedDocumentField(User)
|
|
comments = ListField(EmbeddedDocumentField(Comment))
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
post = BlogPost(content="Had a good coffee today...")
|
|
post.author = User(name="Test User")
|
|
post.comments = [
|
|
Comment(title="I aggree", text="Great post!"),
|
|
Comment(title="Coffee", text="I hate coffee"),
|
|
]
|
|
post.save()
|
|
|
|
obj = BlogPost.objects.exclude("author", "comments.text").get()
|
|
self.assertEqual(obj.author, None)
|
|
self.assertEqual(obj.content, "Had a good coffee today...")
|
|
self.assertEqual(obj.comments[0].title, "I aggree")
|
|
self.assertEqual(obj.comments[0].text, None)
|
|
|
|
BlogPost.drop_collection()
|
|
|
|
def test_exclude_only_combining(self):
|
|
class Attachment(EmbeddedDocument):
|
|
name = StringField()
|
|
content = StringField()
|
|
|
|
class Email(Document):
|
|
sender = StringField()
|
|
to = StringField()
|
|
subject = StringField()
|
|
body = StringField()
|
|
content_type = StringField()
|
|
attachments = ListField(EmbeddedDocumentField(Attachment))
|
|
|
|
Email.drop_collection()
|
|
email = Email(
|
|
sender="me",
|
|
to="you",
|
|
subject="From Russia with Love",
|
|
body="Hello!",
|
|
content_type="text/plain",
|
|
)
|
|
email.attachments = [
|
|
Attachment(name="file1.doc", content="ABC"),
|
|
Attachment(name="file2.doc", content="XYZ"),
|
|
]
|
|
email.save()
|
|
|
|
obj = Email.objects.exclude("content_type").exclude("body").get()
|
|
self.assertEqual(obj.sender, "me")
|
|
self.assertEqual(obj.to, "you")
|
|
self.assertEqual(obj.subject, "From Russia with Love")
|
|
self.assertEqual(obj.body, None)
|
|
self.assertEqual(obj.content_type, None)
|
|
|
|
obj = Email.objects.only("sender", "to").exclude("body", "sender").get()
|
|
self.assertEqual(obj.sender, None)
|
|
self.assertEqual(obj.to, "you")
|
|
self.assertEqual(obj.subject, None)
|
|
self.assertEqual(obj.body, None)
|
|
self.assertEqual(obj.content_type, None)
|
|
|
|
obj = (
|
|
Email.objects.exclude("attachments.content")
|
|
.exclude("body")
|
|
.only("to", "attachments.name")
|
|
.get()
|
|
)
|
|
self.assertEqual(obj.attachments[0].name, "file1.doc")
|
|
self.assertEqual(obj.attachments[0].content, None)
|
|
self.assertEqual(obj.sender, None)
|
|
self.assertEqual(obj.to, "you")
|
|
self.assertEqual(obj.subject, None)
|
|
self.assertEqual(obj.body, None)
|
|
self.assertEqual(obj.content_type, None)
|
|
|
|
Email.drop_collection()
|
|
|
|
def test_all_fields(self):
|
|
class Email(Document):
|
|
sender = StringField()
|
|
to = StringField()
|
|
subject = StringField()
|
|
body = StringField()
|
|
content_type = StringField()
|
|
|
|
Email.drop_collection()
|
|
|
|
email = Email(
|
|
sender="me",
|
|
to="you",
|
|
subject="From Russia with Love",
|
|
body="Hello!",
|
|
content_type="text/plain",
|
|
)
|
|
email.save()
|
|
|
|
obj = (
|
|
Email.objects.exclude("content_type", "body")
|
|
.only("to", "body")
|
|
.all_fields()
|
|
.get()
|
|
)
|
|
self.assertEqual(obj.sender, "me")
|
|
self.assertEqual(obj.to, "you")
|
|
self.assertEqual(obj.subject, "From Russia with Love")
|
|
self.assertEqual(obj.body, "Hello!")
|
|
self.assertEqual(obj.content_type, "text/plain")
|
|
|
|
Email.drop_collection()
|
|
|
|
def test_slicing_fields(self):
|
|
"""Ensure that query slicing an array works.
|
|
"""
|
|
|
|
class Numbers(Document):
|
|
n = ListField(IntField())
|
|
|
|
Numbers.drop_collection()
|
|
|
|
numbers = Numbers(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1])
|
|
numbers.save()
|
|
|
|
# first three
|
|
numbers = Numbers.objects.fields(slice__n=3).get()
|
|
self.assertEqual(numbers.n, [0, 1, 2])
|
|
|
|
# last three
|
|
numbers = Numbers.objects.fields(slice__n=-3).get()
|
|
self.assertEqual(numbers.n, [-3, -2, -1])
|
|
|
|
# skip 2, limit 3
|
|
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
|
|
self.assertEqual(numbers.n, [2, 3, 4])
|
|
|
|
# skip to fifth from last, limit 4
|
|
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
|
|
self.assertEqual(numbers.n, [-5, -4, -3, -2])
|
|
|
|
# skip to fifth from last, limit 10
|
|
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
|
|
self.assertEqual(numbers.n, [-5, -4, -3, -2, -1])
|
|
|
|
# skip to fifth from last, limit 10 dict method
|
|
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
|
|
self.assertEqual(numbers.n, [-5, -4, -3, -2, -1])
|
|
|
|
def test_slicing_nested_fields(self):
|
|
"""Ensure that query slicing an embedded array works.
|
|
"""
|
|
|
|
class EmbeddedNumber(EmbeddedDocument):
|
|
n = ListField(IntField())
|
|
|
|
class Numbers(Document):
|
|
embedded = EmbeddedDocumentField(EmbeddedNumber)
|
|
|
|
Numbers.drop_collection()
|
|
|
|
numbers = Numbers()
|
|
numbers.embedded = EmbeddedNumber(n=[0, 1, 2, 3, 4, 5, -5, -4, -3, -2, -1])
|
|
numbers.save()
|
|
|
|
# first three
|
|
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
|
|
self.assertEqual(numbers.embedded.n, [0, 1, 2])
|
|
|
|
# last three
|
|
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
|
|
self.assertEqual(numbers.embedded.n, [-3, -2, -1])
|
|
|
|
# skip 2, limit 3
|
|
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
|
|
self.assertEqual(numbers.embedded.n, [2, 3, 4])
|
|
|
|
# skip to fifth from last, limit 4
|
|
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
|
|
self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2])
|
|
|
|
# skip to fifth from last, limit 10
|
|
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
|
|
self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1])
|
|
|
|
# skip to fifth from last, limit 10 dict method
|
|
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
|
|
self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1])
|
|
|
|
def test_exclude_from_subclasses_docs(self):
|
|
class Base(Document):
|
|
username = StringField()
|
|
|
|
meta = {"allow_inheritance": True}
|
|
|
|
class Anon(Base):
|
|
anon = BooleanField()
|
|
|
|
class User(Base):
|
|
password = StringField()
|
|
wibble = StringField()
|
|
|
|
Base.drop_collection()
|
|
User(username="mongodb", password="secret").save()
|
|
|
|
user = Base.objects().exclude("password", "wibble").first()
|
|
self.assertEqual(user.password, None)
|
|
|
|
self.assertRaises(LookUpError, Base.objects.exclude, "made_up")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|