Merge branch 'dev' into feature/update_lists

This commit is contained in:
Alistair Roche 2011-05-24 08:58:38 +01:00
commit 13935fc335
7 changed files with 476 additions and 102 deletions

View File

@ -7,16 +7,26 @@ import pymongo
import pymongo.objectid
_document_registry = {}
def get_document(name):
return _document_registry[name]
class NotRegistered(Exception):
pass
class ValidationError(Exception):
pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
@ -295,6 +305,30 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes
unique_indexes = cls._unique_with_indexes(new_class)
new_class._meta['unique_indexes'] = unique_indexes
for field_name, field in new_class._fields.items():
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
@classmethod
def _unique_with_indexes(cls, new_class, namespace=""):
unique_indexes = []
for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints
@ -320,28 +354,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
unique_fields += unique_with
# Add the new index to the list
index = [(f, pymongo.ASCENDING) for f in unique_fields]
index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields]
unique_indexes.append(index)
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
# Grab any embedded document field unique indexes
if field.__class__.__name__ == "EmbeddedDocumentField":
field_namespace = "%s." % field_name
unique_indexes += cls._unique_with_indexes(field.document_type,
field_namespace)
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
new_class._meta['unique_indexes'] = unique_indexes
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
return unique_indexes
class BaseDocument(object):

View File

@ -56,7 +56,7 @@ class Document(BaseDocument):
__metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True):
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@ -68,16 +68,26 @@ class Document(BaseDocument):
:param force_insert: only try to create a new document, don't allow
updates of existing documents
:param validate: validates the document; set to ``False`` to skip.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
"""
if validate:
self.validate()
if not write_options:
write_options = {}
doc = self.to_mongo()
try:
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe)
object_id = collection.insert(doc, safe=safe, **write_options)
else:
object_id = collection.save(doc, safe=safe)
object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):

View File

@ -522,6 +522,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3
"""
@ -601,6 +604,7 @@ class GridFSProxy(object):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
def __getattr__(self, name):
obj = self.get()
@ -614,8 +618,12 @@ class GridFSProxy(object):
def get(self, id=None):
if id:
self.grid_id = id
if self.grid_id is None:
return None
try:
return self.fs.get(id or self.grid_id)
if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except:
# File has been deleted
return None
@ -645,9 +653,9 @@ class GridFSProxy(object):
self.grid_id = self.newfile._id
self.newfile.writelines(lines)
def read(self):
def read(self, size=-1):
try:
return self.get().read()
return self.get().read(size)
except:
return None
@ -655,6 +663,7 @@ class GridFSProxy(object):
# Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id)
self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs):
self.delete()

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re
import copy
import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True
EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]):
self.direction = direction
def __init__(self, fields=[], value=ONLY, always_include=[]):
self.value = value
self.fields = set(fields)
self.always_include = set(always_include)
def as_dict(self):
return dict((field, self.direction) for field in self.fields)
return dict((field, self.value) for field in self.fields)
def __add__(self, f):
if not self.fields:
self.fields = f.fields
self.direction = f.direction
elif self.direction is self.ONLY and f.direction is self.ONLY:
self.value = f.value
elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE:
elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE:
elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY:
self.direction = self.ONLY
elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.value = self.ONLY
self.fields = f.fields - self.fields
if self.always_include:
if self.direction is self.ONLY and self.fields:
if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include)
else:
self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self):
self.fields = set([])
self.direction = self.ONLY
self.value = self.ONLY
def __nonzero__(self):
return bool(self.fields)
@ -551,7 +552,7 @@ class QuerySet(object):
return '.'.join(parts)
@classmethod
def _transform_query(cls, _doc_cls=None, **query):
def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
@ -646,7 +647,7 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name)
def get_or_create(self, *q_objs, **query):
def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises
@ -656,6 +657,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3
"""
defaults = query.get('defaults', {})
@ -667,7 +672,7 @@ class QuerySet(object):
if count == 0:
query.update(defaults)
doc = self._document(**query)
doc.save()
doc.save(write_options=write_options)
return doc, True
elif count == 1:
return self.first(), False
@ -893,10 +898,8 @@ class QuerySet(object):
.. versionadded:: 0.3
"""
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY)
return self
fields = dict([(f, QueryFieldList.ONLY) for f in fields])
return self.fields(**fields)
def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. ::
@ -905,8 +908,44 @@ class QuerySet(object):
:param fields: fields to exclude
"""
fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE)
self._loaded_fields += QueryFieldList(fields, value=value)
return self
def all_fields(self):
@ -1062,22 +1101,27 @@ class QuerySet(object):
return mongo_update
def update(self, safe_update=True, upsert=False, **update):
def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2
"""
if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, err:
@ -1086,22 +1130,27 @@ class QuerySet(object):
raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update):
def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments
.. versionadded:: 0.2
"""
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update)
try:
# Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True'
if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update)
upsert=upsert, safe=safe_update,
**write_options)
else:
# Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update,

View File

@ -1,11 +1,23 @@
import unittest
from datetime import datetime
import pymongo
import pickle
from mongoengine import *
from mongoengine.base import BaseField
from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleTest(Document):
number = IntField()
string = StringField()
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
class DocumentTest(unittest.TestCase):
def setUp(self):
@ -200,6 +212,22 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection()
self.assertFalse(collection in self.db.collection_names())
def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
"""
class Person(Document):
name = StringField(primary_key=True)
meta = {'collection': 'app'}
user = Person(name="Test User")
user.save()
user_obj = Person.objects[0]
self.assertEqual(user_obj.name, "Test User")
Person.drop_collection()
def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections.
"""
@ -334,6 +362,10 @@ class DocumentTest(unittest.TestCase):
post2 = BlogPost(title='test2', slug='test')
self.assertRaises(OperationError, post2.save)
def test_unique_with(self):
"""Ensure that unique_with constraints are applied to fields.
"""
class Date(EmbeddedDocument):
year = IntField(db_field='yr')
@ -357,6 +389,63 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection()
def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField()
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_with_embedded_document_and_embedded_unique(self):
"""Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField(unique_with='sub.year')
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
# Now there will be two docs with the same title and year
post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by
meta.indexes.
@ -798,6 +887,25 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection()
BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self):
class A(Document):
pass
class B(A):
foo = BooleanField(unique=True)
A.drop_collection()
B.drop_collection()
A().save()
A().save()
B(foo=True).save()
self.assertEquals(A.objects.count(), 2)
self.assertEquals(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def tearDown(self):
self.Person.drop_collection()
@ -850,6 +958,43 @@ class DocumentTest(unittest.TestCase):
self.assertTrue(u1 in all_user_set )
def test_picklable(self):
pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEquals(resurrected, pickle_doc)
resurrected.string = "Working"
resurrected.save()
pickle_doc.reload()
self.assertEquals(resurrected, pickle_doc)
def test_write_options(self):
"""Test that passing write_options works"""
self.Person.drop_collection()
write_options = {"fsync": True}
author, created = self.Person.objects.get_or_create(
name='Test User', write_options=write_options)
author.save(write_options=write_options)
self.Person.objects.update(set__name='Ross', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Ross')
self.Person.objects.update_one(set__name='Test User', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Test User')
if __name__ == '__main__':
unittest.main()

View File

@ -7,6 +7,7 @@ import gridfs
from mongoengine import *
from mongoengine.connection import _get_db
from mongoengine.base import _document_registry, NotRegistered
class FieldTest(unittest.TestCase):
@ -584,6 +585,39 @@ class FieldTest(unittest.TestCase):
Post.drop_collection()
User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError, "Link was removed from the registry"
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved.
"""
@ -701,6 +735,12 @@ class FieldTest(unittest.TestCase):
self.assertTrue(streamfile == result)
self.assertEquals(result.file.read(), text + more_text)
self.assertEquals(result.file.content_type, content_type)
result.file.seek(0)
self.assertEquals(result.file.tell(), 0)
self.assertEquals(result.file.read(len(text)), text)
self.assertEquals(result.file.tell(), len(text))
self.assertEquals(result.file.read(len(more_text)), more_text)
self.assertEquals(result.file.tell(), len(text + more_text))
result.file.delete()
# Ensure deleted file returns None

View File

@ -597,6 +597,81 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__n=3).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__n=-3).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_slicing_nested_fields(self):
"""Ensure that query slicing an embedded array works.
"""
class EmbeddedNumber(EmbeddedDocument):
n = ListField(IntField())
class Numbers(Document):
embedded = EmbeddedDocumentField(EmbeddedNumber)
Numbers.drop_collection()
numbers = Numbers()
numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
self.assertEquals(numbers.embedded.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
self.assertEquals(numbers.embedded.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
self.assertEquals(numbers.embedded.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query.
"""
@ -1294,6 +1369,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document):
tags = ListField(StringField())
deleted = BooleanField(default=False)
date = DateTimeField(default=datetime.now)
@queryset_manager
def objects(doc_cls, queryset):
@ -1301,7 +1377,7 @@ class QuerySetTest(unittest.TestCase):
@queryset_manager
def music_posts(doc_cls, queryset):
return queryset(tags='music', deleted=False)
return queryset(tags='music', deleted=False).order_by('-date')
BlogPost.drop_collection()
@ -1317,7 +1393,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.music_posts],
[post1.id, post2.id])
[post2.id, post1.id])
BlogPost.drop_collection()
@ -1760,6 +1836,25 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection()
def test_order_works_with_primary(self):
"""Ensure that order_by and primary work.
"""
class Number(Document):
n = IntField(primary_key=True)
Number.drop_collection()
Number(n=1).save()
Number(n=2).save()
Number(n=3).save()
numbers = [n.n for n in Number.objects.order_by('-n')]
self.assertEquals([3, 2, 1], numbers)
numbers = [n.n for n in Number.objects.order_by('+n')]
self.assertEquals([1, 2, 3], numbers)
Number.drop_collection()
class QTest(unittest.TestCase):
@ -1931,49 +2026,53 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self):
q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset()
self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY)
q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__':
unittest.main()