diff --git a/.travis.yml b/.travis.yml
index f9993e79..8af73c6b 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -52,19 +52,22 @@ install:
- wget http://fastdl.mongodb.org/linux/mongodb-linux-x86_64-${MONGODB}.tgz
- tar xzf mongodb-linux-x86_64-${MONGODB}.tgz
- ${PWD}/mongodb-linux-x86_64-${MONGODB}/bin/mongod --version
- # Install python dependencies
+ # Install Python dependencies.
- pip install --upgrade pip
- pip install coveralls
- pip install flake8 flake8-import-order
- pip install tox # tox 3.11.0 has requirement virtualenv>=14.0.0
- pip install virtualenv # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32)
- # Install the tox venv
+ # Install the tox venv.
- tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test
+ # Install black for Python v3.7 only.
+ - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then pip install black; fi
before_script:
- mkdir ${PWD}/mongodb-linux-x86_64-${MONGODB}/data
- ${PWD}/mongodb-linux-x86_64-${MONGODB}/bin/mongod --dbpath ${PWD}/mongodb-linux-x86_64-${MONGODB}/data --logpath ${PWD}/mongodb-linux-x86_64-${MONGODB}/mongodb.log --fork
- - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then flake8 .; else echo "flake8 only runs on py27"; fi # Run flake8 for py27
+ - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then flake8 .; else echo "flake8 only runs on py27"; fi # Run flake8 for Python 2.7 only
+ - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then black --check .; else echo "black only runs on py37"; fi # Run black for Python 3.7 only
- mongo --eval 'db.version();' # Make sure mongo is awake
script:
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index f7b15c85..4711c1d3 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -31,12 +31,8 @@ build. You should ensure that your code is properly converted by
Style Guide
-----------
-MongoEngine aims to follow `PEP8 `_
-including 4 space indents. When possible we try to stick to 79 character line
-limits. However, screens got bigger and an ORM has a strong focus on
-readability and if it can help, we accept 119 as maximum line length, in a
-similar way as `django does
-`_
+MongoEngine uses `black `_ for code
+formatting.
Testing
-------
diff --git a/benchmarks/test_basic_doc_ops.py b/benchmarks/test_basic_doc_ops.py
index 06f0538b..e840f97a 100644
--- a/benchmarks/test_basic_doc_ops.py
+++ b/benchmarks/test_basic_doc_ops.py
@@ -1,11 +1,18 @@
from timeit import repeat
import mongoengine
-from mongoengine import (BooleanField, Document, EmailField, EmbeddedDocument,
- EmbeddedDocumentField, IntField, ListField,
- StringField)
+from mongoengine import (
+ BooleanField,
+ Document,
+ EmailField,
+ EmbeddedDocument,
+ EmbeddedDocumentField,
+ IntField,
+ ListField,
+ StringField,
+)
-mongoengine.connect(db='mongoengine_benchmark_test')
+mongoengine.connect(db="mongoengine_benchmark_test")
def timeit(f, n=10000):
@@ -24,46 +31,41 @@ def test_basic():
def init_book():
return Book(
- name='Always be closing',
+ name="Always be closing",
pages=100,
- tags=['self-help', 'sales'],
+ tags=["self-help", "sales"],
is_published=True,
- author_email='alec@example.com',
+ author_email="alec@example.com",
)
- print('Doc initialization: %.3fus' % (timeit(init_book, 1000) * 10**6))
+ print("Doc initialization: %.3fus" % (timeit(init_book, 1000) * 10 ** 6))
b = init_book()
- print('Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6))
+ print("Doc getattr: %.3fus" % (timeit(lambda: b.name, 10000) * 10 ** 6))
print(
- 'Doc setattr: %.3fus' % (
- timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6
- )
+ "Doc setattr: %.3fus"
+ % (timeit(lambda: setattr(b, "name", "New name"), 10000) * 10 ** 6)
)
- print('Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6))
+ print("Doc to mongo: %.3fus" % (timeit(b.to_mongo, 1000) * 10 ** 6))
- print('Doc validation: %.3fus' % (timeit(b.validate, 1000) * 10**6))
+ print("Doc validation: %.3fus" % (timeit(b.validate, 1000) * 10 ** 6))
def save_book():
- b._mark_as_changed('name')
- b._mark_as_changed('tags')
+ b._mark_as_changed("name")
+ b._mark_as_changed("tags")
b.save()
- print('Save to database: %.3fus' % (timeit(save_book, 100) * 10**6))
+ print("Save to database: %.3fus" % (timeit(save_book, 100) * 10 ** 6))
son = b.to_mongo()
print(
- 'Load from SON: %.3fus' % (
- timeit(lambda: Book._from_son(son), 1000) * 10**6
- )
+ "Load from SON: %.3fus" % (timeit(lambda: Book._from_son(son), 1000) * 10 ** 6)
)
print(
- 'Load from database: %.3fus' % (
- timeit(lambda: Book.objects[0], 100) * 10**6
- )
+ "Load from database: %.3fus" % (timeit(lambda: Book.objects[0], 100) * 10 ** 6)
)
def create_and_delete_book():
@@ -72,9 +74,8 @@ def test_basic():
b.delete()
print(
- 'Init + save to database + delete: %.3fms' % (
- timeit(create_and_delete_book, 10) * 10**3
- )
+ "Init + save to database + delete: %.3fms"
+ % (timeit(create_and_delete_book, 10) * 10 ** 3)
)
@@ -92,42 +93,36 @@ def test_big_doc():
def init_company():
return Company(
- name='MongoDB, Inc.',
+ name="MongoDB, Inc.",
contacts=[
- Contact(
- name='Contact %d' % x,
- title='CEO',
- address='Address %d' % x,
- )
+ Contact(name="Contact %d" % x, title="CEO", address="Address %d" % x)
for x in range(1000)
- ]
+ ],
)
company = init_company()
- print('Big doc to mongo: %.3fms' % (timeit(company.to_mongo, 100) * 10**3))
+ print("Big doc to mongo: %.3fms" % (timeit(company.to_mongo, 100) * 10 ** 3))
- print('Big doc validation: %.3fms' % (timeit(company.validate, 1000) * 10**3))
+ print("Big doc validation: %.3fms" % (timeit(company.validate, 1000) * 10 ** 3))
company.save()
def save_company():
- company._mark_as_changed('name')
- company._mark_as_changed('contacts')
+ company._mark_as_changed("name")
+ company._mark_as_changed("contacts")
company.save()
- print('Save to database: %.3fms' % (timeit(save_company, 100) * 10**3))
+ print("Save to database: %.3fms" % (timeit(save_company, 100) * 10 ** 3))
son = company.to_mongo()
print(
- 'Load from SON: %.3fms' % (
- timeit(lambda: Company._from_son(son), 100) * 10**3
- )
+ "Load from SON: %.3fms"
+ % (timeit(lambda: Company._from_son(son), 100) * 10 ** 3)
)
print(
- 'Load from database: %.3fms' % (
- timeit(lambda: Company.objects[0], 100) * 10**3
- )
+ "Load from database: %.3fms"
+ % (timeit(lambda: Company.objects[0], 100) * 10 ** 3)
)
def create_and_delete_company():
@@ -136,13 +131,12 @@ def test_big_doc():
c.delete()
print(
- 'Init + save to database + delete: %.3fms' % (
- timeit(create_and_delete_company, 10) * 10**3
- )
+ "Init + save to database + delete: %.3fms"
+ % (timeit(create_and_delete_company, 10) * 10 ** 3)
)
-if __name__ == '__main__':
+if __name__ == "__main__":
test_basic()
- print('-' * 100)
+ print("-" * 100)
test_big_doc()
diff --git a/benchmarks/test_inserts.py b/benchmarks/test_inserts.py
index 8113d988..fd017bae 100644
--- a/benchmarks/test_inserts.py
+++ b/benchmarks/test_inserts.py
@@ -26,10 +26,10 @@ myNoddys = noddy.find()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
- print('PyMongo: Creating 10000 dictionaries.')
+ print("-" * 100)
+ print("PyMongo: Creating 10000 dictionaries.")
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
stmt = """
from pymongo import MongoClient, WriteConcern
@@ -49,10 +49,10 @@ myNoddys = noddy.find()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
+ print("-" * 100)
print('PyMongo: Creating 10000 dictionaries (write_concern={"w": 0}).')
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
setup = """
from pymongo import MongoClient
@@ -78,10 +78,10 @@ myNoddys = Noddy.objects()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
- print('MongoEngine: Creating 10000 dictionaries.')
+ print("-" * 100)
+ print("MongoEngine: Creating 10000 dictionaries.")
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
stmt = """
for i in range(10000):
@@ -96,10 +96,10 @@ myNoddys = Noddy.objects()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
- print('MongoEngine: Creating 10000 dictionaries (using a single field assignment).')
+ print("-" * 100)
+ print("MongoEngine: Creating 10000 dictionaries (using a single field assignment).")
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
stmt = """
for i in range(10000):
@@ -112,10 +112,10 @@ myNoddys = Noddy.objects()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
+ print("-" * 100)
print('MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}).')
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
stmt = """
for i in range(10000):
@@ -128,10 +128,12 @@ myNoddys = Noddy.objects()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
- print('MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}, validate=False).')
+ print("-" * 100)
+ print(
+ 'MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}, validate=False).'
+ )
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
stmt = """
for i in range(10000):
@@ -144,10 +146,12 @@ myNoddys = Noddy.objects()
[n for n in myNoddys] # iterate
"""
- print('-' * 100)
- print('MongoEngine: Creating 10000 dictionaries (force_insert=True, write_concern={"w": 0}, validate=False).')
+ print("-" * 100)
+ print(
+ 'MongoEngine: Creating 10000 dictionaries (force_insert=True, write_concern={"w": 0}, validate=False).'
+ )
t = timeit.Timer(stmt=stmt, setup=setup)
- print('{}s'.format(t.timeit(1)))
+ print("{}s".format(t.timeit(1)))
if __name__ == "__main__":
diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py
index 796336e6..3ca2384c 100644
--- a/docs/code/tumblelog.py
+++ b/docs/code/tumblelog.py
@@ -1,16 +1,19 @@
from mongoengine import *
-connect('tumblelog')
+connect("tumblelog")
+
class Comment(EmbeddedDocument):
content = StringField()
name = StringField(max_length=120)
+
class User(Document):
email = StringField(required=True)
first_name = StringField(max_length=50)
last_name = StringField(max_length=50)
+
class Post(Document):
title = StringField(max_length=120, required=True)
author = ReferenceField(User)
@@ -18,54 +21,57 @@ class Post(Document):
comments = ListField(EmbeddedDocumentField(Comment))
# bugfix
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class TextPost(Post):
content = StringField()
+
class ImagePost(Post):
image_path = StringField()
+
class LinkPost(Post):
link_url = StringField()
+
Post.drop_collection()
-john = User(email='jdoe@example.com', first_name='John', last_name='Doe')
+john = User(email="jdoe@example.com", first_name="John", last_name="Doe")
john.save()
-post1 = TextPost(title='Fun with MongoEngine', author=john)
-post1.content = 'Took a look at MongoEngine today, looks pretty cool.'
-post1.tags = ['mongodb', 'mongoengine']
+post1 = TextPost(title="Fun with MongoEngine", author=john)
+post1.content = "Took a look at MongoEngine today, looks pretty cool."
+post1.tags = ["mongodb", "mongoengine"]
post1.save()
-post2 = LinkPost(title='MongoEngine Documentation', author=john)
-post2.link_url = 'http://tractiondigital.com/labs/mongoengine/docs'
-post2.tags = ['mongoengine']
+post2 = LinkPost(title="MongoEngine Documentation", author=john)
+post2.link_url = "http://tractiondigital.com/labs/mongoengine/docs"
+post2.tags = ["mongoengine"]
post2.save()
-print('ALL POSTS')
+print("ALL POSTS")
print()
for post in Post.objects:
print(post.title)
- #print '=' * post.title.count()
+ # print '=' * post.title.count()
print("=" * 20)
if isinstance(post, TextPost):
print(post.content)
if isinstance(post, LinkPost):
- print('Link:', post.link_url)
+ print("Link:", post.link_url)
print()
print()
-print('POSTS TAGGED \'MONGODB\'')
+print("POSTS TAGGED 'MONGODB'")
print()
-for post in Post.objects(tags='mongodb'):
+for post in Post.objects(tags="mongodb"):
print(post.title)
print()
-num_posts = Post.objects(tags='mongodb').count()
+num_posts = Post.objects(tags="mongodb").count()
print('Found %d posts with tag "mongodb"' % num_posts)
diff --git a/docs/conf.py b/docs/conf.py
index 468e71e0..0d642e0c 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,29 +20,29 @@ import mongoengine
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
# -- General configuration -----------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
-extensions = ['sphinx.ext.autodoc', 'sphinx.ext.todo']
+extensions = ["sphinx.ext.autodoc", "sphinx.ext.todo"]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix of source filenames.
-source_suffix = '.rst'
+source_suffix = ".rst"
# The encoding of source files.
-#source_encoding = 'utf-8'
+# source_encoding = 'utf-8'
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = u'MongoEngine'
-copyright = u'2009, MongoEngine Authors'
+project = u"MongoEngine"
+copyright = u"2009, MongoEngine Authors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -55,68 +55,66 @@ release = mongoengine.get_version()
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
-#language = None
+# language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
-#today = ''
+# today = ''
# Else, today_fmt is used as the format for a strftime call.
-#today_fmt = '%B %d, %Y'
+# today_fmt = '%B %d, %Y'
# List of documents that shouldn't be included in the build.
-#unused_docs = []
+# unused_docs = []
# List of directories, relative to source directory, that shouldn't be searched
# for source files.
-exclude_trees = ['_build']
+exclude_trees = ["_build"]
# The reST default role (used for this markup: `text`) to use for all documents.
-#default_role = None
+# default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
-#add_function_parentheses = True
+# add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
-#add_module_names = True
+# add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
-#show_authors = False
+# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting.
-#modindex_common_prefix = []
+# modindex_common_prefix = []
# -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. Major themes that come with
# Sphinx are currently 'default' and 'sphinxdoc'.
-html_theme = 'sphinx_rtd_theme'
+html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
-html_theme_options = {
- 'canonical_url': 'http://docs.mongoengine.org/en/latest/'
-}
+html_theme_options = {"canonical_url": "http://docs.mongoengine.org/en/latest/"}
# Add any paths that contain custom themes here, relative to this directory.
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
-#html_title = None
+# html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
-#html_short_title = None
+# html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-#html_logo = None
+# html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
@@ -126,11 +124,11 @@ html_favicon = "favicon.ico"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-#html_static_path = ['_static']
+# html_static_path = ['_static']
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
-#html_last_updated_fmt = '%b %d, %Y'
+# html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
@@ -138,69 +136,68 @@ html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
html_sidebars = {
- 'index': ['globaltoc.html', 'searchbox.html'],
- '**': ['localtoc.html', 'relations.html', 'searchbox.html']
+ "index": ["globaltoc.html", "searchbox.html"],
+ "**": ["localtoc.html", "relations.html", "searchbox.html"],
}
# Additional templates that should be rendered to pages, maps page names to
# template names.
-#html_additional_pages = {}
+# html_additional_pages = {}
# If false, no module index is generated.
-#html_use_modindex = True
+# html_use_modindex = True
# If false, no index is generated.
-#html_use_index = True
+# html_use_index = True
# If true, the index is split into individual pages for each letter.
-#html_split_index = False
+# html_split_index = False
# If true, links to the reST sources are added to the pages.
-#html_show_sourcelink = True
+# html_show_sourcelink = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
-#html_use_opensearch = ''
+# html_use_opensearch = ''
# If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml").
-#html_file_suffix = ''
+# html_file_suffix = ''
# Output file base name for HTML help builder.
-htmlhelp_basename = 'MongoEnginedoc'
+htmlhelp_basename = "MongoEnginedoc"
# -- Options for LaTeX output --------------------------------------------------
# The paper size ('letter' or 'a4').
-latex_paper_size = 'a4'
+latex_paper_size = "a4"
# The font size ('10pt', '11pt' or '12pt').
-#latex_font_size = '10pt'
+# latex_font_size = '10pt'
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
- ('index', 'MongoEngine.tex', 'MongoEngine Documentation',
- 'Ross Lawley', 'manual'),
+ ("index", "MongoEngine.tex", "MongoEngine Documentation", "Ross Lawley", "manual")
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
-#latex_logo = None
+# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
-#latex_use_parts = False
+# latex_use_parts = False
# Additional stuff for the LaTeX preamble.
-#latex_preamble = ''
+# latex_preamble = ''
# Documents to append as an appendix to all manuals.
-#latex_appendices = []
+# latex_appendices = []
# If false, no module index is generated.
-#latex_use_modindex = True
+# latex_use_modindex = True
-autoclass_content = 'both'
+autoclass_content = "both"
diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py
index bb7a4e57..d7093d28 100644
--- a/mongoengine/__init__.py
+++ b/mongoengine/__init__.py
@@ -18,9 +18,14 @@ from mongoengine.queryset import *
from mongoengine.signals import *
-__all__ = (list(document.__all__) + list(fields.__all__) +
- list(connection.__all__) + list(queryset.__all__) +
- list(signals.__all__) + list(errors.__all__))
+__all__ = (
+ list(document.__all__)
+ + list(fields.__all__)
+ + list(connection.__all__)
+ + list(queryset.__all__)
+ + list(signals.__all__)
+ + list(errors.__all__)
+)
VERSION = (0, 18, 2)
@@ -31,7 +36,7 @@ def get_version():
For example, if `VERSION == (0, 10, 7)`, return '0.10.7'.
"""
- return '.'.join(map(str, VERSION))
+ return ".".join(map(str, VERSION))
__version__ = get_version()
diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py
index e069a147..dca0c4bb 100644
--- a/mongoengine/base/__init__.py
+++ b/mongoengine/base/__init__.py
@@ -12,17 +12,22 @@ from mongoengine.base.metaclasses import *
__all__ = (
# common
- 'UPDATE_OPERATORS', '_document_registry', 'get_document',
-
+ "UPDATE_OPERATORS",
+ "_document_registry",
+ "get_document",
# datastructures
- 'BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference',
-
+ "BaseDict",
+ "BaseList",
+ "EmbeddedDocumentList",
+ "LazyReference",
# document
- 'BaseDocument',
-
+ "BaseDocument",
# fields
- 'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField',
-
+ "BaseField",
+ "ComplexBaseField",
+ "ObjectIdField",
+ "GeoJsonBaseField",
# metaclasses
- 'DocumentMetaclass', 'TopLevelDocumentMetaclass'
+ "DocumentMetaclass",
+ "TopLevelDocumentMetaclass",
)
diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py
index 999fd23a..85897324 100644
--- a/mongoengine/base/common.py
+++ b/mongoengine/base/common.py
@@ -1,12 +1,25 @@
from mongoengine.errors import NotRegistered
-__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
+__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
-UPDATE_OPERATORS = {'set', 'unset', 'inc', 'dec', 'mul',
- 'pop', 'push', 'push_all', 'pull',
- 'pull_all', 'add_to_set', 'set_on_insert',
- 'min', 'max', 'rename'}
+UPDATE_OPERATORS = {
+ "set",
+ "unset",
+ "inc",
+ "dec",
+ "mul",
+ "pop",
+ "push",
+ "push_all",
+ "pull",
+ "pull_all",
+ "add_to_set",
+ "set_on_insert",
+ "min",
+ "max",
+ "rename",
+}
_document_registry = {}
@@ -17,25 +30,33 @@ def get_document(name):
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
- single_end = name.split('.')[-1]
- compound_end = '.%s' % single_end
- possible_match = [k for k in _document_registry
- if k.endswith(compound_end) or k == single_end]
+ single_end = name.split(".")[-1]
+ compound_end = ".%s" % single_end
+ possible_match = [
+ k for k in _document_registry if k.endswith(compound_end) or k == single_end
+ ]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
- raise NotRegistered("""
+ raise NotRegistered(
+ """
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
- """.strip() % name)
+ """.strip()
+ % name
+ )
return doc
def _get_documents_by_db(connection_alias, default_connection_alias):
"""Get all registered Documents class attached to a given database"""
- def get_doc_alias(doc_cls):
- return doc_cls._meta.get('db_alias', default_connection_alias)
- return [doc_cls for doc_cls in _document_registry.values()
- if get_doc_alias(doc_cls) == connection_alias]
+ def get_doc_alias(doc_cls):
+ return doc_cls._meta.get("db_alias", default_connection_alias)
+
+ return [
+ doc_cls
+ for doc_cls in _document_registry.values()
+ if get_doc_alias(doc_cls) == connection_alias
+ ]
diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py
index cce71846..d1b5ae76 100644
--- a/mongoengine/base/datastructures.py
+++ b/mongoengine/base/datastructures.py
@@ -7,26 +7,36 @@ from six import iteritems
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
-__all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference')
+__all__ = (
+ "BaseDict",
+ "StrictDict",
+ "BaseList",
+ "EmbeddedDocumentList",
+ "LazyReference",
+)
def mark_as_changed_wrapper(parent_method):
"""Decorator that ensures _mark_as_changed method gets called."""
+
def wrapper(self, *args, **kwargs):
# Can't use super() in the decorator.
result = parent_method(self, *args, **kwargs)
self._mark_as_changed()
return result
+
return wrapper
def mark_key_as_changed_wrapper(parent_method):
"""Decorator that ensures _mark_as_changed method gets called with the key argument"""
+
def wrapper(self, key, *args, **kwargs):
# Can't use super() in the decorator.
result = parent_method(self, key, *args, **kwargs)
self._mark_as_changed(key)
return result
+
return wrapper
@@ -38,7 +48,7 @@ class BaseDict(dict):
_name = None
def __init__(self, dict_items, instance, name):
- BaseDocument = _import_class('BaseDocument')
+ BaseDocument = _import_class("BaseDocument")
if isinstance(instance, BaseDocument):
self._instance = weakref.proxy(instance)
@@ -55,15 +65,15 @@ class BaseDict(dict):
def __getitem__(self, key):
value = super(BaseDict, self).__getitem__(key)
- EmbeddedDocument = _import_class('EmbeddedDocument')
+ EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif isinstance(value, dict) and not isinstance(value, BaseDict):
- value = BaseDict(value, None, '%s.%s' % (self._name, key))
+ value = BaseDict(value, None, "%s.%s" % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
elif isinstance(value, list) and not isinstance(value, BaseList):
- value = BaseList(value, None, '%s.%s' % (self._name, key))
+ value = BaseList(value, None, "%s.%s" % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
return value
@@ -87,9 +97,9 @@ class BaseDict(dict):
setdefault = mark_as_changed_wrapper(dict.setdefault)
def _mark_as_changed(self, key=None):
- if hasattr(self._instance, '_mark_as_changed'):
+ if hasattr(self._instance, "_mark_as_changed"):
if key:
- self._instance._mark_as_changed('%s.%s' % (self._name, key))
+ self._instance._mark_as_changed("%s.%s" % (self._name, key))
else:
self._instance._mark_as_changed(self._name)
@@ -102,7 +112,7 @@ class BaseList(list):
_name = None
def __init__(self, list_items, instance, name):
- BaseDocument = _import_class('BaseDocument')
+ BaseDocument = _import_class("BaseDocument")
if isinstance(instance, BaseDocument):
self._instance = weakref.proxy(instance)
@@ -117,17 +127,17 @@ class BaseList(list):
# to parent's instance. This is buggy for now but would require more work to be handled properly
return value
- EmbeddedDocument = _import_class('EmbeddedDocument')
+ EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif isinstance(value, dict) and not isinstance(value, BaseDict):
# Replace dict by BaseDict
- value = BaseDict(value, None, '%s.%s' % (self._name, key))
+ value = BaseDict(value, None, "%s.%s" % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
elif isinstance(value, list) and not isinstance(value, BaseList):
# Replace list by BaseList
- value = BaseList(value, None, '%s.%s' % (self._name, key))
+ value = BaseList(value, None, "%s.%s" % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
return value
@@ -181,17 +191,14 @@ class BaseList(list):
return self.__getitem__(slice(i, j))
def _mark_as_changed(self, key=None):
- if hasattr(self._instance, '_mark_as_changed'):
+ if hasattr(self._instance, "_mark_as_changed"):
if key:
- self._instance._mark_as_changed(
- '%s.%s' % (self._name, key % len(self))
- )
+ self._instance._mark_as_changed("%s.%s" % (self._name, key % len(self)))
else:
self._instance._mark_as_changed(self._name)
class EmbeddedDocumentList(BaseList):
-
def __init__(self, list_items, instance, name):
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
self._instance = instance
@@ -276,12 +283,10 @@ class EmbeddedDocumentList(BaseList):
"""
values = self.__only_matches(self, kwargs)
if len(values) == 0:
- raise DoesNotExist(
- '%s matching query does not exist.' % self._name
- )
+ raise DoesNotExist("%s matching query does not exist." % self._name)
elif len(values) > 1:
raise MultipleObjectsReturned(
- '%d items returned, instead of 1' % len(values)
+ "%d items returned, instead of 1" % len(values)
)
return values[0]
@@ -362,7 +367,7 @@ class EmbeddedDocumentList(BaseList):
class StrictDict(object):
__slots__ = ()
- _special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'}
+ _special_fields = {"get", "pop", "iteritems", "items", "keys", "create"}
_classes = {}
def __init__(self, **kwargs):
@@ -370,14 +375,14 @@ class StrictDict(object):
setattr(self, k, v)
def __getitem__(self, key):
- key = '_reserved_' + key if key in self._special_fields else key
+ key = "_reserved_" + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
- key = '_reserved_' + key if key in self._special_fields else key
+ key = "_reserved_" + key if key in self._special_fields else key
return setattr(self, key, value)
def __contains__(self, key):
@@ -424,27 +429,32 @@ class StrictDict(object):
@classmethod
def create(cls, allowed_keys):
- allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
+ allowed_keys_tuple = tuple(
+ ("_reserved_" + k if k in cls._special_fields else k) for k in allowed_keys
+ )
allowed_keys = frozenset(allowed_keys_tuple)
if allowed_keys not in cls._classes:
+
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple
def __repr__(self):
- return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
+ return "{%s}" % ", ".join(
+ '"{0!s}": {1!r}'.format(k, v) for k, v in self.items()
+ )
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]
class LazyReference(DBRef):
- __slots__ = ('_cached_doc', 'passthrough', 'document_type')
+ __slots__ = ("_cached_doc", "passthrough", "document_type")
def fetch(self, force=False):
if not self._cached_doc or force:
self._cached_doc = self.document_type.objects.get(pk=self.pk)
if not self._cached_doc:
- raise DoesNotExist('Trying to dereference unknown document %s' % (self))
+ raise DoesNotExist("Trying to dereference unknown document %s" % (self))
return self._cached_doc
@property
@@ -455,7 +465,9 @@ class LazyReference(DBRef):
self.document_type = document_type
self._cached_doc = cached_doc
self.passthrough = passthrough
- super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk)
+ super(LazyReference, self).__init__(
+ self.document_type._get_collection_name(), pk
+ )
def __getitem__(self, name):
if not self.passthrough:
@@ -464,7 +476,7 @@ class LazyReference(DBRef):
return document[name]
def __getattr__(self, name):
- if not object.__getattribute__(self, 'passthrough'):
+ if not object.__getattribute__(self, "passthrough"):
raise AttributeError()
document = self.fetch()
try:
diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py
index 047d50a4..928a00c2 100644
--- a/mongoengine/base/document.py
+++ b/mongoengine/base/document.py
@@ -9,19 +9,27 @@ from six import iteritems
from mongoengine import signals
from mongoengine.base.common import get_document
-from mongoengine.base.datastructures import (BaseDict, BaseList,
- EmbeddedDocumentList,
- LazyReference,
- StrictDict)
+from mongoengine.base.datastructures import (
+ BaseDict,
+ BaseList,
+ EmbeddedDocumentList,
+ LazyReference,
+ StrictDict,
+)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.common import _import_class
-from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
- LookUpError, OperationError, ValidationError)
+from mongoengine.errors import (
+ FieldDoesNotExist,
+ InvalidDocumentError,
+ LookUpError,
+ OperationError,
+ ValidationError,
+)
from mongoengine.python_support import Hashable
-__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
+__all__ = ("BaseDocument", "NON_FIELD_ERRORS")
-NON_FIELD_ERRORS = '__all__'
+NON_FIELD_ERRORS = "__all__"
class BaseDocument(object):
@@ -35,9 +43,16 @@ class BaseDocument(object):
# field is primarily set via `_from_son` or `_clear_changed_fields`,
# though there are also other methods that manipulate it.
# 4. The codebase is littered with `hasattr` calls for `_changed_fields`.
- __slots__ = ('_changed_fields', '_initialised', '_created', '_data',
- '_dynamic_fields', '_auto_id_field', '_db_field_map',
- '__weakref__')
+ __slots__ = (
+ "_changed_fields",
+ "_initialised",
+ "_created",
+ "_data",
+ "_dynamic_fields",
+ "_auto_id_field",
+ "_db_field_map",
+ "__weakref__",
+ )
_dynamic = False
_dynamic_lock = True
@@ -61,27 +76,28 @@ class BaseDocument(object):
if args:
raise TypeError(
- 'Instantiating a document with positional arguments is not '
- 'supported. Please use `field_name=value` keyword arguments.'
+ "Instantiating a document with positional arguments is not "
+ "supported. Please use `field_name=value` keyword arguments."
)
- __auto_convert = values.pop('__auto_convert', True)
+ __auto_convert = values.pop("__auto_convert", True)
- __only_fields = set(values.pop('__only_fields', values))
+ __only_fields = set(values.pop("__only_fields", values))
- _created = values.pop('_created', True)
+ _created = values.pop("_created", True)
signals.pre_init.send(self.__class__, document=self, values=values)
# Check if there are undefined fields supplied to the constructor,
# if so raise an Exception.
- if not self._dynamic and (self._meta.get('strict', True) or _created):
+ if not self._dynamic and (self._meta.get("strict", True) or _created):
_undefined_fields = set(values.keys()) - set(
- self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
+ self._fields.keys() + ["id", "pk", "_cls", "_text_score"]
+ )
if _undefined_fields:
- msg = (
- 'The fields "{0}" do not exist on the document "{1}"'
- ).format(_undefined_fields, self._class_name)
+ msg = ('The fields "{0}" do not exist on the document "{1}"').format(
+ _undefined_fields, self._class_name
+ )
raise FieldDoesNotExist(msg)
if self.STRICT and not self._dynamic:
@@ -100,22 +116,22 @@ class BaseDocument(object):
value = getattr(self, key, None)
setattr(self, key, value)
- if '_cls' not in values:
+ if "_cls" not in values:
self._cls = self._class_name
# Set passed values after initialisation
if self._dynamic:
dynamic_data = {}
for key, value in iteritems(values):
- if key in self._fields or key == '_id':
+ if key in self._fields or key == "_id":
setattr(self, key, value)
else:
dynamic_data[key] = value
else:
- FileField = _import_class('FileField')
+ FileField = _import_class("FileField")
for key, value in iteritems(values):
key = self._reverse_db_field_map.get(key, key)
- if key in self._fields or key in ('id', 'pk', '_cls'):
+ if key in self._fields or key in ("id", "pk", "_cls"):
if __auto_convert and value is not None:
field = self._fields.get(key)
if field and not isinstance(field, FileField):
@@ -153,20 +169,20 @@ class BaseDocument(object):
# Handle dynamic data only if an initialised dynamic document
if self._dynamic and not self._dynamic_lock:
- if not hasattr(self, name) and not name.startswith('_'):
- DynamicField = _import_class('DynamicField')
+ if not hasattr(self, name) and not name.startswith("_"):
+ DynamicField = _import_class("DynamicField")
field = DynamicField(db_field=name, null=True)
field.name = name
self._dynamic_fields[name] = field
self._fields_ordered += (name,)
- if not name.startswith('_'):
+ if not name.startswith("_"):
value = self.__expand_dynamic_values(name, value)
# Handle marking data as changed
if name in self._dynamic_fields:
self._data[name] = value
- if hasattr(self, '_changed_fields'):
+ if hasattr(self, "_changed_fields"):
self._mark_as_changed(name)
try:
self__created = self._created
@@ -174,12 +190,12 @@ class BaseDocument(object):
self__created = True
if (
- self._is_document and
- not self__created and
- name in self._meta.get('shard_key', tuple()) and
- self._data.get(name) != value
+ self._is_document
+ and not self__created
+ and name in self._meta.get("shard_key", tuple())
+ and self._data.get(name) != value
):
- msg = 'Shard Keys are immutable. Tried to update %s' % name
+ msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
try:
@@ -187,38 +203,52 @@ class BaseDocument(object):
except AttributeError:
self__initialised = False
# Check if the user has created a new instance of a class
- if (self._is_document and self__initialised and
- self__created and name == self._meta.get('id_field')):
- super(BaseDocument, self).__setattr__('_created', False)
+ if (
+ self._is_document
+ and self__initialised
+ and self__created
+ and name == self._meta.get("id_field")
+ ):
+ super(BaseDocument, self).__setattr__("_created", False)
super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self):
data = {}
- for k in ('_changed_fields', '_initialised', '_created',
- '_dynamic_fields', '_fields_ordered'):
+ for k in (
+ "_changed_fields",
+ "_initialised",
+ "_created",
+ "_dynamic_fields",
+ "_fields_ordered",
+ ):
if hasattr(self, k):
data[k] = getattr(self, k)
- data['_data'] = self.to_mongo()
+ data["_data"] = self.to_mongo()
return data
def __setstate__(self, data):
- if isinstance(data['_data'], SON):
- data['_data'] = self.__class__._from_son(data['_data'])._data
- for k in ('_changed_fields', '_initialised', '_created', '_data',
- '_dynamic_fields'):
+ if isinstance(data["_data"], SON):
+ data["_data"] = self.__class__._from_son(data["_data"])._data
+ for k in (
+ "_changed_fields",
+ "_initialised",
+ "_created",
+ "_data",
+ "_dynamic_fields",
+ ):
if k in data:
setattr(self, k, data[k])
- if '_fields_ordered' in data:
+ if "_fields_ordered" in data:
if self._dynamic:
- setattr(self, '_fields_ordered', data['_fields_ordered'])
+ setattr(self, "_fields_ordered", data["_fields_ordered"])
else:
_super_fields_ordered = type(self)._fields_ordered
- setattr(self, '_fields_ordered', _super_fields_ordered)
+ setattr(self, "_fields_ordered", _super_fields_ordered)
- dynamic_fields = data.get('_dynamic_fields') or SON()
+ dynamic_fields = data.get("_dynamic_fields") or SON()
for k in dynamic_fields.keys():
- setattr(self, k, data['_data'].get(k))
+ setattr(self, k, data["_data"].get(k))
def __iter__(self):
return iter(self._fields_ordered)
@@ -255,24 +285,30 @@ class BaseDocument(object):
try:
u = self.__str__()
except (UnicodeEncodeError, UnicodeDecodeError):
- u = '[Bad Unicode data]'
+ u = "[Bad Unicode data]"
repr_type = str if u is None else type(u)
- return repr_type('<%s: %s>' % (self.__class__.__name__, u))
+ return repr_type("<%s: %s>" % (self.__class__.__name__, u))
def __str__(self):
# TODO this could be simpler?
- if hasattr(self, '__unicode__'):
+ if hasattr(self, "__unicode__"):
if six.PY3:
return self.__unicode__()
else:
- return six.text_type(self).encode('utf-8')
- return six.text_type('%s object' % self.__class__.__name__)
+ return six.text_type(self).encode("utf-8")
+ return six.text_type("%s object" % self.__class__.__name__)
def __eq__(self, other):
- if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
+ if (
+ isinstance(other, self.__class__)
+ and hasattr(other, "id")
+ and other.id is not None
+ ):
return self.id == other.id
if isinstance(other, DBRef):
- return self._get_collection_name() == other.collection and self.id == other.id
+ return (
+ self._get_collection_name() == other.collection and self.id == other.id
+ )
if self.id is None:
return self is other
return False
@@ -295,10 +331,12 @@ class BaseDocument(object):
Get text score from text query
"""
- if '_text_score' not in self._data:
- raise InvalidDocumentError('This document is not originally built from a text query')
+ if "_text_score" not in self._data:
+ raise InvalidDocumentError(
+ "This document is not originally built from a text query"
+ )
- return self._data['_text_score']
+ return self._data["_text_score"]
def to_mongo(self, use_db_field=True, fields=None):
"""
@@ -307,11 +345,11 @@ class BaseDocument(object):
fields = fields or []
data = SON()
- data['_id'] = None
- data['_cls'] = self._class_name
+ data["_id"] = None
+ data["_cls"] = self._class_name
# only root fields ['test1.a', 'test2'] => ['test1', 'test2']
- root_fields = {f.split('.')[0] for f in fields}
+ root_fields = {f.split(".")[0] for f in fields}
for field_name in self:
if root_fields and field_name not in root_fields:
@@ -326,16 +364,16 @@ class BaseDocument(object):
if value is not None:
f_inputs = field.to_mongo.__code__.co_varnames
ex_vars = {}
- if fields and 'fields' in f_inputs:
- key = '%s.' % field_name
+ if fields and "fields" in f_inputs:
+ key = "%s." % field_name
embedded_fields = [
- i.replace(key, '') for i in fields
- if i.startswith(key)]
+ i.replace(key, "") for i in fields if i.startswith(key)
+ ]
- ex_vars['fields'] = embedded_fields
+ ex_vars["fields"] = embedded_fields
- if 'use_db_field' in f_inputs:
- ex_vars['use_db_field'] = use_db_field
+ if "use_db_field" in f_inputs:
+ ex_vars["use_db_field"] = use_db_field
value = field.to_mongo(value, **ex_vars)
@@ -351,8 +389,8 @@ class BaseDocument(object):
data[field.name] = value
# Only add _cls if allow_inheritance is True
- if not self._meta.get('allow_inheritance'):
- data.pop('_cls')
+ if not self._meta.get("allow_inheritance"):
+ data.pop("_cls")
return data
@@ -372,18 +410,23 @@ class BaseDocument(object):
errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values
- fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
- self._data.get(name)) for name in self._fields_ordered]
+ fields = [
+ (
+ self._fields.get(name, self._dynamic_fields.get(name)),
+ self._data.get(name),
+ )
+ for name in self._fields_ordered
+ ]
- EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
- GenericEmbeddedDocumentField = _import_class(
- 'GenericEmbeddedDocumentField')
+ EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
+ GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
for field, value in fields:
if value is not None:
try:
- if isinstance(field, (EmbeddedDocumentField,
- GenericEmbeddedDocumentField)):
+ if isinstance(
+ field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)
+ ):
field._validate(value, clean=clean)
else:
field._validate(value)
@@ -391,17 +434,18 @@ class BaseDocument(object):
errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError) as error:
errors[field.name] = error
- elif field.required and not getattr(field, '_auto_gen', False):
- errors[field.name] = ValidationError('Field is required',
- field_name=field.name)
+ elif field.required and not getattr(field, "_auto_gen", False):
+ errors[field.name] = ValidationError(
+ "Field is required", field_name=field.name
+ )
if errors:
- pk = 'None'
- if hasattr(self, 'pk'):
+ pk = "None"
+ if hasattr(self, "pk"):
pk = self.pk
- elif self._instance and hasattr(self._instance, 'pk'):
+ elif self._instance and hasattr(self._instance, "pk"):
pk = self._instance.pk
- message = 'ValidationError (%s:%s) ' % (self._class_name, pk)
+ message = "ValidationError (%s:%s) " % (self._class_name, pk)
raise ValidationError(message, errors=errors)
def to_json(self, *args, **kwargs):
@@ -411,7 +455,7 @@ class BaseDocument(object):
MongoDB (as opposed to attribute names on this document).
Defaults to True.
"""
- use_db_field = kwargs.pop('use_db_field', True)
+ use_db_field = kwargs.pop("use_db_field", True)
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
@classmethod
@@ -434,22 +478,18 @@ class BaseDocument(object):
# If the value is a dict with '_cls' in it, turn it into a document
is_dict = isinstance(value, dict)
- if is_dict and '_cls' in value:
- cls = get_document(value['_cls'])
+ if is_dict and "_cls" in value:
+ cls = get_document(value["_cls"])
return cls(**value)
if is_dict:
- value = {
- k: self.__expand_dynamic_values(k, v)
- for k, v in value.items()
- }
+ value = {k: self.__expand_dynamic_values(k, v) for k, v in value.items()}
else:
value = [self.__expand_dynamic_values(name, v) for v in value]
# Convert lists / values so we can watch for any changes on them
- EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
- if (isinstance(value, (list, tuple)) and
- not isinstance(value, BaseList)):
+ EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField")
+ if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
if issubclass(type(self), EmbeddedDocumentListField):
value = EmbeddedDocumentList(value, self, name)
else:
@@ -464,26 +504,26 @@ class BaseDocument(object):
if not key:
return
- if not hasattr(self, '_changed_fields'):
+ if not hasattr(self, "_changed_fields"):
return
- if '.' in key:
- key, rest = key.split('.', 1)
+ if "." in key:
+ key, rest = key.split(".", 1)
key = self._db_field_map.get(key, key)
- key = '%s.%s' % (key, rest)
+ key = "%s.%s" % (key, rest)
else:
key = self._db_field_map.get(key, key)
if key not in self._changed_fields:
- levels, idx = key.split('.'), 1
+ levels, idx = key.split("."), 1
while idx <= len(levels):
- if '.'.join(levels[:idx]) in self._changed_fields:
+ if ".".join(levels[:idx]) in self._changed_fields:
break
idx += 1
else:
self._changed_fields.append(key)
# remove lower level changed fields
- level = '.'.join(levels[:idx]) + '.'
+ level = ".".join(levels[:idx]) + "."
remove = self._changed_fields.remove
for field in self._changed_fields[:]:
if field.startswith(level):
@@ -494,7 +534,7 @@ class BaseDocument(object):
are marked as changed.
"""
for changed in self._get_changed_fields():
- parts = changed.split('.')
+ parts = changed.split(".")
data = self
for part in parts:
if isinstance(data, list):
@@ -507,8 +547,10 @@ class BaseDocument(object):
else:
data = getattr(data, part, None)
- if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'):
- if getattr(data, '_is_document', False):
+ if not isinstance(data, LazyReference) and hasattr(
+ data, "_changed_fields"
+ ):
+ if getattr(data, "_is_document", False):
continue
data._changed_fields = []
@@ -524,39 +566,38 @@ class BaseDocument(object):
"""
# Loop list / dict fields as they contain documents
# Determine the iterator to use
- if not hasattr(data, 'items'):
+ if not hasattr(data, "items"):
iterator = enumerate(data)
else:
iterator = iteritems(data)
for index_or_key, value in iterator:
- item_key = '%s%s.' % (base_key, index_or_key)
+ item_key = "%s%s." % (base_key, index_or_key)
# don't check anything lower if this key is already marked
# as changed.
if item_key[:-1] in changed_fields:
continue
- if hasattr(value, '_get_changed_fields'):
+ if hasattr(value, "_get_changed_fields"):
changed = value._get_changed_fields()
- changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
+ changed_fields += ["%s%s" % (item_key, k) for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
- self._nestable_types_changed_fields(
- changed_fields, item_key, value)
+ self._nestable_types_changed_fields(changed_fields, item_key, value)
def _get_changed_fields(self):
"""Return a list of all fields that have explicitly been changed.
"""
- EmbeddedDocument = _import_class('EmbeddedDocument')
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- SortedListField = _import_class('SortedListField')
+ EmbeddedDocument = _import_class("EmbeddedDocument")
+ ReferenceField = _import_class("ReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
+ SortedListField = _import_class("SortedListField")
changed_fields = []
- changed_fields += getattr(self, '_changed_fields', [])
+ changed_fields += getattr(self, "_changed_fields", [])
for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name)
- key = '%s.' % db_field_name
+ key = "%s." % db_field_name
data = self._data.get(field_name, None)
field = self._fields.get(field_name)
@@ -564,16 +605,17 @@ class BaseDocument(object):
# Whole field already marked as changed, no need to go further
continue
- if isinstance(field, ReferenceField): # Don't follow referenced documents
+ if isinstance(field, ReferenceField): # Don't follow referenced documents
continue
if isinstance(data, EmbeddedDocument):
# Find all embedded fields that have been changed
changed = data._get_changed_fields()
- changed_fields += ['%s%s' % (key, k) for k in changed if k]
+ changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif isinstance(data, (list, tuple, dict)):
- if (hasattr(field, 'field') and
- isinstance(field.field, (ReferenceField, GenericReferenceField))):
+ if hasattr(field, "field") and isinstance(
+ field.field, (ReferenceField, GenericReferenceField)
+ ):
continue
elif isinstance(field, SortedListField) and field._ordering:
# if ordering is affected whole list is changed
@@ -581,8 +623,7 @@ class BaseDocument(object):
changed_fields.append(db_field_name)
continue
- self._nestable_types_changed_fields(
- changed_fields, key, data)
+ self._nestable_types_changed_fields(changed_fields, key, data)
return changed_fields
def _delta(self):
@@ -594,11 +635,11 @@ class BaseDocument(object):
set_fields = self._get_changed_fields()
unset_data = {}
- if hasattr(self, '_changed_fields'):
+ if hasattr(self, "_changed_fields"):
set_data = {}
# Fetch each set item from its path
for path in set_fields:
- parts = path.split('.')
+ parts = path.split(".")
d = doc
new_path = []
for p in parts:
@@ -608,26 +649,27 @@ class BaseDocument(object):
elif isinstance(d, list) and p.isdigit():
# An item of a list (identified by its index) is updated
d = d[int(p)]
- elif hasattr(d, 'get'):
+ elif hasattr(d, "get"):
# dict-like (dict, embedded document)
d = d.get(p)
new_path.append(p)
- path = '.'.join(new_path)
+ path = ".".join(new_path)
set_data[path] = d
else:
set_data = doc
- if '_id' in set_data:
- del set_data['_id']
+ if "_id" in set_data:
+ del set_data["_id"]
# Determine if any changed items were actually unset.
for path, value in set_data.items():
- if value or isinstance(value, (numbers.Number, bool)): # Account for 0 and True that are truthy
+ if value or isinstance(
+ value, (numbers.Number, bool)
+ ): # Account for 0 and True that are truthy
continue
- parts = path.split('.')
+ parts = path.split(".")
- if (self._dynamic and len(parts) and parts[0] in
- self._dynamic_fields):
+ if self._dynamic and len(parts) and parts[0] in self._dynamic_fields:
del set_data[path]
unset_data[path] = 1
continue
@@ -642,16 +684,16 @@ class BaseDocument(object):
for p in parts:
if isinstance(d, list) and p.isdigit():
d = d[int(p)]
- elif (hasattr(d, '__getattribute__') and
- not isinstance(d, dict)):
+ elif hasattr(d, "__getattribute__") and not isinstance(d, dict):
real_path = d._reverse_db_field_map.get(p, p)
d = getattr(d, real_path)
else:
d = d.get(p)
- if hasattr(d, '_fields'):
- field_name = d._reverse_db_field_map.get(db_field_name,
- db_field_name)
+ if hasattr(d, "_fields"):
+ field_name = d._reverse_db_field_map.get(
+ db_field_name, db_field_name
+ )
if field_name in d._fields:
default = d._fields.get(field_name).default
else:
@@ -672,7 +714,7 @@ class BaseDocument(object):
"""Return the collection name for this class. None for abstract
class.
"""
- return cls._meta.get('collection', None)
+ return cls._meta.get("collection", None)
@classmethod
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
@@ -685,7 +727,7 @@ class BaseDocument(object):
# Get the class name from the document, falling back to the given
# class if unavailable
- class_name = son.get('_cls', cls._class_name)
+ class_name = son.get("_cls", cls._class_name)
# Convert SON to a data dict, making sure each key is a string and
# corresponds to the right db field.
@@ -710,18 +752,20 @@ class BaseDocument(object):
if field.db_field in data:
value = data[field.db_field]
try:
- data[field_name] = (value if value is None
- else field.to_python(value))
+ data[field_name] = (
+ value if value is None else field.to_python(value)
+ )
if field_name != field.db_field:
del data[field.db_field]
except (AttributeError, ValueError) as e:
errors_dict[field_name] = e
if errors_dict:
- errors = '\n'.join(['%s - %s' % (k, v)
- for k, v in errors_dict.items()])
- msg = ('Invalid data to create a `%s` instance.\n%s'
- % (cls._class_name, errors))
+ errors = "\n".join(["%s - %s" % (k, v) for k, v in errors_dict.items()])
+ msg = "Invalid data to create a `%s` instance.\n%s" % (
+ cls._class_name,
+ errors,
+ )
raise InvalidDocumentError(msg)
# In STRICT documents, remove any keys that aren't in cls._fields
@@ -729,10 +773,7 @@ class BaseDocument(object):
data = {k: v for k, v in iteritems(data) if k in cls._fields}
obj = cls(
- __auto_convert=False,
- _created=created,
- __only_fields=only_fields,
- **data
+ __auto_convert=False, _created=created, __only_fields=only_fields, **data
)
obj._changed_fields = []
if not _auto_dereference:
@@ -754,15 +795,13 @@ class BaseDocument(object):
# Create a map of index fields to index spec. We're converting
# the fields from a list to a tuple so that it's hashable.
- spec_fields = {
- tuple(index['fields']): index for index in index_specs
- }
+ spec_fields = {tuple(index["fields"]): index for index in index_specs}
# For each new index, if there's an existing index with the same
# fields list, update the existing spec with all data from the
# new spec.
for new_index in indices:
- candidate = spec_fields.get(tuple(new_index['fields']))
+ candidate = spec_fields.get(tuple(new_index["fields"]))
if candidate is None:
index_specs.append(new_index)
else:
@@ -779,9 +818,9 @@ class BaseDocument(object):
def _build_index_spec(cls, spec):
"""Build a PyMongo index spec from a MongoEngine index spec."""
if isinstance(spec, six.string_types):
- spec = {'fields': [spec]}
+ spec = {"fields": [spec]}
elif isinstance(spec, (list, tuple)):
- spec = {'fields': list(spec)}
+ spec = {"fields": list(spec)}
elif isinstance(spec, dict):
spec = dict(spec)
@@ -789,19 +828,21 @@ class BaseDocument(object):
direction = None
# Check to see if we need to include _cls
- allow_inheritance = cls._meta.get('allow_inheritance')
+ allow_inheritance = cls._meta.get("allow_inheritance")
include_cls = (
- allow_inheritance and
- not spec.get('sparse', False) and
- spec.get('cls', True) and
- '_cls' not in spec['fields']
+ allow_inheritance
+ and not spec.get("sparse", False)
+ and spec.get("cls", True)
+ and "_cls" not in spec["fields"]
)
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
- include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
- if 'cls' in spec:
- spec.pop('cls')
- for key in spec['fields']:
+ include_cls = include_cls and (
+ spec.get("cls", False) or cls._meta.get("index_cls", True)
+ )
+ if "cls" in spec:
+ spec.pop("cls")
+ for key in spec["fields"]:
# If inherited spec continue
if isinstance(key, (list, tuple)):
continue
@@ -814,51 +855,54 @@ class BaseDocument(object):
# GEOHAYSTACK from )
# GEO2D from *
direction = pymongo.ASCENDING
- if key.startswith('-'):
+ if key.startswith("-"):
direction = pymongo.DESCENDING
- elif key.startswith('$'):
+ elif key.startswith("$"):
direction = pymongo.TEXT
- elif key.startswith('#'):
+ elif key.startswith("#"):
direction = pymongo.HASHED
- elif key.startswith('('):
+ elif key.startswith("("):
direction = pymongo.GEOSPHERE
- elif key.startswith(')'):
+ elif key.startswith(")"):
direction = pymongo.GEOHAYSTACK
- elif key.startswith('*'):
+ elif key.startswith("*"):
direction = pymongo.GEO2D
- if key.startswith(('+', '-', '*', '$', '#', '(', ')')):
+ if key.startswith(("+", "-", "*", "$", "#", "(", ")")):
key = key[1:]
# Use real field name, do it manually because we need field
# objects for the next part (list field checking)
- parts = key.split('.')
- if parts in (['pk'], ['id'], ['_id']):
- key = '_id'
+ parts = key.split(".")
+ if parts in (["pk"], ["id"], ["_id"]):
+ key = "_id"
else:
fields = cls._lookup_field(parts)
parts = []
for field in fields:
try:
- if field != '_id':
+ if field != "_id":
field = field.db_field
except AttributeError:
pass
parts.append(field)
- key = '.'.join(parts)
+ key = ".".join(parts)
index_list.append((key, direction))
# Don't add cls to a geo index
if include_cls and direction not in (
- pymongo.GEO2D, pymongo.GEOHAYSTACK, pymongo.GEOSPHERE):
- index_list.insert(0, ('_cls', 1))
+ pymongo.GEO2D,
+ pymongo.GEOHAYSTACK,
+ pymongo.GEOSPHERE,
+ ):
+ index_list.insert(0, ("_cls", 1))
if index_list:
- spec['fields'] = index_list
+ spec["fields"] = index_list
return spec
@classmethod
- def _unique_with_indexes(cls, namespace=''):
+ def _unique_with_indexes(cls, namespace=""):
"""Find unique indexes in the document schema and return them."""
unique_indexes = []
for field_name, field in cls._fields.items():
@@ -876,36 +920,39 @@ class BaseDocument(object):
# Convert unique_with field names to real field names
unique_with = []
for other_name in field.unique_with:
- parts = other_name.split('.')
+ parts = other_name.split(".")
# Lookup real name
parts = cls._lookup_field(parts)
name_parts = [part.db_field for part in parts]
- unique_with.append('.'.join(name_parts))
+ unique_with.append(".".join(name_parts))
# Unique field should be required
parts[-1].required = True
- sparse = (not sparse and
- parts[-1].name not in cls.__dict__)
+ sparse = not sparse and parts[-1].name not in cls.__dict__
unique_fields += unique_with
# Add the new index to the list
fields = [
- ('%s%s' % (namespace, f), pymongo.ASCENDING)
- for f in unique_fields
+ ("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields
]
- index = {'fields': fields, 'unique': True, 'sparse': sparse}
+ index = {"fields": fields, "unique": True, "sparse": sparse}
unique_indexes.append(index)
- if field.__class__.__name__ in {'EmbeddedDocumentListField',
- 'ListField', 'SortedListField'}:
+ if field.__class__.__name__ in {
+ "EmbeddedDocumentListField",
+ "ListField",
+ "SortedListField",
+ }:
field = field.field
# Grab any embedded document field unique indexes
- if (field.__class__.__name__ == 'EmbeddedDocumentField' and
- field.document_type != cls):
- field_namespace = '%s.' % field_name
+ if (
+ field.__class__.__name__ == "EmbeddedDocumentField"
+ and field.document_type != cls
+ ):
+ field_namespace = "%s." % field_name
doc_cls = field.document_type
unique_indexes += doc_cls._unique_with_indexes(field_namespace)
@@ -917,32 +964,36 @@ class BaseDocument(object):
geo_indices = []
inspected.append(cls)
- geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField',
- 'PointField', 'LineStringField',
- 'PolygonField')
+ geo_field_type_names = (
+ "EmbeddedDocumentField",
+ "GeoPointField",
+ "PointField",
+ "LineStringField",
+ "PolygonField",
+ )
- geo_field_types = tuple([_import_class(field)
- for field in geo_field_type_names])
+ geo_field_types = tuple(
+ [_import_class(field) for field in geo_field_type_names]
+ )
for field in cls._fields.values():
if not isinstance(field, geo_field_types):
continue
- if hasattr(field, 'document_type'):
+ if hasattr(field, "document_type"):
field_cls = field.document_type
if field_cls in inspected:
continue
- if hasattr(field_cls, '_geo_indices'):
+ if hasattr(field_cls, "_geo_indices"):
geo_indices += field_cls._geo_indices(
- inspected, parent_field=field.db_field)
+ inspected, parent_field=field.db_field
+ )
elif field._geo_index:
field_name = field.db_field
if parent_field:
- field_name = '%s.%s' % (parent_field, field_name)
- geo_indices.append({
- 'fields': [(field_name, field._geo_index)]
- })
+ field_name = "%s.%s" % (parent_field, field_name)
+ geo_indices.append({"fields": [(field_name, field._geo_index)]})
return geo_indices
@@ -983,8 +1034,8 @@ class BaseDocument(object):
# TODO this method is WAY too complicated. Simplify it.
# TODO don't think returning a string for embedded non-existent fields is desired
- ListField = _import_class('ListField')
- DynamicField = _import_class('DynamicField')
+ ListField = _import_class("ListField")
+ DynamicField = _import_class("DynamicField")
if not isinstance(parts, (list, tuple)):
parts = [parts]
@@ -1000,15 +1051,17 @@ class BaseDocument(object):
# Look up first field from the document
if field is None:
- if field_name == 'pk':
+ if field_name == "pk":
# Deal with "primary key" alias
- field_name = cls._meta['id_field']
+ field_name = cls._meta["id_field"]
if field_name in cls._fields:
field = cls._fields[field_name]
elif cls._dynamic:
field = DynamicField(db_field=field_name)
- elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False):
+ elif cls._meta.get("allow_inheritance") or cls._meta.get(
+ "abstract", False
+ ):
# 744: in case the field is defined in a subclass
for subcls in cls.__subclasses__():
try:
@@ -1023,38 +1076,41 @@ class BaseDocument(object):
else:
raise LookUpError('Cannot resolve field "%s"' % field_name)
else:
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
+ ReferenceField = _import_class("ReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
# If previous field was a reference, throw an error (we
# cannot look up fields that are on references).
if isinstance(field, (ReferenceField, GenericReferenceField)):
- raise LookUpError('Cannot perform join in mongoDB: %s' %
- '__'.join(parts))
+ raise LookUpError(
+ "Cannot perform join in mongoDB: %s" % "__".join(parts)
+ )
# If the parent field has a "field" attribute which has a
# lookup_member method, call it to find the field
# corresponding to this iteration.
- if hasattr(getattr(field, 'field', None), 'lookup_member'):
+ if hasattr(getattr(field, "field", None), "lookup_member"):
new_field = field.field.lookup_member(field_name)
# If the parent field is a DynamicField or if it's part of
# a DynamicDocument, mark current field as a DynamicField
# with db_name equal to the field name.
- elif cls._dynamic and (isinstance(field, DynamicField) or
- getattr(getattr(field, 'document_type', None), '_dynamic', None)):
+ elif cls._dynamic and (
+ isinstance(field, DynamicField)
+ or getattr(getattr(field, "document_type", None), "_dynamic", None)
+ ):
new_field = DynamicField(db_field=field_name)
# Else, try to use the parent field's lookup_member method
# to find the subfield.
- elif hasattr(field, 'lookup_member'):
+ elif hasattr(field, "lookup_member"):
new_field = field.lookup_member(field_name)
# Raise a LookUpError if all the other conditions failed.
else:
raise LookUpError(
- 'Cannot resolve subfield or operator {} '
- 'on the field {}'.format(field_name, field.name)
+ "Cannot resolve subfield or operator {} "
+ "on the field {}".format(field_name, field.name)
)
# If current field still wasn't found and the parent field
@@ -1073,23 +1129,24 @@ class BaseDocument(object):
return fields
@classmethod
- def _translate_field_name(cls, field, sep='.'):
+ def _translate_field_name(cls, field, sep="."):
"""Translate a field attribute name to a database field name.
"""
parts = field.split(sep)
parts = [f.db_field for f in cls._lookup_field(parts)]
- return '.'.join(parts)
+ return ".".join(parts)
def __set_field_display(self):
"""For each field that specifies choices, create a
get__display method.
"""
- fields_with_choices = [(n, f) for n, f in self._fields.items()
- if f.choices]
+ fields_with_choices = [(n, f) for n, f in self._fields.items() if f.choices]
for attr_name, field in fields_with_choices:
- setattr(self,
- 'get_%s_display' % attr_name,
- partial(self.__get_field_display, field=field))
+ setattr(
+ self,
+ "get_%s_display" % attr_name,
+ partial(self.__get_field_display, field=field),
+ )
def __get_field_display(self, field):
"""Return the display value for a choice field"""
@@ -1097,9 +1154,16 @@ class BaseDocument(object):
if field.choices and isinstance(field.choices[0], (list, tuple)):
if value is None:
return None
- sep = getattr(field, 'display_sep', ' ')
- values = value if field.__class__.__name__ in ('ListField', 'SortedListField') else [value]
- return sep.join([
- six.text_type(dict(field.choices).get(val, val))
- for val in values or []])
+ sep = getattr(field, "display_sep", " ")
+ values = (
+ value
+ if field.__class__.__name__ in ("ListField", "SortedListField")
+ else [value]
+ )
+ return sep.join(
+ [
+ six.text_type(dict(field.choices).get(val, val))
+ for val in values or []
+ ]
+ )
return value
diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py
index 9ce426c9..cd1039cb 100644
--- a/mongoengine/base/fields.py
+++ b/mongoengine/base/fields.py
@@ -8,13 +8,11 @@ import six
from six import iteritems
from mongoengine.base.common import UPDATE_OPERATORS
-from mongoengine.base.datastructures import (BaseDict, BaseList,
- EmbeddedDocumentList)
+from mongoengine.base.datastructures import BaseDict, BaseList, EmbeddedDocumentList
from mongoengine.common import _import_class
from mongoengine.errors import DeprecatedError, ValidationError
-__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
- 'GeoJsonBaseField')
+__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
class BaseField(object):
@@ -23,6 +21,7 @@ class BaseField(object):
.. versionchanged:: 0.5 - added verbose and help text
"""
+
name = None
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
@@ -34,10 +33,21 @@ class BaseField(object):
creation_counter = 0
auto_creation_counter = -1
- def __init__(self, db_field=None, name=None, required=False, default=None,
- unique=False, unique_with=None, primary_key=False,
- validation=None, choices=None, null=False, sparse=False,
- **kwargs):
+ def __init__(
+ self,
+ db_field=None,
+ name=None,
+ required=False,
+ default=None,
+ unique=False,
+ unique_with=None,
+ primary_key=False,
+ validation=None,
+ choices=None,
+ null=False,
+ sparse=False,
+ **kwargs
+ ):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
@@ -65,7 +75,7 @@ class BaseField(object):
existing attributes. Common metadata includes `verbose_name` and
`help_text`.
"""
- self.db_field = (db_field or name) if not primary_key else '_id'
+ self.db_field = (db_field or name) if not primary_key else "_id"
if name:
msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
@@ -82,17 +92,16 @@ class BaseField(object):
self._owner_document = None
# Make sure db_field is a string (if it's explicitly defined).
- if (
- self.db_field is not None and
- not isinstance(self.db_field, six.string_types)
+ if self.db_field is not None and not isinstance(
+ self.db_field, six.string_types
):
- raise TypeError('db_field should be a string.')
+ raise TypeError("db_field should be a string.")
# Make sure db_field doesn't contain any forbidden characters.
if isinstance(self.db_field, six.string_types) and (
- '.' in self.db_field or
- '\0' in self.db_field or
- self.db_field.startswith('$')
+ "." in self.db_field
+ or "\0" in self.db_field
+ or self.db_field.startswith("$")
):
raise ValueError(
'field names cannot contain dots (".") or null characters '
@@ -102,15 +111,17 @@ class BaseField(object):
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
- raise TypeError('%s already has attribute(s): %s' % (
- self.__class__.__name__, ', '.join(conflicts)))
+ raise TypeError(
+ "%s already has attribute(s): %s"
+ % (self.__class__.__name__, ", ".join(conflicts))
+ )
# Assign metadata to the instance
# This efficient method is available because no __slots__ are defined.
self.__dict__.update(kwargs)
# Adjust the appropriate creation counter, and save our local copy.
- if self.db_field == '_id':
+ if self.db_field == "_id":
self.creation_counter = BaseField.auto_creation_counter
BaseField.auto_creation_counter -= 1
else:
@@ -142,8 +153,8 @@ class BaseField(object):
if instance._initialised:
try:
value_has_changed = (
- self.name not in instance._data or
- instance._data[self.name] != value
+ self.name not in instance._data
+ or instance._data[self.name] != value
)
if value_has_changed:
instance._mark_as_changed(self.name)
@@ -153,7 +164,7 @@ class BaseField(object):
# Mark the field as changed in such cases.
instance._mark_as_changed(self.name)
- EmbeddedDocument = _import_class('EmbeddedDocument')
+ EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument):
value._instance = weakref.proxy(instance)
elif isinstance(value, (list, tuple)):
@@ -163,7 +174,7 @@ class BaseField(object):
instance._data[self.name] = value
- def error(self, message='', errors=None, field_name=None):
+ def error(self, message="", errors=None, field_name=None):
"""Raise a ValidationError."""
field_name = field_name if field_name else self.name
raise ValidationError(message, errors=errors, field_name=field_name)
@@ -180,11 +191,11 @@ class BaseField(object):
"""Helper method to call to_mongo with proper inputs."""
f_inputs = self.to_mongo.__code__.co_varnames
ex_vars = {}
- if 'fields' in f_inputs:
- ex_vars['fields'] = fields
+ if "fields" in f_inputs:
+ ex_vars["fields"] = fields
- if 'use_db_field' in f_inputs:
- ex_vars['use_db_field'] = use_db_field
+ if "use_db_field" in f_inputs:
+ ex_vars["use_db_field"] = use_db_field
return self.to_mongo(value, **ex_vars)
@@ -199,8 +210,8 @@ class BaseField(object):
pass
def _validate_choices(self, value):
- Document = _import_class('Document')
- EmbeddedDocument = _import_class('EmbeddedDocument')
+ Document = _import_class("Document")
+ EmbeddedDocument = _import_class("EmbeddedDocument")
choice_list = self.choices
if isinstance(next(iter(choice_list)), (list, tuple)):
@@ -211,15 +222,13 @@ class BaseField(object):
if isinstance(value, (Document, EmbeddedDocument)):
if not any(isinstance(value, c) for c in choice_list):
self.error(
- 'Value must be an instance of %s' % (
- six.text_type(choice_list)
- )
+ "Value must be an instance of %s" % (six.text_type(choice_list))
)
# Choices which are types other than Documents
else:
values = value if isinstance(value, (list, tuple)) else [value]
if len(set(values) - set(choice_list)):
- self.error('Value must be one of %s' % six.text_type(choice_list))
+ self.error("Value must be one of %s" % six.text_type(choice_list))
def _validate(self, value, **kwargs):
# Check the Choices Constraint
@@ -235,13 +244,17 @@ class BaseField(object):
# in favor of having validation raising a ValidationError
ret = self.validation(value)
if ret is not None:
- raise DeprecatedError('validation argument for `%s` must not return anything, '
- 'it should raise a ValidationError if validation fails' % self.name)
+ raise DeprecatedError(
+ "validation argument for `%s` must not return anything, "
+ "it should raise a ValidationError if validation fails"
+ % self.name
+ )
except ValidationError as ex:
self.error(str(ex))
else:
- raise ValueError('validation argument for `"%s"` must be a '
- 'callable.' % self.name)
+ raise ValueError(
+ 'validation argument for `"%s"` must be a ' "callable." % self.name
+ )
self.validate(value, **kwargs)
@@ -275,35 +288,41 @@ class ComplexBaseField(BaseField):
# Document class being used rather than a document object
return self
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
+ ReferenceField = _import_class("ReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
+ EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField")
auto_dereference = instance._fields[self.name]._auto_dereference
- dereference = (auto_dereference and
- (self.field is None or isinstance(self.field,
- (GenericReferenceField, ReferenceField))))
+ dereference = auto_dereference and (
+ self.field is None
+ or isinstance(self.field, (GenericReferenceField, ReferenceField))
+ )
- _dereference = _import_class('DeReference')()
+ _dereference = _import_class("DeReference")()
- if (instance._initialised and
- dereference and
- instance._data.get(self.name) and
- not getattr(instance._data[self.name], '_dereferenced', False)):
+ if (
+ instance._initialised
+ and dereference
+ and instance._data.get(self.name)
+ and not getattr(instance._data[self.name], "_dereferenced", False)
+ ):
instance._data[self.name] = _dereference(
- instance._data.get(self.name), max_depth=1, instance=instance,
- name=self.name
+ instance._data.get(self.name),
+ max_depth=1,
+ instance=instance,
+ name=self.name,
)
- if hasattr(instance._data[self.name], '_dereferenced'):
+ if hasattr(instance._data[self.name], "_dereferenced"):
instance._data[self.name]._dereferenced = True
value = super(ComplexBaseField, self).__get__(instance, owner)
# Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)):
- if (issubclass(type(self), EmbeddedDocumentListField) and
- not isinstance(value, EmbeddedDocumentList)):
+ if issubclass(type(self), EmbeddedDocumentListField) and not isinstance(
+ value, EmbeddedDocumentList
+ ):
value = EmbeddedDocumentList(value, instance, self.name)
elif not isinstance(value, BaseList):
value = BaseList(value, instance, self.name)
@@ -312,12 +331,13 @@ class ComplexBaseField(BaseField):
value = BaseDict(value, instance, self.name)
instance._data[self.name] = value
- if (auto_dereference and instance._initialised and
- isinstance(value, (BaseList, BaseDict)) and
- not value._dereferenced):
- value = _dereference(
- value, max_depth=1, instance=instance, name=self.name
- )
+ if (
+ auto_dereference
+ and instance._initialised
+ and isinstance(value, (BaseList, BaseDict))
+ and not value._dereferenced
+ ):
+ value = _dereference(value, max_depth=1, instance=instance, name=self.name)
value._dereferenced = True
instance._data[self.name] = value
@@ -328,16 +348,16 @@ class ComplexBaseField(BaseField):
if isinstance(value, six.string_types):
return value
- if hasattr(value, 'to_python'):
+ if hasattr(value, "to_python"):
return value.to_python()
- BaseDocument = _import_class('BaseDocument')
+ BaseDocument = _import_class("BaseDocument")
if isinstance(value, BaseDocument):
# Something is wrong, return the value as it is
return value
is_list = False
- if not hasattr(value, 'items'):
+ if not hasattr(value, "items"):
try:
is_list = True
value = {idx: v for idx, v in enumerate(value)}
@@ -346,50 +366,54 @@ class ComplexBaseField(BaseField):
if self.field:
self.field._auto_dereference = self._auto_dereference
- value_dict = {key: self.field.to_python(item)
- for key, item in value.items()}
+ value_dict = {
+ key: self.field.to_python(item) for key, item in value.items()
+ }
else:
- Document = _import_class('Document')
+ Document = _import_class("Document")
value_dict = {}
for k, v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
- self.error('You can only reference documents once they'
- ' have been saved to the database')
+ self.error(
+ "You can only reference documents once they"
+ " have been saved to the database"
+ )
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
- elif hasattr(v, 'to_python'):
+ elif hasattr(v, "to_python"):
value_dict[k] = v.to_python()
else:
value_dict[k] = self.to_python(v)
if is_list: # Convert back to a list
- return [v for _, v in sorted(value_dict.items(),
- key=operator.itemgetter(0))]
+ return [
+ v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))
+ ]
return value_dict
def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB-compatible type."""
- Document = _import_class('Document')
- EmbeddedDocument = _import_class('EmbeddedDocument')
- GenericReferenceField = _import_class('GenericReferenceField')
+ Document = _import_class("Document")
+ EmbeddedDocument = _import_class("EmbeddedDocument")
+ GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, six.string_types):
return value
- if hasattr(value, 'to_mongo'):
+ if hasattr(value, "to_mongo"):
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo(use_db_field, fields)
# If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument):
- val['_cls'] = cls.__name__
+ val["_cls"] = cls.__name__
return val
is_list = False
- if not hasattr(value, 'items'):
+ if not hasattr(value, "items"):
try:
is_list = True
value = {k: v for k, v in enumerate(value)}
@@ -407,39 +431,42 @@ class ComplexBaseField(BaseField):
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
- self.error('You can only reference documents once they'
- ' have been saved to the database')
+ self.error(
+ "You can only reference documents once they"
+ " have been saved to the database"
+ )
# If its a document that is not inheritable it won't have
# any _cls data so make it a generic reference allows
# us to dereference
- meta = getattr(v, '_meta', {})
- allow_inheritance = meta.get('allow_inheritance')
+ meta = getattr(v, "_meta", {})
+ allow_inheritance = meta.get("allow_inheritance")
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
- elif hasattr(v, 'to_mongo'):
+ elif hasattr(v, "to_mongo"):
cls = v.__class__
val = v.to_mongo(use_db_field, fields)
# If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)):
- val['_cls'] = cls.__name__
+ val["_cls"] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v, use_db_field, fields)
if is_list: # Convert back to a list
- return [v for _, v in sorted(value_dict.items(),
- key=operator.itemgetter(0))]
+ return [
+ v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))
+ ]
return value_dict
def validate(self, value):
"""If field is provided ensure the value is valid."""
errors = {}
if self.field:
- if hasattr(value, 'iteritems') or hasattr(value, 'items'):
+ if hasattr(value, "iteritems") or hasattr(value, "items"):
sequence = iteritems(value)
else:
sequence = enumerate(value)
@@ -453,11 +480,10 @@ class ComplexBaseField(BaseField):
if errors:
field_class = self.field.__class__.__name__
- self.error('Invalid %s item (%s)' % (field_class, value),
- errors=errors)
+ self.error("Invalid %s item (%s)" % (field_class, value), errors=errors)
# Don't allow empty values if required
if self.required and not value:
- self.error('Field is required and cannot be empty')
+ self.error("Field is required and cannot be empty")
def prepare_query_value(self, op, value):
return self.to_mongo(value)
@@ -500,7 +526,7 @@ class ObjectIdField(BaseField):
try:
ObjectId(six.text_type(value))
except Exception:
- self.error('Invalid Object ID')
+ self.error("Invalid Object ID")
class GeoJsonBaseField(BaseField):
@@ -510,14 +536,14 @@ class GeoJsonBaseField(BaseField):
"""
_geo_index = pymongo.GEOSPHERE
- _type = 'GeoBase'
+ _type = "GeoBase"
def __init__(self, auto_index=True, *args, **kwargs):
"""
:param bool auto_index: Automatically create a '2dsphere' index.\
Defaults to `True`.
"""
- self._name = '%sField' % self._type
+ self._name = "%sField" % self._type
if not auto_index:
self._geo_index = False
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
@@ -525,57 +551,58 @@ class GeoJsonBaseField(BaseField):
def validate(self, value):
"""Validate the GeoJson object based on its type."""
if isinstance(value, dict):
- if set(value.keys()) == {'type', 'coordinates'}:
- if value['type'] != self._type:
- self.error('%s type must be "%s"' %
- (self._name, self._type))
- return self.validate(value['coordinates'])
+ if set(value.keys()) == {"type", "coordinates"}:
+ if value["type"] != self._type:
+ self.error('%s type must be "%s"' % (self._name, self._type))
+ return self.validate(value["coordinates"])
else:
- self.error('%s can only accept a valid GeoJson dictionary'
- ' or lists of (x, y)' % self._name)
+ self.error(
+ "%s can only accept a valid GeoJson dictionary"
+ " or lists of (x, y)" % self._name
+ )
return
elif not isinstance(value, (list, tuple)):
- self.error('%s can only accept lists of [x, y]' % self._name)
+ self.error("%s can only accept lists of [x, y]" % self._name)
return
- validate = getattr(self, '_validate_%s' % self._type.lower())
+ validate = getattr(self, "_validate_%s" % self._type.lower())
error = validate(value)
if error:
self.error(error)
def _validate_polygon(self, value, top_level=True):
if not isinstance(value, (list, tuple)):
- return 'Polygons must contain list of linestrings'
+ return "Polygons must contain list of linestrings"
# Quick and dirty validator
try:
value[0][0][0]
except (TypeError, IndexError):
- return 'Invalid Polygon must contain at least one valid linestring'
+ return "Invalid Polygon must contain at least one valid linestring"
errors = []
for val in value:
error = self._validate_linestring(val, False)
if not error and val[0] != val[-1]:
- error = 'LineStrings must start and end at the same point'
+ error = "LineStrings must start and end at the same point"
if error and error not in errors:
errors.append(error)
if errors:
if top_level:
- return 'Invalid Polygon:\n%s' % ', '.join(errors)
+ return "Invalid Polygon:\n%s" % ", ".join(errors)
else:
- return '%s' % ', '.join(errors)
+ return "%s" % ", ".join(errors)
def _validate_linestring(self, value, top_level=True):
"""Validate a linestring."""
if not isinstance(value, (list, tuple)):
- return 'LineStrings must contain list of coordinate pairs'
+ return "LineStrings must contain list of coordinate pairs"
# Quick and dirty validator
try:
value[0][0]
except (TypeError, IndexError):
- return 'Invalid LineString must contain at least one valid point'
+ return "Invalid LineString must contain at least one valid point"
errors = []
for val in value:
@@ -584,29 +611,30 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
if top_level:
- return 'Invalid LineString:\n%s' % ', '.join(errors)
+ return "Invalid LineString:\n%s" % ", ".join(errors)
else:
- return '%s' % ', '.join(errors)
+ return "%s" % ", ".join(errors)
def _validate_point(self, value):
"""Validate each set of coords"""
if not isinstance(value, (list, tuple)):
- return 'Points must be a list of coordinate pairs'
+ return "Points must be a list of coordinate pairs"
elif not len(value) == 2:
- return 'Value (%s) must be a two-dimensional point' % repr(value)
- elif (not isinstance(value[0], (float, int)) or
- not isinstance(value[1], (float, int))):
- return 'Both values (%s) in point must be float or int' % repr(value)
+ return "Value (%s) must be a two-dimensional point" % repr(value)
+ elif not isinstance(value[0], (float, int)) or not isinstance(
+ value[1], (float, int)
+ ):
+ return "Both values (%s) in point must be float or int" % repr(value)
def _validate_multipoint(self, value):
if not isinstance(value, (list, tuple)):
- return 'MultiPoint must be a list of Point'
+ return "MultiPoint must be a list of Point"
# Quick and dirty validator
try:
value[0][0]
except (TypeError, IndexError):
- return 'Invalid MultiPoint must contain at least one valid point'
+ return "Invalid MultiPoint must contain at least one valid point"
errors = []
for point in value:
@@ -615,17 +643,17 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
- return '%s' % ', '.join(errors)
+ return "%s" % ", ".join(errors)
def _validate_multilinestring(self, value, top_level=True):
if not isinstance(value, (list, tuple)):
- return 'MultiLineString must be a list of LineString'
+ return "MultiLineString must be a list of LineString"
# Quick and dirty validator
try:
value[0][0][0]
except (TypeError, IndexError):
- return 'Invalid MultiLineString must contain at least one valid linestring'
+ return "Invalid MultiLineString must contain at least one valid linestring"
errors = []
for linestring in value:
@@ -635,19 +663,19 @@ class GeoJsonBaseField(BaseField):
if errors:
if top_level:
- return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
+ return "Invalid MultiLineString:\n%s" % ", ".join(errors)
else:
- return '%s' % ', '.join(errors)
+ return "%s" % ", ".join(errors)
def _validate_multipolygon(self, value):
if not isinstance(value, (list, tuple)):
- return 'MultiPolygon must be a list of Polygon'
+ return "MultiPolygon must be a list of Polygon"
# Quick and dirty validator
try:
value[0][0][0][0]
except (TypeError, IndexError):
- return 'Invalid MultiPolygon must contain at least one valid Polygon'
+ return "Invalid MultiPolygon must contain at least one valid Polygon"
errors = []
for polygon in value:
@@ -656,9 +684,9 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
- return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
+ return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
def to_mongo(self, value):
if isinstance(value, dict):
return value
- return SON([('type', self._type), ('coordinates', value)])
+ return SON([("type", self._type), ("coordinates", value)])
diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py
index c3ced5bb..e4d26811 100644
--- a/mongoengine/base/metaclasses.py
+++ b/mongoengine/base/metaclasses.py
@@ -8,12 +8,15 @@ from mongoengine.base.common import _document_registry
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
from mongoengine.common import _import_class
from mongoengine.errors import InvalidDocumentError
-from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
- MultipleObjectsReturned,
- QuerySetManager)
+from mongoengine.queryset import (
+ DO_NOTHING,
+ DoesNotExist,
+ MultipleObjectsReturned,
+ QuerySetManager,
+)
-__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass')
+__all__ = ("DocumentMetaclass", "TopLevelDocumentMetaclass")
class DocumentMetaclass(type):
@@ -25,44 +28,46 @@ class DocumentMetaclass(type):
super_new = super(DocumentMetaclass, mcs).__new__
# If a base class just call super
- metaclass = attrs.get('my_metaclass')
+ metaclass = attrs.get("my_metaclass")
if metaclass and issubclass(metaclass, DocumentMetaclass):
return super_new(mcs, name, bases, attrs)
- attrs['_is_document'] = attrs.get('_is_document', False)
- attrs['_cached_reference_fields'] = []
+ attrs["_is_document"] = attrs.get("_is_document", False)
+ attrs["_cached_reference_fields"] = []
# EmbeddedDocuments could have meta data for inheritance
- if 'meta' in attrs:
- attrs['_meta'] = attrs.pop('meta')
+ if "meta" in attrs:
+ attrs["_meta"] = attrs.pop("meta")
# EmbeddedDocuments should inherit meta data
- if '_meta' not in attrs:
+ if "_meta" not in attrs:
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
- if hasattr(base, 'meta'):
+ if hasattr(base, "meta"):
meta.merge(base.meta)
- elif hasattr(base, '_meta'):
+ elif hasattr(base, "_meta"):
meta.merge(base._meta)
- attrs['_meta'] = meta
- attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract
+ attrs["_meta"] = meta
+ attrs["_meta"][
+ "abstract"
+ ] = False # 789: EmbeddedDocument shouldn't inherit abstract
# If allow_inheritance is True, add a "_cls" string field to the attrs
- if attrs['_meta'].get('allow_inheritance'):
- StringField = _import_class('StringField')
- attrs['_cls'] = StringField()
+ if attrs["_meta"].get("allow_inheritance"):
+ StringField = _import_class("StringField")
+ attrs["_cls"] = StringField()
# Handle document Fields
# Merge all fields from subclasses
doc_fields = {}
for base in flattened_bases[::-1]:
- if hasattr(base, '_fields'):
+ if hasattr(base, "_fields"):
doc_fields.update(base._fields)
# Standard object mixin - merge in any Fields
- if not hasattr(base, '_meta'):
+ if not hasattr(base, "_meta"):
base_fields = {}
for attr_name, attr_value in iteritems(base.__dict__):
if not isinstance(attr_value, BaseField):
@@ -85,27 +90,31 @@ class DocumentMetaclass(type):
doc_fields[attr_name] = attr_value
# Count names to ensure no db_field redefinitions
- field_names[attr_value.db_field] = field_names.get(
- attr_value.db_field, 0) + 1
+ field_names[attr_value.db_field] = (
+ field_names.get(attr_value.db_field, 0) + 1
+ )
# Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
if duplicate_db_fields:
- msg = ('Multiple db_fields defined for: %s ' %
- ', '.join(duplicate_db_fields))
+ msg = "Multiple db_fields defined for: %s " % ", ".join(duplicate_db_fields)
raise InvalidDocumentError(msg)
# Set _fields and db_field maps
- attrs['_fields'] = doc_fields
- attrs['_db_field_map'] = {k: getattr(v, 'db_field', k)
- for k, v in doc_fields.items()}
- attrs['_reverse_db_field_map'] = {
- v: k for k, v in attrs['_db_field_map'].items()
+ attrs["_fields"] = doc_fields
+ attrs["_db_field_map"] = {
+ k: getattr(v, "db_field", k) for k, v in doc_fields.items()
+ }
+ attrs["_reverse_db_field_map"] = {
+ v: k for k, v in attrs["_db_field_map"].items()
}
- attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
- (v.creation_counter, v.name)
- for v in itervalues(doc_fields)))
+ attrs["_fields_ordered"] = tuple(
+ i[1]
+ for i in sorted(
+ (v.creation_counter, v.name) for v in itervalues(doc_fields)
+ )
+ )
#
# Set document hierarchy
@@ -113,32 +122,34 @@ class DocumentMetaclass(type):
superclasses = ()
class_name = [name]
for base in flattened_bases:
- if (not getattr(base, '_is_base_cls', True) and
- not getattr(base, '_meta', {}).get('abstract', True)):
+ if not getattr(base, "_is_base_cls", True) and not getattr(
+ base, "_meta", {}
+ ).get("abstract", True):
# Collate hierarchy for _cls and _subclasses
class_name.append(base.__name__)
- if hasattr(base, '_meta'):
+ if hasattr(base, "_meta"):
# Warn if allow_inheritance isn't set and prevent
# inheritance of classes where inheritance is set to False
- allow_inheritance = base._meta.get('allow_inheritance')
- if not allow_inheritance and not base._meta.get('abstract'):
- raise ValueError('Document %s may not be subclassed. '
- 'To enable inheritance, use the "allow_inheritance" meta attribute.' %
- base.__name__)
+ allow_inheritance = base._meta.get("allow_inheritance")
+ if not allow_inheritance and not base._meta.get("abstract"):
+ raise ValueError(
+ "Document %s may not be subclassed. "
+ 'To enable inheritance, use the "allow_inheritance" meta attribute.'
+ % base.__name__
+ )
# Get superclasses from last base superclass
- document_bases = [b for b in flattened_bases
- if hasattr(b, '_class_name')]
+ document_bases = [b for b in flattened_bases if hasattr(b, "_class_name")]
if document_bases:
superclasses = document_bases[0]._superclasses
- superclasses += (document_bases[0]._class_name, )
+ superclasses += (document_bases[0]._class_name,)
- _cls = '.'.join(reversed(class_name))
- attrs['_class_name'] = _cls
- attrs['_superclasses'] = superclasses
- attrs['_subclasses'] = (_cls, )
- attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types
+ _cls = ".".join(reversed(class_name))
+ attrs["_class_name"] = _cls
+ attrs["_superclasses"] = superclasses
+ attrs["_subclasses"] = (_cls,)
+ attrs["_types"] = attrs["_subclasses"] # TODO depreciate _types
# Create the new_class
new_class = super_new(mcs, name, bases, attrs)
@@ -149,8 +160,12 @@ class DocumentMetaclass(type):
base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types
- (Document, EmbeddedDocument, DictField,
- CachedReferenceField) = mcs._import_classes()
+ (
+ Document,
+ EmbeddedDocument,
+ DictField,
+ CachedReferenceField,
+ ) = mcs._import_classes()
if issubclass(new_class, Document):
new_class._collection = None
@@ -169,52 +184,55 @@ class DocumentMetaclass(type):
for val in new_class.__dict__.values():
if isinstance(val, classmethod):
f = val.__get__(new_class)
- if hasattr(f, '__func__') and not hasattr(f, 'im_func'):
- f.__dict__.update({'im_func': getattr(f, '__func__')})
- if hasattr(f, '__self__') and not hasattr(f, 'im_self'):
- f.__dict__.update({'im_self': getattr(f, '__self__')})
+ if hasattr(f, "__func__") and not hasattr(f, "im_func"):
+ f.__dict__.update({"im_func": getattr(f, "__func__")})
+ if hasattr(f, "__self__") and not hasattr(f, "im_self"):
+ f.__dict__.update({"im_self": getattr(f, "__self__")})
# Handle delete rules
for field in itervalues(new_class._fields):
f = field
if f.owner_document is None:
f.owner_document = new_class
- delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
+ delete_rule = getattr(f, "reverse_delete_rule", DO_NOTHING)
if isinstance(f, CachedReferenceField):
if issubclass(new_class, EmbeddedDocument):
- raise InvalidDocumentError('CachedReferenceFields is not '
- 'allowed in EmbeddedDocuments')
+ raise InvalidDocumentError(
+ "CachedReferenceFields is not allowed in EmbeddedDocuments"
+ )
if f.auto_sync:
f.start_listener()
f.document_type._cached_reference_fields.append(f)
- if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
- delete_rule = getattr(f.field,
- 'reverse_delete_rule',
- DO_NOTHING)
+ if isinstance(f, ComplexBaseField) and hasattr(f, "field"):
+ delete_rule = getattr(f.field, "reverse_delete_rule", DO_NOTHING)
if isinstance(f, DictField) and delete_rule != DO_NOTHING:
- msg = ('Reverse delete rules are not supported '
- 'for %s (field: %s)' %
- (field.__class__.__name__, field.name))
+ msg = (
+ "Reverse delete rules are not supported "
+ "for %s (field: %s)" % (field.__class__.__name__, field.name)
+ )
raise InvalidDocumentError(msg)
f = field.field
if delete_rule != DO_NOTHING:
if issubclass(new_class, EmbeddedDocument):
- msg = ('Reverse delete rules are not supported for '
- 'EmbeddedDocuments (field: %s)' % field.name)
+ msg = (
+ "Reverse delete rules are not supported for "
+ "EmbeddedDocuments (field: %s)" % field.name
+ )
raise InvalidDocumentError(msg)
- f.document_type.register_delete_rule(new_class,
- field.name, delete_rule)
+ f.document_type.register_delete_rule(new_class, field.name, delete_rule)
- if (field.name and hasattr(Document, field.name) and
- EmbeddedDocument not in new_class.mro()):
- msg = ('%s is a document method and not a valid '
- 'field name' % field.name)
+ if (
+ field.name
+ and hasattr(Document, field.name)
+ and EmbeddedDocument not in new_class.mro()
+ ):
+ msg = "%s is a document method and not a valid field name" % field.name
raise InvalidDocumentError(msg)
return new_class
@@ -239,10 +257,10 @@ class DocumentMetaclass(type):
@classmethod
def _import_classes(mcs):
- Document = _import_class('Document')
- EmbeddedDocument = _import_class('EmbeddedDocument')
- DictField = _import_class('DictField')
- CachedReferenceField = _import_class('CachedReferenceField')
+ Document = _import_class("Document")
+ EmbeddedDocument = _import_class("EmbeddedDocument")
+ DictField = _import_class("DictField")
+ CachedReferenceField = _import_class("CachedReferenceField")
return Document, EmbeddedDocument, DictField, CachedReferenceField
@@ -256,65 +274,67 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
super_new = super(TopLevelDocumentMetaclass, mcs).__new__
# Set default _meta data if base class, otherwise get user defined meta
- if attrs.get('my_metaclass') == TopLevelDocumentMetaclass:
+ if attrs.get("my_metaclass") == TopLevelDocumentMetaclass:
# defaults
- attrs['_meta'] = {
- 'abstract': True,
- 'max_documents': None,
- 'max_size': None,
- 'ordering': [], # default ordering applied at runtime
- 'indexes': [], # indexes to be ensured at runtime
- 'id_field': None,
- 'index_background': False,
- 'index_drop_dups': False,
- 'index_opts': None,
- 'delete_rules': None,
-
+ attrs["_meta"] = {
+ "abstract": True,
+ "max_documents": None,
+ "max_size": None,
+ "ordering": [], # default ordering applied at runtime
+ "indexes": [], # indexes to be ensured at runtime
+ "id_field": None,
+ "index_background": False,
+ "index_drop_dups": False,
+ "index_opts": None,
+ "delete_rules": None,
# allow_inheritance can be True, False, and None. True means
# "allow inheritance", False means "don't allow inheritance",
# None means "do whatever your parent does, or don't allow
# inheritance if you're a top-level class".
- 'allow_inheritance': None,
+ "allow_inheritance": None,
}
- attrs['_is_base_cls'] = True
- attrs['_meta'].update(attrs.get('meta', {}))
+ attrs["_is_base_cls"] = True
+ attrs["_meta"].update(attrs.get("meta", {}))
else:
- attrs['_meta'] = attrs.get('meta', {})
+ attrs["_meta"] = attrs.get("meta", {})
# Explicitly set abstract to false unless set
- attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False)
- attrs['_is_base_cls'] = False
+ attrs["_meta"]["abstract"] = attrs["_meta"].get("abstract", False)
+ attrs["_is_base_cls"] = False
# Set flag marking as document class - as opposed to an object mixin
- attrs['_is_document'] = True
+ attrs["_is_document"] = True
# Ensure queryset_class is inherited
- if 'objects' in attrs:
- manager = attrs['objects']
- if hasattr(manager, 'queryset_class'):
- attrs['_meta']['queryset_class'] = manager.queryset_class
+ if "objects" in attrs:
+ manager = attrs["objects"]
+ if hasattr(manager, "queryset_class"):
+ attrs["_meta"]["queryset_class"] = manager.queryset_class
# Clean up top level meta
- if 'meta' in attrs:
- del attrs['meta']
+ if "meta" in attrs:
+ del attrs["meta"]
# Find the parent document class
- parent_doc_cls = [b for b in flattened_bases
- if b.__class__ == TopLevelDocumentMetaclass]
+ parent_doc_cls = [
+ b for b in flattened_bases if b.__class__ == TopLevelDocumentMetaclass
+ ]
parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0]
# Prevent classes setting collection different to their parents
# If parent wasn't an abstract class
- if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
- not parent_doc_cls._meta.get('abstract', True)):
- msg = 'Trying to set a collection on a subclass (%s)' % name
+ if (
+ parent_doc_cls
+ and "collection" in attrs.get("_meta", {})
+ and not parent_doc_cls._meta.get("abstract", True)
+ ):
+ msg = "Trying to set a collection on a subclass (%s)" % name
warnings.warn(msg, SyntaxWarning)
- del attrs['_meta']['collection']
+ del attrs["_meta"]["collection"]
# Ensure abstract documents have abstract bases
- if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
- if (parent_doc_cls and
- not parent_doc_cls._meta.get('abstract', False)):
- msg = 'Abstract document cannot have non-abstract base'
+ if attrs.get("_is_base_cls") or attrs["_meta"].get("abstract"):
+ if parent_doc_cls and not parent_doc_cls._meta.get("abstract", False):
+ msg = "Abstract document cannot have non-abstract base"
raise ValueError(msg)
return super_new(mcs, name, bases, attrs)
@@ -323,38 +343,43 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
- if hasattr(base, 'meta'):
+ if hasattr(base, "meta"):
meta.merge(base.meta)
- elif hasattr(base, '_meta'):
+ elif hasattr(base, "_meta"):
meta.merge(base._meta)
# Set collection in the meta if its callable
- if (getattr(base, '_is_document', False) and
- not base._meta.get('abstract')):
- collection = meta.get('collection', None)
+ if getattr(base, "_is_document", False) and not base._meta.get("abstract"):
+ collection = meta.get("collection", None)
if callable(collection):
- meta['collection'] = collection(base)
+ meta["collection"] = collection(base)
- meta.merge(attrs.get('_meta', {})) # Top level meta
+ meta.merge(attrs.get("_meta", {})) # Top level meta
# Only simple classes (i.e. direct subclasses of Document) may set
# allow_inheritance to False. If the base Document allows inheritance,
# none of its subclasses can override allow_inheritance to False.
- simple_class = all([b._meta.get('abstract')
- for b in flattened_bases if hasattr(b, '_meta')])
+ simple_class = all(
+ [b._meta.get("abstract") for b in flattened_bases if hasattr(b, "_meta")]
+ )
if (
- not simple_class and
- meta['allow_inheritance'] is False and
- not meta['abstract']
+ not simple_class
+ and meta["allow_inheritance"] is False
+ and not meta["abstract"]
):
- raise ValueError('Only direct subclasses of Document may set '
- '"allow_inheritance" to False')
+ raise ValueError(
+ "Only direct subclasses of Document may set "
+ '"allow_inheritance" to False'
+ )
# Set default collection name
- if 'collection' not in meta:
- meta['collection'] = ''.join('_%s' % c if c.isupper() else c
- for c in name).strip('_').lower()
- attrs['_meta'] = meta
+ if "collection" not in meta:
+ meta["collection"] = (
+ "".join("_%s" % c if c.isupper() else c for c in name)
+ .strip("_")
+ .lower()
+ )
+ attrs["_meta"] = meta
# Call super and get the new class
new_class = super_new(mcs, name, bases, attrs)
@@ -362,36 +387,36 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
meta = new_class._meta
# Set index specifications
- meta['index_specs'] = new_class._build_index_specs(meta['indexes'])
+ meta["index_specs"] = new_class._build_index_specs(meta["indexes"])
# If collection is a callable - call it and set the value
- collection = meta.get('collection')
+ collection = meta.get("collection")
if callable(collection):
- new_class._meta['collection'] = collection(new_class)
+ new_class._meta["collection"] = collection(new_class)
# Provide a default queryset unless exists or one has been set
- if 'objects' not in dir(new_class):
+ if "objects" not in dir(new_class):
new_class.objects = QuerySetManager()
# Validate the fields and set primary key if needed
for field_name, field in iteritems(new_class._fields):
if field.primary_key:
# Ensure only one primary key is set
- current_pk = new_class._meta.get('id_field')
+ current_pk = new_class._meta.get("id_field")
if current_pk and current_pk != field_name:
- raise ValueError('Cannot override primary key field')
+ raise ValueError("Cannot override primary key field")
# Set primary key
if not current_pk:
- new_class._meta['id_field'] = field_name
+ new_class._meta["id_field"] = field_name
new_class.id = field
# If the document doesn't explicitly define a primary key field, create
# one. Make it an ObjectIdField and give it a non-clashing name ("id"
# by default, but can be different if that one's taken).
- if not new_class._meta.get('id_field'):
+ if not new_class._meta.get("id_field"):
id_name, id_db_name = mcs.get_auto_id_names(new_class)
- new_class._meta['id_field'] = id_name
+ new_class._meta["id_field"] = id_name
new_class._fields[id_name] = ObjectIdField(db_field=id_db_name)
new_class._fields[id_name].name = id_name
new_class.id = new_class._fields[id_name]
@@ -400,22 +425,20 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Prepend the ID field to _fields_ordered (so that it's *always*
# the first field).
- new_class._fields_ordered = (id_name, ) + new_class._fields_ordered
+ new_class._fields_ordered = (id_name,) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy.
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)
- module = attrs.get('__module__')
+ module = attrs.get("__module__")
for exc in exceptions_to_merge:
name = exc.__name__
parents = tuple(
- getattr(base, name)
- for base in flattened_bases
- if hasattr(base, name)
+ getattr(base, name) for base in flattened_bases if hasattr(base, name)
) or (exc,)
# Create a new exception and set it as an attribute on the new
# class.
- exception = type(name, parents, {'__module__': module})
+ exception = type(name, parents, {"__module__": module})
setattr(new_class, name, exception)
return new_class
@@ -431,23 +454,17 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
Defaults to ('id', '_id'), or generates a non-clashing name in the form
of ('auto_id_X', '_auto_id_X') if the default name is already taken.
"""
- id_name, id_db_name = ('id', '_id')
+ id_name, id_db_name = ("id", "_id")
existing_fields = {field_name for field_name in new_class._fields}
existing_db_fields = {v.db_field for v in new_class._fields.values()}
- if (
- id_name not in existing_fields and
- id_db_name not in existing_db_fields
- ):
+ if id_name not in existing_fields and id_db_name not in existing_db_fields:
return id_name, id_db_name
- id_basename, id_db_basename, i = ('auto_id', '_auto_id', 0)
+ id_basename, id_db_basename, i = ("auto_id", "_auto_id", 0)
for i in itertools.count():
- id_name = '{0}_{1}'.format(id_basename, i)
- id_db_name = '{0}_{1}'.format(id_db_basename, i)
- if (
- id_name not in existing_fields and
- id_db_name not in existing_db_fields
- ):
+ id_name = "{0}_{1}".format(id_basename, i)
+ id_db_name = "{0}_{1}".format(id_db_basename, i)
+ if id_name not in existing_fields and id_db_name not in existing_db_fields:
return id_name, id_db_name
@@ -455,7 +472,8 @@ class MetaDict(dict):
"""Custom dictionary for meta classes.
Handles the merging of set indexes
"""
- _merge_options = ('indexes',)
+
+ _merge_options = ("indexes",)
def merge(self, new_options):
for k, v in iteritems(new_options):
@@ -467,4 +485,5 @@ class MetaDict(dict):
class BasesTuple(tuple):
"""Special class to handle introspection of bases tuple in __new__"""
+
pass
diff --git a/mongoengine/common.py b/mongoengine/common.py
index bcdea194..640384ec 100644
--- a/mongoengine/common.py
+++ b/mongoengine/common.py
@@ -19,34 +19,44 @@ def _import_class(cls_name):
if cls_name in _class_registry_cache:
return _class_registry_cache.get(cls_name)
- doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument',
- 'MapReduceDocument')
+ doc_classes = (
+ "Document",
+ "DynamicEmbeddedDocument",
+ "EmbeddedDocument",
+ "MapReduceDocument",
+ )
# Field Classes
if not _field_list_cache:
from mongoengine.fields import __all__ as fields
+
_field_list_cache.extend(fields)
from mongoengine.base.fields import __all__ as fields
+
_field_list_cache.extend(fields)
field_classes = _field_list_cache
- deref_classes = ('DeReference',)
+ deref_classes = ("DeReference",)
- if cls_name == 'BaseDocument':
+ if cls_name == "BaseDocument":
from mongoengine.base import document as module
- import_classes = ['BaseDocument']
+
+ import_classes = ["BaseDocument"]
elif cls_name in doc_classes:
from mongoengine import document as module
+
import_classes = doc_classes
elif cls_name in field_classes:
from mongoengine import fields as module
+
import_classes = field_classes
elif cls_name in deref_classes:
from mongoengine import dereference as module
+
import_classes = deref_classes
else:
- raise ValueError('No import set for: %s' % cls_name)
+ raise ValueError("No import set for: %s" % cls_name)
for cls in import_classes:
_class_registry_cache[cls] = getattr(module, cls)
diff --git a/mongoengine/connection.py b/mongoengine/connection.py
index 6a613a42..ef0dd27c 100644
--- a/mongoengine/connection.py
+++ b/mongoengine/connection.py
@@ -3,21 +3,21 @@ from pymongo.database import _check_name
import six
__all__ = [
- 'DEFAULT_CONNECTION_NAME',
- 'DEFAULT_DATABASE_NAME',
- 'MongoEngineConnectionError',
- 'connect',
- 'disconnect',
- 'disconnect_all',
- 'get_connection',
- 'get_db',
- 'register_connection',
+ "DEFAULT_CONNECTION_NAME",
+ "DEFAULT_DATABASE_NAME",
+ "MongoEngineConnectionError",
+ "connect",
+ "disconnect",
+ "disconnect_all",
+ "get_connection",
+ "get_db",
+ "register_connection",
]
-DEFAULT_CONNECTION_NAME = 'default'
-DEFAULT_DATABASE_NAME = 'test'
-DEFAULT_HOST = 'localhost'
+DEFAULT_CONNECTION_NAME = "default"
+DEFAULT_DATABASE_NAME = "test"
+DEFAULT_HOST = "localhost"
DEFAULT_PORT = 27017
_connection_settings = {}
@@ -31,6 +31,7 @@ class MongoEngineConnectionError(Exception):
"""Error raised when the database connection can't be established or
when a connection with a requested alias can't be retrieved.
"""
+
pass
@@ -39,18 +40,23 @@ def _check_db_name(name):
This functionality is copied from pymongo Database class constructor.
"""
if not isinstance(name, six.string_types):
- raise TypeError('name must be an instance of %s' % six.string_types)
- elif name != '$external':
+ raise TypeError("name must be an instance of %s" % six.string_types)
+ elif name != "$external":
_check_name(name)
def _get_connection_settings(
- db=None, name=None, host=None, port=None,
- read_preference=READ_PREFERENCE,
- username=None, password=None,
- authentication_source=None,
- authentication_mechanism=None,
- **kwargs):
+ db=None,
+ name=None,
+ host=None,
+ port=None,
+ read_preference=READ_PREFERENCE,
+ username=None,
+ password=None,
+ authentication_source=None,
+ authentication_mechanism=None,
+ **kwargs
+):
"""Get the connection settings as a dict
: param db: the name of the database to use, for compatibility with connect
@@ -73,18 +79,18 @@ def _get_connection_settings(
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = {
- 'name': name or db or DEFAULT_DATABASE_NAME,
- 'host': host or DEFAULT_HOST,
- 'port': port or DEFAULT_PORT,
- 'read_preference': read_preference,
- 'username': username,
- 'password': password,
- 'authentication_source': authentication_source,
- 'authentication_mechanism': authentication_mechanism
+ "name": name or db or DEFAULT_DATABASE_NAME,
+ "host": host or DEFAULT_HOST,
+ "port": port or DEFAULT_PORT,
+ "read_preference": read_preference,
+ "username": username,
+ "password": password,
+ "authentication_source": authentication_source,
+ "authentication_mechanism": authentication_mechanism,
}
- _check_db_name(conn_settings['name'])
- conn_host = conn_settings['host']
+ _check_db_name(conn_settings["name"])
+ conn_host = conn_settings["host"]
# Host can be a list or a string, so if string, force to a list.
if isinstance(conn_host, six.string_types):
@@ -94,32 +100,32 @@ def _get_connection_settings(
for entity in conn_host:
# Handle Mongomock
- if entity.startswith('mongomock://'):
- conn_settings['is_mock'] = True
+ if entity.startswith("mongomock://"):
+ conn_settings["is_mock"] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
- resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
+ resolved_hosts.append(entity.replace("mongomock://", "mongodb://", 1))
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
- elif '://' in entity:
+ elif "://" in entity:
uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity)
- if uri_dict.get('database'):
- conn_settings['name'] = uri_dict.get('database')
+ if uri_dict.get("database"):
+ conn_settings["name"] = uri_dict.get("database")
- for param in ('read_preference', 'username', 'password'):
+ for param in ("read_preference", "username", "password"):
if uri_dict.get(param):
conn_settings[param] = uri_dict[param]
- uri_options = uri_dict['options']
- if 'replicaset' in uri_options:
- conn_settings['replicaSet'] = uri_options['replicaset']
- if 'authsource' in uri_options:
- conn_settings['authentication_source'] = uri_options['authsource']
- if 'authmechanism' in uri_options:
- conn_settings['authentication_mechanism'] = uri_options['authmechanism']
- if 'readpreference' in uri_options:
+ uri_options = uri_dict["options"]
+ if "replicaset" in uri_options:
+ conn_settings["replicaSet"] = uri_options["replicaset"]
+ if "authsource" in uri_options:
+ conn_settings["authentication_source"] = uri_options["authsource"]
+ if "authmechanism" in uri_options:
+ conn_settings["authentication_mechanism"] = uri_options["authmechanism"]
+ if "readpreference" in uri_options:
read_preferences = (
ReadPreference.NEAREST,
ReadPreference.PRIMARY,
@@ -133,34 +139,41 @@ def _get_connection_settings(
# int (e.g. 3).
# TODO simplify the code below once we drop support for
# PyMongo v3.4.
- read_pf_mode = uri_options['readpreference']
+ read_pf_mode = uri_options["readpreference"]
if isinstance(read_pf_mode, six.string_types):
read_pf_mode = read_pf_mode.lower()
for preference in read_preferences:
if (
- preference.name.lower() == read_pf_mode or
- preference.mode == read_pf_mode
+ preference.name.lower() == read_pf_mode
+ or preference.mode == read_pf_mode
):
- conn_settings['read_preference'] = preference
+ conn_settings["read_preference"] = preference
break
else:
resolved_hosts.append(entity)
- conn_settings['host'] = resolved_hosts
+ conn_settings["host"] = resolved_hosts
# Deprecated parameters that should not be passed on
- kwargs.pop('slaves', None)
- kwargs.pop('is_slave', None)
+ kwargs.pop("slaves", None)
+ kwargs.pop("is_slave", None)
conn_settings.update(kwargs)
return conn_settings
-def register_connection(alias, db=None, name=None, host=None, port=None,
- read_preference=READ_PREFERENCE,
- username=None, password=None,
- authentication_source=None,
- authentication_mechanism=None,
- **kwargs):
+def register_connection(
+ alias,
+ db=None,
+ name=None,
+ host=None,
+ port=None,
+ read_preference=READ_PREFERENCE,
+ username=None,
+ password=None,
+ authentication_source=None,
+ authentication_mechanism=None,
+ **kwargs
+):
"""Register the connection settings.
: param alias: the name that will be used to refer to this connection
@@ -185,12 +198,17 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = _get_connection_settings(
- db=db, name=name, host=host, port=port,
+ db=db,
+ name=name,
+ host=host,
+ port=port,
read_preference=read_preference,
- username=username, password=password,
+ username=username,
+ password=password,
authentication_source=authentication_source,
authentication_mechanism=authentication_mechanism,
- **kwargs)
+ **kwargs
+ )
_connection_settings[alias] = conn_settings
@@ -206,7 +224,7 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME):
if alias in _dbs:
# Detach all cached collections in Documents
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
- if issubclass(doc_cls, Document): # Skip EmbeddedDocument
+ if issubclass(doc_cls, Document): # Skip EmbeddedDocument
doc_cls._disconnect()
del _dbs[alias]
@@ -237,19 +255,21 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Raise MongoEngineConnectionError if it doesn't.
if alias not in _connection_settings:
if alias == DEFAULT_CONNECTION_NAME:
- msg = 'You have not defined a default connection'
+ msg = "You have not defined a default connection"
else:
msg = 'Connection with alias "%s" has not been defined' % alias
raise MongoEngineConnectionError(msg)
def _clean_settings(settings_dict):
irrelevant_fields_set = {
- 'name', 'username', 'password',
- 'authentication_source', 'authentication_mechanism'
+ "name",
+ "username",
+ "password",
+ "authentication_source",
+ "authentication_mechanism",
}
return {
- k: v for k, v in settings_dict.items()
- if k not in irrelevant_fields_set
+ k: v for k, v in settings_dict.items() if k not in irrelevant_fields_set
}
raw_conn_settings = _connection_settings[alias].copy()
@@ -260,13 +280,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings = _clean_settings(raw_conn_settings)
# Determine if we should use PyMongo's or mongomock's MongoClient.
- is_mock = conn_settings.pop('is_mock', False)
+ is_mock = conn_settings.pop("is_mock", False)
if is_mock:
try:
import mongomock
except ImportError:
- raise RuntimeError('You need mongomock installed to mock '
- 'MongoEngine.')
+ raise RuntimeError("You need mongomock installed to mock MongoEngine.")
connection_class = mongomock.MongoClient
else:
connection_class = MongoClient
@@ -277,9 +296,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection = existing_connection
else:
connection = _create_connection(
- alias=alias,
- connection_class=connection_class,
- **conn_settings
+ alias=alias, connection_class=connection_class, **conn_settings
)
_connections[alias] = connection
return _connections[alias]
@@ -294,7 +311,8 @@ def _create_connection(alias, connection_class, **connection_settings):
return connection_class(**connection_settings)
except Exception as e:
raise MongoEngineConnectionError(
- 'Cannot connect to database %s :\n%s' % (alias, e))
+ "Cannot connect to database %s :\n%s" % (alias, e)
+ )
def _find_existing_connection(connection_settings):
@@ -316,7 +334,7 @@ def _find_existing_connection(connection_settings):
# Only remove the name but it's important to
# keep the username/password/authentication_source/authentication_mechanism
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
- return {k: v for k, v in settings_dict.items() if k != 'name'}
+ return {k: v for k, v in settings_dict.items() if k != "name"}
cleaned_conn_settings = _clean_settings(connection_settings)
for db_alias, connection_settings in connection_settings_bis:
@@ -332,14 +350,18 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
if alias not in _dbs:
conn = get_connection(alias)
conn_settings = _connection_settings[alias]
- db = conn[conn_settings['name']]
- auth_kwargs = {'source': conn_settings['authentication_source']}
- if conn_settings['authentication_mechanism'] is not None:
- auth_kwargs['mechanism'] = conn_settings['authentication_mechanism']
+ db = conn[conn_settings["name"]]
+ auth_kwargs = {"source": conn_settings["authentication_source"]}
+ if conn_settings["authentication_mechanism"] is not None:
+ auth_kwargs["mechanism"] = conn_settings["authentication_mechanism"]
# Authenticate if necessary
- if conn_settings['username'] and (conn_settings['password'] or
- conn_settings['authentication_mechanism'] == 'MONGODB-X509'):
- db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs)
+ if conn_settings["username"] and (
+ conn_settings["password"]
+ or conn_settings["authentication_mechanism"] == "MONGODB-X509"
+ ):
+ db.authenticate(
+ conn_settings["username"], conn_settings["password"], **auth_kwargs
+ )
_dbs[alias] = db
return _dbs[alias]
@@ -368,8 +390,8 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
if new_conn_settings != prev_conn_setting:
err_msg = (
- u'A different connection with alias `{}` was already '
- u'registered. Use disconnect() first'
+ u"A different connection with alias `{}` was already "
+ u"registered. Use disconnect() first"
).format(alias)
raise MongoEngineConnectionError(err_msg)
else:
diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py
index 98bd897b..3424a5d5 100644
--- a/mongoengine/context_managers.py
+++ b/mongoengine/context_managers.py
@@ -7,8 +7,14 @@ from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
-__all__ = ('switch_db', 'switch_collection', 'no_dereference',
- 'no_sub_classes', 'query_counter', 'set_write_concern')
+__all__ = (
+ "switch_db",
+ "switch_collection",
+ "no_dereference",
+ "no_sub_classes",
+ "query_counter",
+ "set_write_concern",
+)
class switch_db(object):
@@ -38,17 +44,17 @@ class switch_db(object):
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
- self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
+ self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
"""Change the db_alias and clear the cached collection."""
- self.cls._meta['db_alias'] = self.db_alias
+ self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
"""Reset the db_alias and collection."""
- self.cls._meta['db_alias'] = self.ori_db_alias
+ self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
@@ -111,14 +117,15 @@ class no_dereference(object):
"""
self.cls = cls
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- ComplexBaseField = _import_class('ComplexBaseField')
+ ReferenceField = _import_class("ReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
+ ComplexBaseField = _import_class("ComplexBaseField")
- self.deref_fields = [k for k, v in iteritems(self.cls._fields)
- if isinstance(v, (ReferenceField,
- GenericReferenceField,
- ComplexBaseField))]
+ self.deref_fields = [
+ k
+ for k, v in iteritems(self.cls._fields)
+ if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
+ ]
def __enter__(self):
"""Change the objects default and _auto_dereference values."""
@@ -180,15 +187,12 @@ class query_counter(object):
"""
self.db = get_db()
self.initial_profiling_level = None
- self._ctx_query_counter = 0 # number of queries issued by the context
+ self._ctx_query_counter = 0 # number of queries issued by the context
self._ignored_query = {
- 'ns':
- {'$ne': '%s.system.indexes' % self.db.name},
- 'op': # MONGODB < 3.2
- {'$ne': 'killcursors'},
- 'command.killCursors': # MONGODB >= 3.2
- {'$exists': False}
+ "ns": {"$ne": "%s.system.indexes" % self.db.name},
+ "op": {"$ne": "killcursors"}, # MONGODB < 3.2
+ "command.killCursors": {"$exists": False}, # MONGODB >= 3.2
}
def _turn_on_profiling(self):
@@ -238,8 +242,13 @@ class query_counter(object):
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
issued so we need to balance that
"""
- count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter
- self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
+ count = (
+ count_documents(self.db.system.profile, self._ignored_query)
+ - self._ctx_query_counter
+ )
+ self._ctx_query_counter += (
+ 1
+ ) # Account for the query we just issued to gather the information
return count
diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py
index eaebb56f..9e75f353 100644
--- a/mongoengine/dereference.py
+++ b/mongoengine/dereference.py
@@ -2,8 +2,13 @@ from bson import DBRef, SON
import six
from six import iteritems
-from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
- TopLevelDocumentMetaclass, get_document)
+from mongoengine.base import (
+ BaseDict,
+ BaseList,
+ EmbeddedDocumentList,
+ TopLevelDocumentMetaclass,
+ get_document,
+)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument
@@ -36,21 +41,23 @@ class DeReference(object):
self.max_depth = max_depth
doc_type = None
- if instance and isinstance(instance, (Document, EmbeddedDocument,
- TopLevelDocumentMetaclass)):
+ if instance and isinstance(
+ instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)
+ ):
doc_type = instance._fields.get(name)
- while hasattr(doc_type, 'field'):
+ while hasattr(doc_type, "field"):
doc_type = doc_type.field
if isinstance(doc_type, ReferenceField):
field = doc_type
doc_type = doc_type.document_type
- is_list = not hasattr(items, 'items')
+ is_list = not hasattr(items, "items")
if is_list and all([i.__class__ == doc_type for i in items]):
return items
elif not is_list and all(
- [i.__class__ == doc_type for i in items.values()]):
+ [i.__class__ == doc_type for i in items.values()]
+ ):
return items
elif not field.dbref:
# We must turn the ObjectIds into DBRefs
@@ -83,7 +90,7 @@ class DeReference(object):
new_items[k] = value
return new_items
- if not hasattr(items, 'items'):
+ if not hasattr(items, "items"):
items = _get_items_from_list(items)
else:
items = _get_items_from_dict(items)
@@ -120,13 +127,19 @@ class DeReference(object):
continue
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
- elif isinstance(v, (dict, SON)) and '_ref' in v:
- reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
+ elif isinstance(v, (dict, SON)) and "_ref" in v:
+ reference_map.setdefault(get_document(v["_cls"]), set()).add(
+ v["_ref"].id
+ )
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
- field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
+ field_cls = getattr(
+ getattr(field, "field", None), "document_type", None
+ )
references = self._find_references(v, depth)
for key, refs in iteritems(references):
- if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
+ if isinstance(
+ field_cls, (Document, TopLevelDocumentMetaclass)
+ ):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
elif isinstance(item, LazyReference):
@@ -134,8 +147,10 @@ class DeReference(object):
continue
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
- elif isinstance(item, (dict, SON)) and '_ref' in item:
- reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id)
+ elif isinstance(item, (dict, SON)) and "_ref" in item:
+ reference_map.setdefault(get_document(item["_cls"]), set()).add(
+ item["_ref"].id
+ )
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in iteritems(references):
@@ -151,12 +166,13 @@ class DeReference(object):
# we use getattr instead of hasattr because hasattr swallows any exception under python2
# so it could hide nasty things without raising exceptions (cfr bug #1688))
- ref_document_cls_exists = (getattr(collection, 'objects', None) is not None)
+ ref_document_cls_exists = getattr(collection, "objects", None) is not None
if ref_document_cls_exists:
col_name = collection._get_collection_name()
- refs = [dbref for dbref in dbrefs
- if (col_name, dbref) not in object_map]
+ refs = [
+ dbref for dbref in dbrefs if (col_name, dbref) not in object_map
+ ]
references = collection.objects.in_bulk(refs)
for key, doc in iteritems(references):
object_map[(col_name, key)] = doc
@@ -164,23 +180,26 @@ class DeReference(object):
if isinstance(doc_type, (ListField, DictField, MapField)):
continue
- refs = [dbref for dbref in dbrefs
- if (collection, dbref) not in object_map]
+ refs = [
+ dbref for dbref in dbrefs if (collection, dbref) not in object_map
+ ]
if doc_type:
- references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
+ references = doc_type._get_db()[collection].find(
+ {"_id": {"$in": refs}}
+ )
for ref in references:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
else:
- references = get_db()[collection].find({'_id': {'$in': refs}})
+ references = get_db()[collection].find({"_id": {"$in": refs}})
for ref in references:
- if '_cls' in ref:
- doc = get_document(ref['_cls'])._from_son(ref)
+ if "_cls" in ref:
+ doc = get_document(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
- ''.join(x.capitalize()
- for x in collection.split('_')))._from_son(ref)
+ "".join(x.capitalize() for x in collection.split("_"))
+ )._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
@@ -208,19 +227,20 @@ class DeReference(object):
return BaseList(items, instance, name)
if isinstance(items, (dict, SON)):
- if '_ref' in items:
+ if "_ref" in items:
return self.object_map.get(
- (items['_ref'].collection, items['_ref'].id), items)
- elif '_cls' in items:
- doc = get_document(items['_cls'])._from_son(items)
- _cls = doc._data.pop('_cls', None)
- del items['_cls']
+ (items["_ref"].collection, items["_ref"].id), items
+ )
+ elif "_cls" in items:
+ doc = get_document(items["_cls"])._from_son(items)
+ _cls = doc._data.pop("_cls", None)
+ del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
if _cls is not None:
- doc._data['_cls'] = _cls
+ doc._data["_cls"] = _cls
return doc
- if not hasattr(items, 'items'):
+ if not hasattr(items, "items"):
is_list = True
list_type = BaseList
if isinstance(items, EmbeddedDocumentList):
@@ -247,17 +267,25 @@ class DeReference(object):
v = data[k]._data.get(field_name, None)
if isinstance(v, DBRef):
data[k]._data[field_name] = self.object_map.get(
- (v.collection, v.id), v)
- elif isinstance(v, (dict, SON)) and '_ref' in v:
+ (v.collection, v.id), v
+ )
+ elif isinstance(v, (dict, SON)) and "_ref" in v:
data[k]._data[field_name] = self.object_map.get(
- (v['_ref'].collection, v['_ref'].id), v)
+ (v["_ref"].collection, v["_ref"].id), v
+ )
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
- item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
- data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
+ item_name = six.text_type("{0}.{1}.{2}").format(
+ name, k, field_name
+ )
+ data[k]._data[field_name] = self._attach_objects(
+ v, depth, instance=instance, name=item_name
+ )
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
- item_name = '%s.%s' % (name, k) if name else name
- data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
- elif isinstance(v, DBRef) and hasattr(v, 'id'):
+ item_name = "%s.%s" % (name, k) if name else name
+ data[k] = self._attach_objects(
+ v, depth - 1, instance=instance, name=item_name
+ )
+ elif isinstance(v, DBRef) and hasattr(v, "id"):
data[k] = self.object_map.get((v.collection, v.id), v)
if instance and name:
diff --git a/mongoengine/document.py b/mongoengine/document.py
index 520de5bf..41166df4 100644
--- a/mongoengine/document.py
+++ b/mongoengine/document.py
@@ -8,23 +8,36 @@ import six
from six import iteritems
from mongoengine import signals
-from mongoengine.base import (BaseDict, BaseDocument, BaseList,
- DocumentMetaclass, EmbeddedDocumentList,
- TopLevelDocumentMetaclass, get_document)
+from mongoengine.base import (
+ BaseDict,
+ BaseDocument,
+ BaseList,
+ DocumentMetaclass,
+ EmbeddedDocumentList,
+ TopLevelDocumentMetaclass,
+ get_document,
+)
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
-from mongoengine.context_managers import (set_write_concern,
- switch_collection,
- switch_db)
-from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
- SaveConditionError)
+from mongoengine.context_managers import set_write_concern, switch_collection, switch_db
+from mongoengine.errors import (
+ InvalidDocumentError,
+ InvalidQueryError,
+ SaveConditionError,
+)
from mongoengine.pymongo_support import list_collection_names
-from mongoengine.queryset import (NotUniqueError, OperationError,
- QuerySet, transform)
+from mongoengine.queryset import NotUniqueError, OperationError, QuerySet, transform
-__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
- 'DynamicEmbeddedDocument', 'OperationError',
- 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument')
+__all__ = (
+ "Document",
+ "EmbeddedDocument",
+ "DynamicDocument",
+ "DynamicEmbeddedDocument",
+ "OperationError",
+ "InvalidCollectionError",
+ "NotUniqueError",
+ "MapReduceDocument",
+)
def includes_cls(fields):
@@ -35,7 +48,7 @@ def includes_cls(fields):
first_field = fields[0]
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
first_field = fields[0][0]
- return first_field == '_cls'
+ return first_field == "_cls"
class InvalidCollectionError(Exception):
@@ -56,7 +69,7 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
:attr:`meta` dictionary.
"""
- __slots__ = ('_instance', )
+ __slots__ = ("_instance",)
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
@@ -85,8 +98,8 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs)
# remove _id from the SON if it's in it and it's None
- if '_id' in data and data['_id'] is None:
- del data['_id']
+ if "_id" in data and data["_id"] is None:
+ del data["_id"]
return data
@@ -147,19 +160,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = TopLevelDocumentMetaclass
- __slots__ = ('__objects',)
+ __slots__ = ("__objects",)
@property
def pk(self):
"""Get the primary key."""
- if 'id_field' not in self._meta:
+ if "id_field" not in self._meta:
return None
- return getattr(self, self._meta['id_field'])
+ return getattr(self, self._meta["id_field"])
@pk.setter
def pk(self, value):
"""Set the primary key."""
- return setattr(self, self._meta['id_field'], value)
+ return setattr(self, self._meta["id_field"], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
@@ -173,7 +186,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
- return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
+ return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
@classmethod
def _disconnect(cls):
@@ -190,9 +203,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
2. Creates indexes defined in this document's :attr:`meta` dictionary.
This happens only if `auto_create_index` is True.
"""
- if not hasattr(cls, '_collection') or cls._collection is None:
+ if not hasattr(cls, "_collection") or cls._collection is None:
# Get the collection, either capped or regular.
- if cls._meta.get('max_size') or cls._meta.get('max_documents'):
+ if cls._meta.get("max_size") or cls._meta.get("max_documents"):
cls._collection = cls._get_capped_collection()
else:
db = cls._get_db()
@@ -203,8 +216,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# set to False.
# Also there is no need to ensure indexes on slave.
db = cls._get_db()
- if cls._meta.get('auto_create_index', True) and\
- db.client.is_primary:
+ if cls._meta.get("auto_create_index", True) and db.client.is_primary:
cls.ensure_indexes()
return cls._collection
@@ -216,8 +228,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
collection_name = cls._get_collection_name()
# Get max document limit and max byte size from meta.
- max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
- max_documents = cls._meta.get('max_documents')
+ max_size = cls._meta.get("max_size") or 10 * 2 ** 20 # 10MB default
+ max_documents = cls._meta.get("max_documents")
# MongoDB will automatically raise the size to make it a multiple of
# 256 bytes. We raise it here ourselves to be able to reliably compare
@@ -227,24 +239,23 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# If the collection already exists and has different options
# (i.e. isn't capped or has different max/size), raise an error.
- if collection_name in list_collection_names(db, include_system_collections=True):
+ if collection_name in list_collection_names(
+ db, include_system_collections=True
+ ):
collection = db[collection_name]
options = collection.options()
- if (
- options.get('max') != max_documents or
- options.get('size') != max_size
- ):
+ if options.get("max") != max_documents or options.get("size") != max_size:
raise InvalidCollectionError(
'Cannot create collection "{}" as a capped '
- 'collection as it already exists'.format(cls._collection)
+ "collection as it already exists".format(cls._collection)
)
return collection
# Create a new capped collection.
- opts = {'capped': True, 'size': max_size}
+ opts = {"capped": True, "size": max_size}
if max_documents:
- opts['max'] = max_documents
+ opts["max"] = max_documents
return db.create_collection(collection_name, **opts)
@@ -253,11 +264,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# If '_id' is None, try and set it from self._data. If that
# doesn't exist either, remove '_id' from the SON completely.
- if data['_id'] is None:
- if self._data.get('id') is None:
- del data['_id']
+ if data["_id"] is None:
+ if self._data.get("id") is None:
+ del data["_id"]
else:
- data['_id'] = self._data['id']
+ data["_id"] = self._data["id"]
return data
@@ -279,15 +290,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
query = {}
if self.pk is None:
- raise InvalidDocumentError('The document does not have a primary key.')
+ raise InvalidDocumentError("The document does not have a primary key.")
- id_field = self._meta['id_field']
+ id_field = self._meta["id_field"]
query = query.copy() if isinstance(query, dict) else query.to_query(self)
if id_field not in query:
query[id_field] = self.pk
elif query[id_field] != self.pk:
- raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
+ raise InvalidQueryError(
+ "Invalid document modify query: it must modify only this document."
+ )
# Need to add shard key to query, or you get an error
query.update(self._object_key)
@@ -304,9 +317,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
return True
- def save(self, force_insert=False, validate=True, clean=True,
- write_concern=None, cascade=None, cascade_kwargs=None,
- _refs=None, save_condition=None, signal_kwargs=None, **kwargs):
+ def save(
+ self,
+ force_insert=False,
+ validate=True,
+ clean=True,
+ write_concern=None,
+ cascade=None,
+ cascade_kwargs=None,
+ _refs=None,
+ save_condition=None,
+ signal_kwargs=None,
+ **kwargs
+ ):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -360,8 +383,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
signal_kwargs = signal_kwargs or {}
- if self._meta.get('abstract'):
- raise InvalidDocumentError('Cannot save an abstract document.')
+ if self._meta.get("abstract"):
+ raise InvalidDocumentError("Cannot save an abstract document.")
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
@@ -371,15 +394,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if write_concern is None:
write_concern = {}
- doc_id = self.to_mongo(fields=[self._meta['id_field']])
- created = ('_id' not in doc_id or self._created or force_insert)
+ doc_id = self.to_mongo(fields=[self._meta["id_field"]])
+ created = "_id" not in doc_id or self._created or force_insert
- signals.pre_save_post_validation.send(self.__class__, document=self,
- created=created, **signal_kwargs)
+ signals.pre_save_post_validation.send(
+ self.__class__, document=self, created=created, **signal_kwargs
+ )
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()
- if self._meta.get('auto_create_index', True):
+ if self._meta.get("auto_create_index", True):
self.ensure_indexes()
try:
@@ -387,44 +411,45 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
- object_id, created = self._save_update(doc, save_condition,
- write_concern)
+ object_id, created = self._save_update(
+ doc, save_condition, write_concern
+ )
if cascade is None:
- cascade = (self._meta.get('cascade', False) or
- cascade_kwargs is not None)
+ cascade = self._meta.get("cascade", False) or cascade_kwargs is not None
if cascade:
kwargs = {
- 'force_insert': force_insert,
- 'validate': validate,
- 'write_concern': write_concern,
- 'cascade': cascade
+ "force_insert": force_insert,
+ "validate": validate,
+ "write_concern": write_concern,
+ "cascade": cascade,
}
if cascade_kwargs: # Allow granular control over cascades
kwargs.update(cascade_kwargs)
- kwargs['_refs'] = _refs
+ kwargs["_refs"] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err:
- message = u'Tried to save duplicate unique keys (%s)'
+ message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err:
- message = 'Could not save document (%s)'
- if re.match('^E1100[01] duplicate key', six.text_type(err)):
+ message = "Could not save document (%s)"
+ if re.match("^E1100[01] duplicate key", six.text_type(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
- message = u'Tried to save duplicate unique keys (%s)'
+ message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
- id_field = self._meta['id_field']
- if created or id_field not in self._meta.get('shard_key', []):
+ id_field = self._meta["id_field"]
+ if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)
- signals.post_save.send(self.__class__, document=self,
- created=created, **signal_kwargs)
+ signals.post_save.send(
+ self.__class__, document=self, created=created, **signal_kwargs
+ )
self._clear_changed_fields()
self._created = False
@@ -442,11 +467,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
return wc_collection.insert_one(doc).inserted_id
# insert_one will provoke UniqueError alongside save does not
# therefore, it need to catch and call replace_one.
- if '_id' in doc:
+ if "_id" in doc:
raw_object = wc_collection.find_one_and_replace(
- {'_id': doc['_id']}, doc)
+ {"_id": doc["_id"]}, doc
+ )
if raw_object:
- return doc['_id']
+ return doc["_id"]
object_id = wc_collection.insert_one(doc).inserted_id
@@ -461,9 +487,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
update_doc = {}
if updates:
- update_doc['$set'] = updates
+ update_doc["$set"] = updates
if removals:
- update_doc['$unset'] = removals
+ update_doc["$unset"] = removals
return update_doc
@@ -473,39 +499,38 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
Helper method, should only be used inside save().
"""
collection = self._get_collection()
- object_id = doc['_id']
+ object_id = doc["_id"]
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
- select_dict['_id'] = object_id
+ select_dict["_id"] = object_id
# Need to add shard key to query, or you get an error
- shard_key = self._meta.get('shard_key', tuple())
+ shard_key = self._meta.get("shard_key", tuple())
for k in shard_key:
- path = self._lookup_field(k.split('.'))
+ path = self._lookup_field(k.split("."))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
- select_dict['.'.join(actual_key)] = val
+ select_dict[".".join(actual_key)] = val
update_doc = self._get_update_doc()
if update_doc:
upsert = save_condition is None
with set_write_concern(collection, write_concern) as wc_collection:
last_error = wc_collection.update_one(
- select_dict,
- update_doc,
- upsert=upsert
+ select_dict, update_doc, upsert=upsert
).raw_result
- if not upsert and last_error['n'] == 0:
- raise SaveConditionError('Race condition preventing'
- ' document update detected')
+ if not upsert and last_error["n"] == 0:
+ raise SaveConditionError(
+ "Race condition preventing document update detected"
+ )
if last_error is not None:
- updated_existing = last_error.get('updatedExisting')
+ updated_existing = last_error.get("updatedExisting")
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
@@ -518,21 +543,20 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""Recursively save any references and generic references on the
document.
"""
- _refs = kwargs.get('_refs') or []
+ _refs = kwargs.get("_refs") or []
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
+ ReferenceField = _import_class("ReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
for name, cls in self._fields.items():
- if not isinstance(cls, (ReferenceField,
- GenericReferenceField)):
+ if not isinstance(cls, (ReferenceField, GenericReferenceField)):
continue
ref = self._data.get(name)
if not ref or isinstance(ref, DBRef):
continue
- if not getattr(ref, '_changed_fields', True):
+ if not getattr(ref, "_changed_fields", True):
continue
ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data))
@@ -545,7 +569,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
@property
def _qs(self):
"""Return the default queryset corresponding to this document."""
- if not hasattr(self, '__objects'):
+ if not hasattr(self, "__objects"):
self.__objects = QuerySet(self, self._get_collection())
return self.__objects
@@ -558,15 +582,15 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
a sharded collection with a compound shard key, it can contain a more
complex query.
"""
- select_dict = {'pk': self.pk}
- shard_key = self.__class__._meta.get('shard_key', tuple())
+ select_dict = {"pk": self.pk}
+ shard_key = self.__class__._meta.get("shard_key", tuple())
for k in shard_key:
- path = self._lookup_field(k.split('.'))
+ path = self._lookup_field(k.split("."))
actual_key = [p.db_field for p in path]
val = self
for ak in actual_key:
val = getattr(val, ak)
- select_dict['__'.join(actual_key)] = val
+ select_dict["__".join(actual_key)] = val
return select_dict
def update(self, **kwargs):
@@ -577,14 +601,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
been saved.
"""
if self.pk is None:
- if kwargs.get('upsert', False):
+ if kwargs.get("upsert", False):
query = self.to_mongo()
- if '_cls' in query:
- del query['_cls']
+ if "_cls" in query:
+ del query["_cls"]
return self._qs.filter(**query).update_one(**kwargs)
else:
- raise OperationError(
- 'attempt to update a document not yet saved')
+ raise OperationError("attempt to update a document not yet saved")
# Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs)
@@ -608,16 +631,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
signals.pre_delete.send(self.__class__, document=self, **signal_kwargs)
# Delete FileFields separately
- FileField = _import_class('FileField')
+ FileField = _import_class("FileField")
for name, field in iteritems(self._fields):
if isinstance(field, FileField):
getattr(self, name).delete()
try:
- self._qs.filter(
- **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
+ self._qs.filter(**self._object_key).delete(
+ write_concern=write_concern, _from_doc_delete=True
+ )
except pymongo.errors.OperationFailure as err:
- message = u'Could not delete document (%s)' % err.message
+ message = u"Could not delete document (%s)" % err.message
raise OperationError(message)
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
@@ -686,7 +710,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. versionadded:: 0.5
"""
- DeReference = _import_class('DeReference')
+ DeReference = _import_class("DeReference")
DeReference()([self], max_depth + 1)
return self
@@ -704,20 +728,24 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
- elif 'max_depth' in kwargs:
- max_depth = kwargs['max_depth']
+ elif "max_depth" in kwargs:
+ max_depth = kwargs["max_depth"]
if self.pk is None:
- raise self.DoesNotExist('Document does not exist')
+ raise self.DoesNotExist("Document does not exist")
- obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
- **self._object_key).only(*fields).limit(
- 1).select_related(max_depth=max_depth)
+ obj = (
+ self._qs.read_preference(ReadPreference.PRIMARY)
+ .filter(**self._object_key)
+ .only(*fields)
+ .limit(1)
+ .select_related(max_depth=max_depth)
+ )
if obj:
obj = obj[0]
else:
- raise self.DoesNotExist('Document does not exist')
+ raise self.DoesNotExist("Document does not exist")
for field in obj._data:
if not fields or field in fields:
try:
@@ -733,9 +761,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
- self._changed_fields = list(
- set(self._changed_fields) - set(fields)
- ) if fields else obj._changed_fields
+ self._changed_fields = (
+ list(set(self._changed_fields) - set(fields))
+ if fields
+ else obj._changed_fields
+ )
self._created = False
return self
@@ -761,7 +791,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if self.pk is None:
- msg = 'Only saved documents can have a valid dbref'
+ msg = "Only saved documents can have a valid dbref"
raise OperationError(msg)
return DBRef(self.__class__._get_collection_name(), self.pk)
@@ -770,18 +800,22 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""This method registers the delete rules to apply when removing this
object.
"""
- classes = [get_document(class_name)
- for class_name in cls._subclasses
- if class_name != cls.__name__] + [cls]
- documents = [get_document(class_name)
- for class_name in document_cls._subclasses
- if class_name != document_cls.__name__] + [document_cls]
+ classes = [
+ get_document(class_name)
+ for class_name in cls._subclasses
+ if class_name != cls.__name__
+ ] + [cls]
+ documents = [
+ get_document(class_name)
+ for class_name in document_cls._subclasses
+ if class_name != document_cls.__name__
+ ] + [document_cls]
for klass in classes:
for document_cls in documents:
- delete_rules = klass._meta.get('delete_rules') or {}
+ delete_rules = klass._meta.get("delete_rules") or {}
delete_rules[(document_cls, field_name)] = rule
- klass._meta['delete_rules'] = delete_rules
+ klass._meta["delete_rules"] = delete_rules
@classmethod
def drop_collection(cls):
@@ -796,8 +830,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
coll_name = cls._get_collection_name()
if not coll_name:
- raise OperationError('Document %s has no collection defined '
- '(is it abstract ?)' % cls)
+ raise OperationError(
+ "Document %s has no collection defined (is it abstract ?)" % cls
+ )
cls._collection = None
db = cls._get_db()
db.drop_collection(coll_name)
@@ -813,19 +848,18 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
index_spec = cls._build_index_spec(keys)
index_spec = index_spec.copy()
- fields = index_spec.pop('fields')
- drop_dups = kwargs.get('drop_dups', False)
+ fields = index_spec.pop("fields")
+ drop_dups = kwargs.get("drop_dups", False)
if drop_dups:
- msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
+ msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
- index_spec['background'] = background
+ index_spec["background"] = background
index_spec.update(kwargs)
return cls._get_collection().create_index(fields, **index_spec)
@classmethod
- def ensure_index(cls, key_or_list, drop_dups=False, background=False,
- **kwargs):
+ def ensure_index(cls, key_or_list, drop_dups=False, background=False, **kwargs):
"""Ensure that the given indexes are in place. Deprecated in favour
of create_index.
@@ -837,7 +871,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
will be removed if PyMongo3+ is used
"""
if drop_dups:
- msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
+ msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
return cls.create_index(key_or_list, background=background, **kwargs)
@@ -850,12 +884,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. note:: You can disable automatic index creation by setting
`auto_create_index` to False in the documents meta data
"""
- background = cls._meta.get('index_background', False)
- drop_dups = cls._meta.get('index_drop_dups', False)
- index_opts = cls._meta.get('index_opts') or {}
- index_cls = cls._meta.get('index_cls', True)
+ background = cls._meta.get("index_background", False)
+ drop_dups = cls._meta.get("index_drop_dups", False)
+ index_opts = cls._meta.get("index_opts") or {}
+ index_cls = cls._meta.get("index_cls", True)
if drop_dups:
- msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
+ msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
collection = cls._get_collection()
@@ -871,40 +905,39 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
cls_indexed = False
# Ensure document-defined indexes are created
- if cls._meta['index_specs']:
- index_spec = cls._meta['index_specs']
+ if cls._meta["index_specs"]:
+ index_spec = cls._meta["index_specs"]
for spec in index_spec:
spec = spec.copy()
- fields = spec.pop('fields')
+ fields = spec.pop("fields")
cls_indexed = cls_indexed or includes_cls(fields)
opts = index_opts.copy()
opts.update(spec)
# we shouldn't pass 'cls' to the collection.ensureIndex options
# because of https://jira.mongodb.org/browse/SERVER-769
- if 'cls' in opts:
- del opts['cls']
+ if "cls" in opts:
+ del opts["cls"]
collection.create_index(fields, background=background, **opts)
# If _cls is being used (for polymorphism), it needs an index,
# only if another index doesn't begin with _cls
- if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'):
+ if index_cls and not cls_indexed and cls._meta.get("allow_inheritance"):
# we shouldn't pass 'cls' to the collection.ensureIndex options
# because of https://jira.mongodb.org/browse/SERVER-769
- if 'cls' in index_opts:
- del index_opts['cls']
+ if "cls" in index_opts:
+ del index_opts["cls"]
- collection.create_index('_cls', background=background,
- **index_opts)
+ collection.create_index("_cls", background=background, **index_opts)
@classmethod
def list_indexes(cls):
""" Lists all of the indexes that should be created for given
collection. It includes all the indexes from super- and sub-classes.
"""
- if cls._meta.get('abstract'):
+ if cls._meta.get("abstract"):
return []
# get all the base classes, subclasses and siblings
@@ -912,22 +945,27 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
def get_classes(cls):
- if (cls not in classes and
- isinstance(cls, TopLevelDocumentMetaclass)):
+ if cls not in classes and isinstance(cls, TopLevelDocumentMetaclass):
classes.append(cls)
for base_cls in cls.__bases__:
- if (isinstance(base_cls, TopLevelDocumentMetaclass) and
- base_cls != Document and
- not base_cls._meta.get('abstract') and
- base_cls._get_collection().full_name == cls._get_collection().full_name and
- base_cls not in classes):
+ if (
+ isinstance(base_cls, TopLevelDocumentMetaclass)
+ and base_cls != Document
+ and not base_cls._meta.get("abstract")
+ and base_cls._get_collection().full_name
+ == cls._get_collection().full_name
+ and base_cls not in classes
+ ):
classes.append(base_cls)
get_classes(base_cls)
for subclass in cls.__subclasses__():
- if (isinstance(base_cls, TopLevelDocumentMetaclass) and
- subclass._get_collection().full_name == cls._get_collection().full_name and
- subclass not in classes):
+ if (
+ isinstance(base_cls, TopLevelDocumentMetaclass)
+ and subclass._get_collection().full_name
+ == cls._get_collection().full_name
+ and subclass not in classes
+ ):
classes.append(subclass)
get_classes(subclass)
@@ -937,11 +975,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
def get_indexes_spec(cls):
indexes = []
- if cls._meta['index_specs']:
- index_spec = cls._meta['index_specs']
+ if cls._meta["index_specs"]:
+ index_spec = cls._meta["index_specs"]
for spec in index_spec:
spec = spec.copy()
- fields = spec.pop('fields')
+ fields = spec.pop("fields")
indexes.append(fields)
return indexes
@@ -952,10 +990,10 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
indexes.append(index)
# finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed
- if [(u'_id', 1)] not in indexes:
- indexes.append([(u'_id', 1)])
- if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'):
- indexes.append([(u'_cls', 1)])
+ if [(u"_id", 1)] not in indexes:
+ indexes.append([(u"_id", 1)])
+ if cls._meta.get("index_cls", True) and cls._meta.get("allow_inheritance"):
+ indexes.append([(u"_cls", 1)])
return indexes
@@ -969,27 +1007,26 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
existing = []
for info in cls._get_collection().index_information().values():
- if '_fts' in info['key'][0]:
- index_type = info['key'][0][1]
- text_index_fields = info.get('weights').keys()
- existing.append(
- [(key, index_type) for key in text_index_fields])
+ if "_fts" in info["key"][0]:
+ index_type = info["key"][0][1]
+ text_index_fields = info.get("weights").keys()
+ existing.append([(key, index_type) for key in text_index_fields])
else:
- existing.append(info['key'])
+ existing.append(info["key"])
missing = [index for index in required if index not in existing]
extra = [index for index in existing if index not in required]
# if { _cls: 1 } is missing, make sure it's *really* necessary
- if [(u'_cls', 1)] in missing:
+ if [(u"_cls", 1)] in missing:
cls_obsolete = False
for index in existing:
if includes_cls(index) and index not in extra:
cls_obsolete = True
break
if cls_obsolete:
- missing.remove([(u'_cls', 1)])
+ missing.remove([(u"_cls", 1)])
- return {'missing': missing, 'extra': extra}
+ return {"missing": missing, "extra": extra}
class DynamicDocument(six.with_metaclass(TopLevelDocumentMetaclass, Document)):
@@ -1074,17 +1111,16 @@ class MapReduceDocument(object):
"""Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``.
"""
- id_field = self._document()._meta['id_field']
+ id_field = self._document()._meta["id_field"]
id_field_type = type(id_field)
if not isinstance(self.key, id_field_type):
try:
self.key = id_field_type(self.key)
except Exception:
- raise Exception('Could not cast key as %s' %
- id_field_type.__name__)
+ raise Exception("Could not cast key as %s" % id_field_type.__name__)
- if not hasattr(self, '_key_object'):
+ if not hasattr(self, "_key_object"):
self._key_object = self._document.objects.with_id(self.key)
return self._key_object
return self._key_object
diff --git a/mongoengine/errors.py b/mongoengine/errors.py
index bea1d3dc..9852f2a1 100644
--- a/mongoengine/errors.py
+++ b/mongoengine/errors.py
@@ -3,10 +3,20 @@ from collections import defaultdict
import six
from six import iteritems
-__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
- 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
- 'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
- 'ValidationError', 'SaveConditionError', 'DeprecatedError')
+__all__ = (
+ "NotRegistered",
+ "InvalidDocumentError",
+ "LookUpError",
+ "DoesNotExist",
+ "MultipleObjectsReturned",
+ "InvalidQueryError",
+ "OperationError",
+ "NotUniqueError",
+ "FieldDoesNotExist",
+ "ValidationError",
+ "SaveConditionError",
+ "DeprecatedError",
+)
class NotRegistered(Exception):
@@ -71,25 +81,25 @@ class ValidationError(AssertionError):
field_name = None
_message = None
- def __init__(self, message='', **kwargs):
+ def __init__(self, message="", **kwargs):
super(ValidationError, self).__init__(message)
- self.errors = kwargs.get('errors', {})
- self.field_name = kwargs.get('field_name')
+ self.errors = kwargs.get("errors", {})
+ self.field_name = kwargs.get("field_name")
self.message = message
def __str__(self):
return six.text_type(self.message)
def __repr__(self):
- return '%s(%s,)' % (self.__class__.__name__, self.message)
+ return "%s(%s,)" % (self.__class__.__name__, self.message)
def __getattribute__(self, name):
message = super(ValidationError, self).__getattribute__(name)
- if name == 'message':
+ if name == "message":
if self.field_name:
- message = '%s' % message
+ message = "%s" % message
if self.errors:
- message = '%s(%s)' % (message, self._format_errors())
+ message = "%s(%s)" % (message, self._format_errors())
return message
def _get_message(self):
@@ -128,22 +138,22 @@ class ValidationError(AssertionError):
def _format_errors(self):
"""Returns a string listing all errors within a document"""
- def generate_key(value, prefix=''):
+ def generate_key(value, prefix=""):
if isinstance(value, list):
- value = ' '.join([generate_key(k) for k in value])
+ value = " ".join([generate_key(k) for k in value])
elif isinstance(value, dict):
- value = ' '.join(
- [generate_key(v, k) for k, v in iteritems(value)])
+ value = " ".join([generate_key(v, k) for k, v in iteritems(value)])
- results = '%s.%s' % (prefix, value) if prefix else value
+ results = "%s.%s" % (prefix, value) if prefix else value
return results
error_dict = defaultdict(list)
for k, v in iteritems(self.to_dict()):
error_dict[generate_key(v)].append(k)
- return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)])
+ return " ".join(["%s: %s" % (k, v) for k, v in iteritems(error_dict)])
class DeprecatedError(Exception):
"""Raise when a user uses a feature that has been Deprecated"""
+
pass
diff --git a/mongoengine/fields.py b/mongoengine/fields.py
index 2a4a2ad8..7ab2276d 100644
--- a/mongoengine/fields.py
+++ b/mongoengine/fields.py
@@ -27,9 +27,15 @@ except ImportError:
Int64 = long
-from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
- GeoJsonBaseField, LazyReference, ObjectIdField,
- get_document)
+from mongoengine.base import (
+ BaseDocument,
+ BaseField,
+ ComplexBaseField,
+ GeoJsonBaseField,
+ LazyReference,
+ ObjectIdField,
+ get_document,
+)
from mongoengine.base.utils import LazyRegexCompiler
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
@@ -53,21 +59,51 @@ if six.PY3:
__all__ = (
- 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField',
- 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'DateField',
- 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField',
- 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField',
- 'SortedListField', 'EmbeddedDocumentListField', 'DictField',
- 'MapField', 'ReferenceField', 'CachedReferenceField',
- 'LazyReferenceField', 'GenericLazyReferenceField',
- 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy',
- 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField',
- 'GeoPointField', 'PointField', 'LineStringField', 'PolygonField',
- 'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField',
- 'MultiPolygonField', 'GeoJsonBaseField'
+ "StringField",
+ "URLField",
+ "EmailField",
+ "IntField",
+ "LongField",
+ "FloatField",
+ "DecimalField",
+ "BooleanField",
+ "DateTimeField",
+ "DateField",
+ "ComplexDateTimeField",
+ "EmbeddedDocumentField",
+ "ObjectIdField",
+ "GenericEmbeddedDocumentField",
+ "DynamicField",
+ "ListField",
+ "SortedListField",
+ "EmbeddedDocumentListField",
+ "DictField",
+ "MapField",
+ "ReferenceField",
+ "CachedReferenceField",
+ "LazyReferenceField",
+ "GenericLazyReferenceField",
+ "GenericReferenceField",
+ "BinaryField",
+ "GridFSError",
+ "GridFSProxy",
+ "FileField",
+ "ImageGridFsProxy",
+ "ImproperlyConfigured",
+ "ImageField",
+ "GeoPointField",
+ "PointField",
+ "LineStringField",
+ "PolygonField",
+ "SequenceField",
+ "UUIDField",
+ "MultiPointField",
+ "MultiLineStringField",
+ "MultiPolygonField",
+ "GeoJsonBaseField",
)
-RECURSIVE_REFERENCE_CONSTANT = 'self'
+RECURSIVE_REFERENCE_CONSTANT = "self"
class StringField(BaseField):
@@ -83,23 +119,23 @@ class StringField(BaseField):
if isinstance(value, six.text_type):
return value
try:
- value = value.decode('utf-8')
+ value = value.decode("utf-8")
except Exception:
pass
return value
def validate(self, value):
if not isinstance(value, six.string_types):
- self.error('StringField only accepts string values')
+ self.error("StringField only accepts string values")
if self.max_length is not None and len(value) > self.max_length:
- self.error('String value is too long')
+ self.error("String value is too long")
if self.min_length is not None and len(value) < self.min_length:
- self.error('String value is too short')
+ self.error("String value is too short")
if self.regex is not None and self.regex.match(value) is None:
- self.error('String value did not match validation regex')
+ self.error("String value did not match validation regex")
def lookup_member(self, member_name):
return None
@@ -109,18 +145,18 @@ class StringField(BaseField):
return value
if op in STRING_OPERATORS:
- case_insensitive = op.startswith('i')
- op = op.lstrip('i')
+ case_insensitive = op.startswith("i")
+ op = op.lstrip("i")
flags = re.IGNORECASE if case_insensitive else 0
- regex = r'%s'
- if op == 'startswith':
- regex = r'^%s'
- elif op == 'endswith':
- regex = r'%s$'
- elif op == 'exact':
- regex = r'^%s$'
+ regex = r"%s"
+ if op == "startswith":
+ regex = r"^%s"
+ elif op == "endswith":
+ regex = r"%s$"
+ elif op == "exact":
+ regex = r"^%s$"
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
@@ -135,14 +171,16 @@ class URLField(StringField):
"""
_URL_REGEX = LazyRegexCompiler(
- r'^(?:[a-z0-9\.\-]*)://' # scheme is validated separately
- r'(?:(?:[A-Z0-9](?:[A-Z0-9-_]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(? self.max_value:
- self.error('Integer value is too large')
+ self.error("Integer value is too large")
def prepare_query_value(self, op, value):
if value is None:
@@ -319,13 +365,13 @@ class LongField(BaseField):
try:
value = long(value)
except (TypeError, ValueError):
- self.error('%s could not be converted to long' % value)
+ self.error("%s could not be converted to long" % value)
if self.min_value is not None and value < self.min_value:
- self.error('Long value is too small')
+ self.error("Long value is too small")
if self.max_value is not None and value > self.max_value:
- self.error('Long value is too large')
+ self.error("Long value is too large")
def prepare_query_value(self, op, value):
if value is None:
@@ -353,16 +399,16 @@ class FloatField(BaseField):
try:
value = float(value)
except OverflowError:
- self.error('The value is too large to be converted to float')
+ self.error("The value is too large to be converted to float")
if not isinstance(value, float):
- self.error('FloatField only accepts float and integer values')
+ self.error("FloatField only accepts float and integer values")
if self.min_value is not None and value < self.min_value:
- self.error('Float value is too small')
+ self.error("Float value is too small")
if self.max_value is not None and value > self.max_value:
- self.error('Float value is too large')
+ self.error("Float value is too large")
def prepare_query_value(self, op, value):
if value is None:
@@ -379,8 +425,15 @@ class DecimalField(BaseField):
.. versionadded:: 0.3
"""
- def __init__(self, min_value=None, max_value=None, force_string=False,
- precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs):
+ def __init__(
+ self,
+ min_value=None,
+ max_value=None,
+ force_string=False,
+ precision=2,
+ rounding=decimal.ROUND_HALF_UP,
+ **kwargs
+ ):
"""
:param min_value: Validation rule for the minimum acceptable value.
:param max_value: Validation rule for the maximum acceptable value.
@@ -416,10 +469,12 @@ class DecimalField(BaseField):
# Convert to string for python 2.6 before casting to Decimal
try:
- value = decimal.Decimal('%s' % value)
+ value = decimal.Decimal("%s" % value)
except (TypeError, ValueError, decimal.InvalidOperation):
return value
- return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding)
+ return value.quantize(
+ decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding
+ )
def to_mongo(self, value):
if value is None:
@@ -435,13 +490,13 @@ class DecimalField(BaseField):
try:
value = decimal.Decimal(value)
except (TypeError, ValueError, decimal.InvalidOperation) as exc:
- self.error('Could not convert value to decimal: %s' % exc)
+ self.error("Could not convert value to decimal: %s" % exc)
if self.min_value is not None and value < self.min_value:
- self.error('Decimal value is too small')
+ self.error("Decimal value is too small")
if self.max_value is not None and value > self.max_value:
- self.error('Decimal value is too large')
+ self.error("Decimal value is too large")
def prepare_query_value(self, op, value):
return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value))
@@ -462,7 +517,7 @@ class BooleanField(BaseField):
def validate(self, value):
if not isinstance(value, bool):
- self.error('BooleanField only accepts boolean values')
+ self.error("BooleanField only accepts boolean values")
class DateTimeField(BaseField):
@@ -514,26 +569,29 @@ class DateTimeField(BaseField):
return None
# split usecs, because they are not recognized by strptime.
- if '.' in value:
+ if "." in value:
try:
- value, usecs = value.split('.')
+ value, usecs = value.split(".")
usecs = int(usecs)
except ValueError:
return None
else:
usecs = 0
- kwargs = {'microsecond': usecs}
+ kwargs = {"microsecond": usecs}
try: # Seconds are optional, so try converting seconds first.
- return datetime.datetime(*time.strptime(value,
- '%Y-%m-%d %H:%M:%S')[:6], **kwargs)
+ return datetime.datetime(
+ *time.strptime(value, "%Y-%m-%d %H:%M:%S")[:6], **kwargs
+ )
except ValueError:
try: # Try without seconds.
- return datetime.datetime(*time.strptime(value,
- '%Y-%m-%d %H:%M')[:5], **kwargs)
+ return datetime.datetime(
+ *time.strptime(value, "%Y-%m-%d %H:%M")[:5], **kwargs
+ )
except ValueError: # Try without hour/minutes/seconds.
try:
- return datetime.datetime(*time.strptime(value,
- '%Y-%m-%d')[:3], **kwargs)
+ return datetime.datetime(
+ *time.strptime(value, "%Y-%m-%d")[:3], **kwargs
+ )
except ValueError:
return None
@@ -578,12 +636,12 @@ class ComplexDateTimeField(StringField):
.. versionadded:: 0.5
"""
- def __init__(self, separator=',', **kwargs):
+ def __init__(self, separator=",", **kwargs):
"""
:param separator: Allows to customize the separator used for storage (default ``,``)
"""
self.separator = separator
- self.format = separator.join(['%Y', '%m', '%d', '%H', '%M', '%S', '%f'])
+ self.format = separator.join(["%Y", "%m", "%d", "%H", "%M", "%S", "%f"])
super(ComplexDateTimeField, self).__init__(**kwargs)
def _convert_from_datetime(self, val):
@@ -630,8 +688,7 @@ class ComplexDateTimeField(StringField):
def validate(self, value):
value = self.to_python(value)
if not isinstance(value, datetime.datetime):
- self.error('Only datetime objects may used in a '
- 'ComplexDateTimeField')
+ self.error("Only datetime objects may used in a ComplexDateTimeField")
def to_python(self, value):
original_value = value
@@ -645,7 +702,9 @@ class ComplexDateTimeField(StringField):
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
- return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value))
+ return super(ComplexDateTimeField, self).prepare_query_value(
+ op, self._convert_from_datetime(value)
+ )
class EmbeddedDocumentField(BaseField):
@@ -656,11 +715,13 @@ class EmbeddedDocumentField(BaseField):
def __init__(self, document_type, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not (
- isinstance(document_type, six.string_types) or
- issubclass(document_type, EmbeddedDocument)
+ isinstance(document_type, six.string_types)
+ or issubclass(document_type, EmbeddedDocument)
):
- self.error('Invalid embedded document class provided to an '
- 'EmbeddedDocumentField')
+ self.error(
+ "Invalid embedded document class provided to an "
+ "EmbeddedDocumentField"
+ )
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs)
@@ -676,15 +737,19 @@ class EmbeddedDocumentField(BaseField):
if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
# There is a chance that it won't be an EmbeddedDocument (#1661)
- self.error('Invalid embedded document class provided to an '
- 'EmbeddedDocumentField')
+ self.error(
+ "Invalid embedded document class provided to an "
+ "EmbeddedDocumentField"
+ )
self.document_type_obj = resolved_document_type
return self.document_type_obj
def to_python(self, value):
if not isinstance(value, self.document_type):
- return self.document_type._from_son(value, _auto_dereference=self._auto_dereference)
+ return self.document_type._from_son(
+ value, _auto_dereference=self._auto_dereference
+ )
return value
def to_mongo(self, value, use_db_field=True, fields=None):
@@ -698,8 +763,10 @@ class EmbeddedDocumentField(BaseField):
"""
# Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document_type):
- self.error('Invalid embedded document instance provided to an '
- 'EmbeddedDocumentField')
+ self.error(
+ "Invalid embedded document instance provided to an "
+ "EmbeddedDocumentField"
+ )
self.document_type.validate(value, clean)
def lookup_member(self, member_name):
@@ -714,8 +781,10 @@ class EmbeddedDocumentField(BaseField):
try:
value = self.document_type._from_son(value)
except ValueError:
- raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" %
- (self.document_type._class_name,))
+ raise InvalidQueryError(
+ "Querying the embedded document '%s' failed, due to an invalid query value"
+ % (self.document_type._class_name,)
+ )
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)
@@ -732,11 +801,13 @@ class GenericEmbeddedDocumentField(BaseField):
"""
def prepare_query_value(self, op, value):
- return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value))
+ return super(GenericEmbeddedDocumentField, self).prepare_query_value(
+ op, self.to_mongo(value)
+ )
def to_python(self, value):
if isinstance(value, dict):
- doc_cls = get_document(value['_cls'])
+ doc_cls = get_document(value["_cls"])
value = doc_cls._from_son(value)
return value
@@ -744,12 +815,14 @@ class GenericEmbeddedDocumentField(BaseField):
def validate(self, value, clean=True):
if self.choices and isinstance(value, SON):
for choice in self.choices:
- if value['_cls'] == choice._class_name:
+ if value["_cls"] == choice._class_name:
return True
if not isinstance(value, EmbeddedDocument):
- self.error('Invalid embedded document instance provided to an '
- 'GenericEmbeddedDocumentField')
+ self.error(
+ "Invalid embedded document instance provided to an "
+ "GenericEmbeddedDocumentField"
+ )
value.validate(clean=clean)
@@ -766,8 +839,8 @@ class GenericEmbeddedDocumentField(BaseField):
if document is None:
return None
data = document.to_mongo(use_db_field, fields)
- if '_cls' not in data:
- data['_cls'] = document._class_name
+ if "_cls" not in data:
+ data["_cls"] = document._class_name
return data
@@ -784,21 +857,21 @@ class DynamicField(BaseField):
if isinstance(value, six.string_types):
return value
- if hasattr(value, 'to_mongo'):
+ if hasattr(value, "to_mongo"):
cls = value.__class__
val = value.to_mongo(use_db_field, fields)
# If we its a document thats not inherited add _cls
if isinstance(value, Document):
- val = {'_ref': value.to_dbref(), '_cls': cls.__name__}
+ val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
if isinstance(value, EmbeddedDocument):
- val['_cls'] = cls.__name__
+ val["_cls"] = cls.__name__
return val
if not isinstance(value, (dict, list, tuple)):
return value
is_list = False
- if not hasattr(value, 'items'):
+ if not hasattr(value, "items"):
is_list = True
value = {k: v for k, v in enumerate(value)}
@@ -812,10 +885,10 @@ class DynamicField(BaseField):
return value
def to_python(self, value):
- if isinstance(value, dict) and '_cls' in value:
- doc_cls = get_document(value['_cls'])
- if '_ref' in value:
- value = doc_cls._get_db().dereference(value['_ref'])
+ if isinstance(value, dict) and "_cls" in value:
+ doc_cls = get_document(value["_cls"])
+ if "_ref" in value:
+ value = doc_cls._get_db().dereference(value["_ref"])
return doc_cls._from_son(value)
return super(DynamicField, self).to_python(value)
@@ -829,7 +902,7 @@ class DynamicField(BaseField):
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
def validate(self, value, clean=True):
- if hasattr(value, 'validate'):
+ if hasattr(value, "validate"):
value.validate(clean=clean)
@@ -845,7 +918,7 @@ class ListField(ComplexBaseField):
def __init__(self, field=None, **kwargs):
self.field = field
- kwargs.setdefault('default', lambda: [])
+ kwargs.setdefault("default", lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
@@ -853,16 +926,19 @@ class ListField(ComplexBaseField):
# Document class being used rather than a document object
return self
value = instance._data.get(self.name)
- LazyReferenceField = _import_class('LazyReferenceField')
- GenericLazyReferenceField = _import_class('GenericLazyReferenceField')
- if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value:
+ LazyReferenceField = _import_class("LazyReferenceField")
+ GenericLazyReferenceField = _import_class("GenericLazyReferenceField")
+ if (
+ isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField))
+ and value
+ ):
instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
return super(ListField, self).__get__(instance, owner)
def validate(self, value):
"""Make sure that a list of valid fields is being used."""
if not isinstance(value, (list, tuple, BaseQuerySet)):
- self.error('Only lists and tuples may be used in a list field')
+ self.error("Only lists and tuples may be used in a list field")
super(ListField, self).validate(value)
def prepare_query_value(self, op, value):
@@ -871,10 +947,10 @@ class ListField(ComplexBaseField):
# If the value is iterable and it's not a string nor a
# BaseDocument, call prepare_query_value for each of its items.
if (
- op in ('set', 'unset', None) and
- hasattr(value, '__iter__') and
- not isinstance(value, six.string_types) and
- not isinstance(value, BaseDocument)
+ op in ("set", "unset", None)
+ and hasattr(value, "__iter__")
+ and not isinstance(value, six.string_types)
+ and not isinstance(value, BaseDocument)
):
return [self.field.prepare_query_value(op, v) for v in value]
@@ -925,17 +1001,18 @@ class SortedListField(ListField):
_order_reverse = False
def __init__(self, field, **kwargs):
- if 'ordering' in kwargs.keys():
- self._ordering = kwargs.pop('ordering')
- if 'reverse' in kwargs.keys():
- self._order_reverse = kwargs.pop('reverse')
+ if "ordering" in kwargs.keys():
+ self._ordering = kwargs.pop("ordering")
+ if "reverse" in kwargs.keys():
+ self._order_reverse = kwargs.pop("reverse")
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None:
- return sorted(value, key=itemgetter(self._ordering),
- reverse=self._order_reverse)
+ return sorted(
+ value, key=itemgetter(self._ordering), reverse=self._order_reverse
+ )
return sorted(value, reverse=self._order_reverse)
@@ -944,7 +1021,9 @@ def key_not_string(d):
dictionary is not a string.
"""
for k, v in d.items():
- if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)):
+ if not isinstance(k, six.string_types) or (
+ isinstance(v, dict) and key_not_string(v)
+ ):
return True
@@ -953,7 +1032,9 @@ def key_has_dot_or_dollar(d):
dictionary contains a dot or a dollar sign.
"""
for k, v in d.items():
- if ('.' in k or k.startswith('$')) or (isinstance(v, dict) and key_has_dot_or_dollar(v)):
+ if ("." in k or k.startswith("$")) or (
+ isinstance(v, dict) and key_has_dot_or_dollar(v)
+ ):
return True
@@ -972,39 +1053,48 @@ class DictField(ComplexBaseField):
self.field = field
self._auto_dereference = False
- kwargs.setdefault('default', lambda: {})
+ kwargs.setdefault("default", lambda: {})
super(DictField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used."""
if not isinstance(value, dict):
- self.error('Only dictionaries may be used in a DictField')
+ self.error("Only dictionaries may be used in a DictField")
if key_not_string(value):
- msg = ('Invalid dictionary key - documents must '
- 'have only string keys')
+ msg = "Invalid dictionary key - documents must have only string keys"
self.error(msg)
if key_has_dot_or_dollar(value):
- self.error('Invalid dictionary key name - keys may not contain "."'
- ' or startswith "$" characters')
+ self.error(
+ 'Invalid dictionary key name - keys may not contain "."'
+ ' or startswith "$" characters'
+ )
super(DictField, self).validate(value)
def lookup_member(self, member_name):
return DictField(db_field=member_name)
def prepare_query_value(self, op, value):
- match_operators = ['contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith',
- 'exact', 'iexact']
+ match_operators = [
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "exact",
+ "iexact",
+ ]
if op in match_operators and isinstance(value, six.string_types):
return StringField().prepare_query_value(op, value)
- if hasattr(self.field, 'field'): # Used for instance when using DictField(ListField(IntField()))
- if op in ('set', 'unset') and isinstance(value, dict):
+ if hasattr(
+ self.field, "field"
+ ): # Used for instance when using DictField(ListField(IntField()))
+ if op in ("set", "unset") and isinstance(value, dict):
return {
- k: self.field.prepare_query_value(op, v)
- for k, v in value.items()
+ k: self.field.prepare_query_value(op, v) for k, v in value.items()
}
return self.field.prepare_query_value(op, value)
@@ -1022,8 +1112,7 @@ class MapField(DictField):
def __init__(self, field=None, *args, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(field, BaseField):
- self.error('Argument to MapField constructor must be a valid '
- 'field')
+ self.error("Argument to MapField constructor must be a valid field")
super(MapField, self).__init__(field=field, *args, **kwargs)
@@ -1069,8 +1158,9 @@ class ReferenceField(BaseField):
.. versionchanged:: 0.5 added `reverse_delete_rule`
"""
- def __init__(self, document_type, dbref=False,
- reverse_delete_rule=DO_NOTHING, **kwargs):
+ def __init__(
+ self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs
+ ):
"""Initialises the Reference Field.
:param dbref: Store the reference as :class:`~pymongo.dbref.DBRef`
@@ -1083,12 +1173,13 @@ class ReferenceField(BaseField):
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
# XXX ValidationError raised outside of the "validate" method.
- if (
- not isinstance(document_type, six.string_types) and
- not issubclass(document_type, Document)
+ if not isinstance(document_type, six.string_types) and not issubclass(
+ document_type, Document
):
- self.error('Argument to ReferenceField constructor must be a '
- 'document class or a string')
+ self.error(
+ "Argument to ReferenceField constructor must be a "
+ "document class or a string"
+ )
self.dbref = dbref
self.document_type_obj = document_type
@@ -1115,14 +1206,14 @@ class ReferenceField(BaseField):
auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if auto_dereference and isinstance(value, DBRef):
- if hasattr(value, 'cls'):
+ if hasattr(value, "cls"):
# Dereference using the class type specified in the reference
cls = get_document(value.cls)
else:
cls = self.document_type
dereferenced = cls._get_db().dereference(value)
if dereferenced is None:
- raise DoesNotExist('Trying to dereference unknown document %s' % value)
+ raise DoesNotExist("Trying to dereference unknown document %s" % value)
else:
instance._data[self.name] = cls._from_son(dereferenced)
@@ -1140,8 +1231,10 @@ class ReferenceField(BaseField):
# XXX ValidationError raised outside of the "validate" method.
if id_ is None:
- self.error('You can only reference documents once they have'
- ' been saved to the database')
+ self.error(
+ "You can only reference documents once they have"
+ " been saved to the database"
+ )
# Use the attributes from the document instance, so that they
# override the attributes of this field's document type
@@ -1150,11 +1243,11 @@ class ReferenceField(BaseField):
id_ = document
cls = self.document_type
- id_field_name = cls._meta['id_field']
+ id_field_name = cls._meta["id_field"]
id_field = cls._fields[id_field_name]
id_ = id_field.to_mongo(id_)
- if self.document_type._meta.get('abstract'):
+ if self.document_type._meta.get("abstract"):
collection = cls._get_collection_name()
return DBRef(collection, id_, cls=cls._class_name)
elif self.dbref:
@@ -1165,8 +1258,9 @@ class ReferenceField(BaseField):
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type."""
- if (not self.dbref and
- not isinstance(value, (DBRef, Document, EmbeddedDocument))):
+ if not self.dbref and not isinstance(
+ value, (DBRef, Document, EmbeddedDocument)
+ ):
collection = self.document_type._get_collection_name()
value = DBRef(collection, self.document_type.id.to_python(value))
return value
@@ -1179,11 +1273,15 @@ class ReferenceField(BaseField):
def validate(self, value):
if not isinstance(value, (self.document_type, LazyReference, DBRef, ObjectId)):
- self.error('A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents')
+ self.error(
+ "A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents"
+ )
if isinstance(value, Document) and value.id is None:
- self.error('You can only reference documents once they have been '
- 'saved to the database')
+ self.error(
+ "You can only reference documents once they have been "
+ "saved to the database"
+ )
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
@@ -1206,12 +1304,13 @@ class CachedReferenceField(BaseField):
fields = []
# XXX ValidationError raised outside of the "validate" method.
- if (
- not isinstance(document_type, six.string_types) and
- not issubclass(document_type, Document)
+ if not isinstance(document_type, six.string_types) and not issubclass(
+ document_type, Document
):
- self.error('Argument to CachedReferenceField constructor must be a'
- ' document class or a string')
+ self.error(
+ "Argument to CachedReferenceField constructor must be a"
+ " document class or a string"
+ )
self.auto_sync = auto_sync
self.document_type_obj = document_type
@@ -1221,15 +1320,14 @@ class CachedReferenceField(BaseField):
def start_listener(self):
from mongoengine import signals
- signals.post_save.connect(self.on_document_pre_save,
- sender=self.document_type)
+ signals.post_save.connect(self.on_document_pre_save, sender=self.document_type)
def on_document_pre_save(self, sender, document, created, **kwargs):
if created:
return None
update_kwargs = {
- 'set__%s__%s' % (self.name, key): val
+ "set__%s__%s" % (self.name, key): val
for key, val in document._delta()[0].items()
if key in self.fields
}
@@ -1237,15 +1335,15 @@ class CachedReferenceField(BaseField):
filter_kwargs = {}
filter_kwargs[self.name] = document
- self.owner_document.objects(
- **filter_kwargs).update(**update_kwargs)
+ self.owner_document.objects(**filter_kwargs).update(**update_kwargs)
def to_python(self, value):
if isinstance(value, dict):
collection = self.document_type._get_collection_name()
- value = DBRef(
- collection, self.document_type.id.to_python(value['_id']))
- return self.document_type._from_son(self.document_type._get_db().dereference(value))
+ value = DBRef(collection, self.document_type.id.to_python(value["_id"]))
+ return self.document_type._from_son(
+ self.document_type._get_db().dereference(value)
+ )
return value
@@ -1271,14 +1369,14 @@ class CachedReferenceField(BaseField):
if auto_dereference and isinstance(value, DBRef):
dereferenced = self.document_type._get_db().dereference(value)
if dereferenced is None:
- raise DoesNotExist('Trying to dereference unknown document %s' % value)
+ raise DoesNotExist("Trying to dereference unknown document %s" % value)
else:
instance._data[self.name] = self.document_type._from_son(dereferenced)
return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document, use_db_field=True, fields=None):
- id_field_name = self.document_type._meta['id_field']
+ id_field_name = self.document_type._meta["id_field"]
id_field = self.document_type._fields[id_field_name]
# XXX ValidationError raised outside of the "validate" method.
@@ -1286,14 +1384,14 @@ class CachedReferenceField(BaseField):
# We need the id from the saved object to create the DBRef
id_ = document.pk
if id_ is None:
- self.error('You can only reference documents once they have'
- ' been saved to the database')
+ self.error(
+ "You can only reference documents once they have"
+ " been saved to the database"
+ )
else:
- self.error('Only accept a document object')
+ self.error("Only accept a document object")
- value = SON((
- ('_id', id_field.to_mongo(id_)),
- ))
+ value = SON((("_id", id_field.to_mongo(id_)),))
if fields:
new_fields = [f for f in self.fields if f in fields]
@@ -1310,9 +1408,11 @@ class CachedReferenceField(BaseField):
# XXX ValidationError raised outside of the "validate" method.
if isinstance(value, Document):
if value.pk is None:
- self.error('You can only reference documents once they have'
- ' been saved to the database')
- value_dict = {'_id': value.pk}
+ self.error(
+ "You can only reference documents once they have"
+ " been saved to the database"
+ )
+ value_dict = {"_id": value.pk}
for field in self.fields:
value_dict.update({field: value[field]})
@@ -1322,11 +1422,13 @@ class CachedReferenceField(BaseField):
def validate(self, value):
if not isinstance(value, self.document_type):
- self.error('A CachedReferenceField only accepts documents')
+ self.error("A CachedReferenceField only accepts documents")
if isinstance(value, Document) and value.id is None:
- self.error('You can only reference documents once they have been '
- 'saved to the database')
+ self.error(
+ "You can only reference documents once they have been "
+ "saved to the database"
+ )
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
@@ -1336,7 +1438,7 @@ class CachedReferenceField(BaseField):
Sync all cached fields on demand.
Caution: this operation may be slower.
"""
- update_key = 'set__%s' % self.name
+ update_key = "set__%s" % self.name
for doc in self.document_type.objects:
filter_kwargs = {}
@@ -1345,8 +1447,7 @@ class CachedReferenceField(BaseField):
update_kwargs = {}
update_kwargs[update_key] = doc
- self.owner_document.objects(
- **filter_kwargs).update(**update_kwargs)
+ self.owner_document.objects(**filter_kwargs).update(**update_kwargs)
class GenericReferenceField(BaseField):
@@ -1370,7 +1471,7 @@ class GenericReferenceField(BaseField):
"""
def __init__(self, *args, **kwargs):
- choices = kwargs.pop('choices', None)
+ choices = kwargs.pop("choices", None)
super(GenericReferenceField, self).__init__(*args, **kwargs)
self.choices = []
# Keep the choices as a list of allowed Document class names
@@ -1383,14 +1484,16 @@ class GenericReferenceField(BaseField):
else:
# XXX ValidationError raised outside of the "validate"
# method.
- self.error('Invalid choices provided: must be a list of'
- 'Document subclasses and/or six.string_typess')
+ self.error(
+ "Invalid choices provided: must be a list of"
+ "Document subclasses and/or six.string_typess"
+ )
def _validate_choices(self, value):
if isinstance(value, dict):
# If the field has not been dereferenced, it is still a dict
# of class and DBRef
- value = value.get('_cls')
+ value = value.get("_cls")
elif isinstance(value, Document):
value = value._class_name
super(GenericReferenceField, self)._validate_choices(value)
@@ -1405,7 +1508,7 @@ class GenericReferenceField(BaseField):
if auto_dereference and isinstance(value, (dict, SON)):
dereferenced = self.dereference(value)
if dereferenced is None:
- raise DoesNotExist('Trying to dereference unknown document %s' % value)
+ raise DoesNotExist("Trying to dereference unknown document %s" % value)
else:
instance._data[self.name] = dereferenced
@@ -1413,20 +1516,22 @@ class GenericReferenceField(BaseField):
def validate(self, value):
if not isinstance(value, (Document, DBRef, dict, SON)):
- self.error('GenericReferences can only contain documents')
+ self.error("GenericReferences can only contain documents")
if isinstance(value, (dict, SON)):
- if '_ref' not in value or '_cls' not in value:
- self.error('GenericReferences can only contain documents')
+ if "_ref" not in value or "_cls" not in value:
+ self.error("GenericReferences can only contain documents")
# We need the id from the saved object to create the DBRef
elif isinstance(value, Document) and value.id is None:
- self.error('You can only reference documents once they have been'
- ' saved to the database')
+ self.error(
+ "You can only reference documents once they have been"
+ " saved to the database"
+ )
def dereference(self, value):
- doc_cls = get_document(value['_cls'])
- reference = value['_ref']
+ doc_cls = get_document(value["_cls"])
+ reference = value["_ref"]
doc = doc_cls._get_db().dereference(reference)
if doc is not None:
doc = doc_cls._from_son(doc)
@@ -1439,7 +1544,7 @@ class GenericReferenceField(BaseField):
if isinstance(document, (dict, SON, ObjectId, DBRef)):
return document
- id_field_name = document.__class__._meta['id_field']
+ id_field_name = document.__class__._meta["id_field"]
id_field = document.__class__._fields[id_field_name]
if isinstance(document, Document):
@@ -1447,18 +1552,17 @@ class GenericReferenceField(BaseField):
id_ = document.id
if id_ is None:
# XXX ValidationError raised outside of the "validate" method.
- self.error('You can only reference documents once they have'
- ' been saved to the database')
+ self.error(
+ "You can only reference documents once they have"
+ " been saved to the database"
+ )
else:
id_ = document
id_ = id_field.to_mongo(id_)
collection = document._get_collection_name()
ref = DBRef(collection, id_)
- return SON((
- ('_cls', document._class_name),
- ('_ref', ref)
- ))
+ return SON((("_cls", document._class_name), ("_ref", ref)))
def prepare_query_value(self, op, value):
if value is None:
@@ -1485,18 +1589,18 @@ class BinaryField(BaseField):
def validate(self, value):
if not isinstance(value, (six.binary_type, Binary)):
- self.error('BinaryField only accepts instances of '
- '(%s, %s, Binary)' % (
- six.binary_type.__name__, Binary.__name__))
+ self.error(
+ "BinaryField only accepts instances of "
+ "(%s, %s, Binary)" % (six.binary_type.__name__, Binary.__name__)
+ )
if self.max_bytes is not None and len(value) > self.max_bytes:
- self.error('Binary value is too long')
+ self.error("Binary value is too long")
def prepare_query_value(self, op, value):
if value is None:
return value
- return super(BinaryField, self).prepare_query_value(
- op, self.to_mongo(value))
+ return super(BinaryField, self).prepare_query_value(op, self.to_mongo(value))
class GridFSError(Exception):
@@ -1513,10 +1617,14 @@ class GridFSProxy(object):
_fs = None
- def __init__(self, grid_id=None, key=None,
- instance=None,
- db_alias=DEFAULT_CONNECTION_NAME,
- collection_name='fs'):
+ def __init__(
+ self,
+ grid_id=None,
+ key=None,
+ instance=None,
+ db_alias=DEFAULT_CONNECTION_NAME,
+ collection_name="fs",
+ ):
self.grid_id = grid_id # Store GridFS id for file
self.key = key
self.instance = instance
@@ -1526,8 +1634,16 @@ class GridFSProxy(object):
self.gridout = None
def __getattr__(self, name):
- attrs = ('_fs', 'grid_id', 'key', 'instance', 'db_alias',
- 'collection_name', 'newfile', 'gridout')
+ attrs = (
+ "_fs",
+ "grid_id",
+ "key",
+ "instance",
+ "db_alias",
+ "collection_name",
+ "newfile",
+ "gridout",
+ )
if name in attrs:
return self.__getattribute__(name)
obj = self.get()
@@ -1545,7 +1661,7 @@ class GridFSProxy(object):
def __getstate__(self):
self_dict = self.__dict__
- self_dict['_fs'] = None
+ self_dict["_fs"] = None
return self_dict
def __copy__(self):
@@ -1557,18 +1673,20 @@ class GridFSProxy(object):
return self.__copy__()
def __repr__(self):
- return '<%s: %s>' % (self.__class__.__name__, self.grid_id)
+ return "<%s: %s>" % (self.__class__.__name__, self.grid_id)
def __str__(self):
gridout = self.get()
- filename = getattr(gridout, 'filename') if gridout else ''
- return '<%s: %s (%s)>' % (self.__class__.__name__, filename, self.grid_id)
+ filename = getattr(gridout, "filename") if gridout else ""
+ return "<%s: %s (%s)>" % (self.__class__.__name__, filename, self.grid_id)
def __eq__(self, other):
if isinstance(other, GridFSProxy):
- return ((self.grid_id == other.grid_id) and
- (self.collection_name == other.collection_name) and
- (self.db_alias == other.db_alias))
+ return (
+ (self.grid_id == other.grid_id)
+ and (self.collection_name == other.collection_name)
+ and (self.db_alias == other.db_alias)
+ )
else:
return False
@@ -1578,8 +1696,7 @@ class GridFSProxy(object):
@property
def fs(self):
if not self._fs:
- self._fs = gridfs.GridFS(
- get_db(self.db_alias), self.collection_name)
+ self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name)
return self._fs
def get(self, grid_id=None):
@@ -1604,16 +1721,20 @@ class GridFSProxy(object):
def put(self, file_obj, **kwargs):
if self.grid_id:
- raise GridFSError('This document already has a file. Either delete '
- 'it or call replace to overwrite it')
+ raise GridFSError(
+ "This document already has a file. Either delete "
+ "it or call replace to overwrite it"
+ )
self.grid_id = self.fs.put(file_obj, **kwargs)
self._mark_as_changed()
def write(self, string):
if self.grid_id:
if not self.newfile:
- raise GridFSError('This document already has a file. Either '
- 'delete it or call replace to overwrite it')
+ raise GridFSError(
+ "This document already has a file. Either "
+ "delete it or call replace to overwrite it"
+ )
else:
self.new_file()
self.newfile.write(string)
@@ -1632,7 +1753,7 @@ class GridFSProxy(object):
try:
return gridout.read(size)
except Exception:
- return ''
+ return ""
def delete(self):
# Delete file from GridFS, FileField still remains
@@ -1662,10 +1783,12 @@ class FileField(BaseField):
.. versionchanged:: 0.5 added optional size param for read
.. versionchanged:: 0.6 added db_alias for multidb support
"""
+
proxy_class = GridFSProxy
- def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs',
- **kwargs):
+ def __init__(
+ self, db_alias=DEFAULT_CONNECTION_NAME, collection_name="fs", **kwargs
+ ):
super(FileField, self).__init__(**kwargs)
self.collection_name = collection_name
self.db_alias = db_alias
@@ -1688,9 +1811,8 @@ class FileField(BaseField):
def __set__(self, instance, value):
key = self.name
if (
- (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or
- isinstance(value, (six.binary_type, six.string_types))
- ):
+ hasattr(value, "read") and not isinstance(value, GridFSProxy)
+ ) or isinstance(value, (six.binary_type, six.string_types)):
# using "FileField() = file/string" notation
grid_file = instance._data.get(self.name)
# If a file already exists, delete it
@@ -1701,8 +1823,7 @@ class FileField(BaseField):
pass
# Create a new proxy object as we don't already have one
- instance._data[key] = self.get_proxy_obj(
- key=key, instance=instance)
+ instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
instance._data[key].put(value)
else:
instance._data[key] = value
@@ -1715,9 +1836,12 @@ class FileField(BaseField):
if collection_name is None:
collection_name = self.collection_name
- return self.proxy_class(key=key, instance=instance,
- db_alias=db_alias,
- collection_name=collection_name)
+ return self.proxy_class(
+ key=key,
+ instance=instance,
+ db_alias=db_alias,
+ collection_name=collection_name,
+ )
def to_mongo(self, value):
# Store the GridFS file id in MongoDB
@@ -1727,16 +1851,16 @@ class FileField(BaseField):
def to_python(self, value):
if value is not None:
- return self.proxy_class(value,
- collection_name=self.collection_name,
- db_alias=self.db_alias)
+ return self.proxy_class(
+ value, collection_name=self.collection_name, db_alias=self.db_alias
+ )
def validate(self, value):
if value.grid_id is not None:
if not isinstance(value, self.proxy_class):
- self.error('FileField only accepts GridFSProxy values')
+ self.error("FileField only accepts GridFSProxy values")
if not isinstance(value.grid_id, ObjectId):
- self.error('Invalid GridFSProxy value')
+ self.error("Invalid GridFSProxy value")
class ImageGridFsProxy(GridFSProxy):
@@ -1753,52 +1877,51 @@ class ImageGridFsProxy(GridFSProxy):
"""
field = self.instance._fields[self.key]
# Handle nested fields
- if hasattr(field, 'field') and isinstance(field.field, FileField):
+ if hasattr(field, "field") and isinstance(field.field, FileField):
field = field.field
try:
img = Image.open(file_obj)
img_format = img.format
except Exception as e:
- raise ValidationError('Invalid image: %s' % e)
+ raise ValidationError("Invalid image: %s" % e)
# Progressive JPEG
# TODO: fixme, at least unused, at worst bad implementation
- progressive = img.info.get('progressive') or False
+ progressive = img.info.get("progressive") or False
- if (kwargs.get('progressive') and
- isinstance(kwargs.get('progressive'), bool) and
- img_format == 'JPEG'):
+ if (
+ kwargs.get("progressive")
+ and isinstance(kwargs.get("progressive"), bool)
+ and img_format == "JPEG"
+ ):
progressive = True
else:
progressive = False
- if (field.size and (img.size[0] > field.size['width'] or
- img.size[1] > field.size['height'])):
+ if field.size and (
+ img.size[0] > field.size["width"] or img.size[1] > field.size["height"]
+ ):
size = field.size
- if size['force']:
- img = ImageOps.fit(img,
- (size['width'],
- size['height']),
- Image.ANTIALIAS)
+ if size["force"]:
+ img = ImageOps.fit(
+ img, (size["width"], size["height"]), Image.ANTIALIAS
+ )
else:
- img.thumbnail((size['width'],
- size['height']),
- Image.ANTIALIAS)
+ img.thumbnail((size["width"], size["height"]), Image.ANTIALIAS)
thumbnail = None
if field.thumbnail_size:
size = field.thumbnail_size
- if size['force']:
+ if size["force"]:
thumbnail = ImageOps.fit(
- img, (size['width'], size['height']), Image.ANTIALIAS)
+ img, (size["width"], size["height"]), Image.ANTIALIAS
+ )
else:
thumbnail = img.copy()
- thumbnail.thumbnail((size['width'],
- size['height']),
- Image.ANTIALIAS)
+ thumbnail.thumbnail((size["width"], size["height"]), Image.ANTIALIAS)
if thumbnail:
thumb_id = self._put_thumbnail(thumbnail, img_format, progressive)
@@ -1811,12 +1934,9 @@ class ImageGridFsProxy(GridFSProxy):
img.save(io, img_format, progressive=progressive)
io.seek(0)
- return super(ImageGridFsProxy, self).put(io,
- width=w,
- height=h,
- format=img_format,
- thumbnail_id=thumb_id,
- **kwargs)
+ return super(ImageGridFsProxy, self).put(
+ io, width=w, height=h, format=img_format, thumbnail_id=thumb_id, **kwargs
+ )
def delete(self, *args, **kwargs):
# deletes thumbnail
@@ -1833,10 +1953,7 @@ class ImageGridFsProxy(GridFSProxy):
thumbnail.save(io, format, progressive=progressive)
io.seek(0)
- return self.fs.put(io, width=w,
- height=h,
- format=format,
- **kwargs)
+ return self.fs.put(io, width=w, height=h, format=format, **kwargs)
@property
def size(self):
@@ -1888,32 +2005,30 @@ class ImageField(FileField):
.. versionadded:: 0.6
"""
+
proxy_class = ImageGridFsProxy
- def __init__(self, size=None, thumbnail_size=None,
- collection_name='images', **kwargs):
+ def __init__(
+ self, size=None, thumbnail_size=None, collection_name="images", **kwargs
+ ):
if not Image:
- raise ImproperlyConfigured('PIL library was not found')
+ raise ImproperlyConfigured("PIL library was not found")
- params_size = ('width', 'height', 'force')
- extra_args = {
- 'size': size,
- 'thumbnail_size': thumbnail_size
- }
+ params_size = ("width", "height", "force")
+ extra_args = {"size": size, "thumbnail_size": thumbnail_size}
for att_name, att in extra_args.items():
value = None
if isinstance(att, (tuple, list)):
if six.PY3:
- value = dict(itertools.zip_longest(params_size, att,
- fillvalue=None))
+ value = dict(
+ itertools.zip_longest(params_size, att, fillvalue=None)
+ )
else:
value = dict(map(None, params_size, att))
setattr(self, att_name, value)
- super(ImageField, self).__init__(
- collection_name=collection_name,
- **kwargs)
+ super(ImageField, self).__init__(collection_name=collection_name, **kwargs)
class SequenceField(BaseField):
@@ -1947,15 +2062,24 @@ class SequenceField(BaseField):
"""
_auto_gen = True
- COLLECTION_NAME = 'mongoengine.counters'
+ COLLECTION_NAME = "mongoengine.counters"
VALUE_DECORATOR = int
- def __init__(self, collection_name=None, db_alias=None, sequence_name=None,
- value_decorator=None, *args, **kwargs):
+ def __init__(
+ self,
+ collection_name=None,
+ db_alias=None,
+ sequence_name=None,
+ value_decorator=None,
+ *args,
+ **kwargs
+ ):
self.collection_name = collection_name or self.COLLECTION_NAME
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
self.sequence_name = sequence_name
- self.value_decorator = value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
+ self.value_decorator = (
+ value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
+ )
super(SequenceField, self).__init__(*args, **kwargs)
def generate(self):
@@ -1963,15 +2087,16 @@ class SequenceField(BaseField):
Generate and Increment the counter
"""
sequence_name = self.get_sequence_name()
- sequence_id = '%s.%s' % (sequence_name, self.name)
+ sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_one_and_update(
- filter={'_id': sequence_id},
- update={'$inc': {'next': 1}},
+ filter={"_id": sequence_id},
+ update={"$inc": {"next": 1}},
return_document=ReturnDocument.AFTER,
- upsert=True)
- return self.value_decorator(counter['next'])
+ upsert=True,
+ )
+ return self.value_decorator(counter["next"])
def set_next_value(self, value):
"""Helper method to set the next sequence value"""
@@ -1982,8 +2107,9 @@ class SequenceField(BaseField):
filter={"_id": sequence_id},
update={"$set": {"next": value}},
return_document=ReturnDocument.AFTER,
- upsert=True)
- return self.value_decorator(counter['next'])
+ upsert=True,
+ )
+ return self.value_decorator(counter["next"])
def get_next_value(self):
"""Helper method to get the next value for previewing.
@@ -1992,12 +2118,12 @@ class SequenceField(BaseField):
as it is only fixed on set.
"""
sequence_name = self.get_sequence_name()
- sequence_id = '%s.%s' % (sequence_name, self.name)
+ sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
- data = collection.find_one({'_id': sequence_id})
+ data = collection.find_one({"_id": sequence_id})
if data:
- return self.value_decorator(data['next'] + 1)
+ return self.value_decorator(data["next"] + 1)
return self.value_decorator(1)
@@ -2005,11 +2131,14 @@ class SequenceField(BaseField):
if self.sequence_name:
return self.sequence_name
owner = self.owner_document
- if issubclass(owner, Document) and not owner._meta.get('abstract'):
+ if issubclass(owner, Document) and not owner._meta.get("abstract"):
return owner._get_collection_name()
else:
- return ''.join('_%s' % c if c.isupper() else c
- for c in owner._class_name).strip('_').lower()
+ return (
+ "".join("_%s" % c if c.isupper() else c for c in owner._class_name)
+ .strip("_")
+ .lower()
+ )
def __get__(self, instance, owner):
value = super(SequenceField, self).__get__(instance, owner)
@@ -2046,6 +2175,7 @@ class UUIDField(BaseField):
.. versionadded:: 0.6
"""
+
_binary = None
def __init__(self, binary=True, **kwargs):
@@ -2090,7 +2220,7 @@ class UUIDField(BaseField):
try:
uuid.UUID(value)
except (ValueError, TypeError, AttributeError) as exc:
- self.error('Could not convert to UUID: %s' % exc)
+ self.error("Could not convert to UUID: %s" % exc)
class GeoPointField(BaseField):
@@ -2109,16 +2239,14 @@ class GeoPointField(BaseField):
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)"""
if not isinstance(value, (list, tuple)):
- self.error('GeoPointField can only accept tuples or lists '
- 'of (x, y)')
+ self.error("GeoPointField can only accept tuples or lists of (x, y)")
if not len(value) == 2:
- self.error('Value (%s) must be a two-dimensional point' %
- repr(value))
- elif (not isinstance(value[0], (float, int)) or
- not isinstance(value[1], (float, int))):
- self.error(
- 'Both values (%s) in point must be float or int' % repr(value))
+ self.error("Value (%s) must be a two-dimensional point" % repr(value))
+ elif not isinstance(value[0], (float, int)) or not isinstance(
+ value[1], (float, int)
+ ):
+ self.error("Both values (%s) in point must be float or int" % repr(value))
class PointField(GeoJsonBaseField):
@@ -2138,7 +2266,8 @@ class PointField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
- _type = 'Point'
+
+ _type = "Point"
class LineStringField(GeoJsonBaseField):
@@ -2157,7 +2286,8 @@ class LineStringField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
- _type = 'LineString'
+
+ _type = "LineString"
class PolygonField(GeoJsonBaseField):
@@ -2179,7 +2309,8 @@ class PolygonField(GeoJsonBaseField):
.. versionadded:: 0.8
"""
- _type = 'Polygon'
+
+ _type = "Polygon"
class MultiPointField(GeoJsonBaseField):
@@ -2199,7 +2330,8 @@ class MultiPointField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
- _type = 'MultiPoint'
+
+ _type = "MultiPoint"
class MultiLineStringField(GeoJsonBaseField):
@@ -2219,7 +2351,8 @@ class MultiLineStringField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
- _type = 'MultiLineString'
+
+ _type = "MultiLineString"
class MultiPolygonField(GeoJsonBaseField):
@@ -2246,7 +2379,8 @@ class MultiPolygonField(GeoJsonBaseField):
.. versionadded:: 0.9
"""
- _type = 'MultiPolygon'
+
+ _type = "MultiPolygon"
class LazyReferenceField(BaseField):
@@ -2260,8 +2394,14 @@ class LazyReferenceField(BaseField):
.. versionadded:: 0.15
"""
- def __init__(self, document_type, passthrough=False, dbref=False,
- reverse_delete_rule=DO_NOTHING, **kwargs):
+ def __init__(
+ self,
+ document_type,
+ passthrough=False,
+ dbref=False,
+ reverse_delete_rule=DO_NOTHING,
+ **kwargs
+ ):
"""Initialises the Reference Field.
:param dbref: Store the reference as :class:`~pymongo.dbref.DBRef`
@@ -2274,12 +2414,13 @@ class LazyReferenceField(BaseField):
document. Note this only work getting field (not setting or deleting).
"""
# XXX ValidationError raised outside of the "validate" method.
- if (
- not isinstance(document_type, six.string_types) and
- not issubclass(document_type, Document)
+ if not isinstance(document_type, six.string_types) and not issubclass(
+ document_type, Document
):
- self.error('Argument to LazyReferenceField constructor must be a '
- 'document class or a string')
+ self.error(
+ "Argument to LazyReferenceField constructor must be a "
+ "document class or a string"
+ )
self.dbref = dbref
self.passthrough = passthrough
@@ -2299,15 +2440,23 @@ class LazyReferenceField(BaseField):
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
- value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
+ value = LazyReference(
+ value.document_type, value.pk, passthrough=self.passthrough
+ )
elif value is not None:
if isinstance(value, self.document_type):
- value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough)
+ value = LazyReference(
+ self.document_type, value.pk, passthrough=self.passthrough
+ )
elif isinstance(value, DBRef):
- value = LazyReference(self.document_type, value.id, passthrough=self.passthrough)
+ value = LazyReference(
+ self.document_type, value.id, passthrough=self.passthrough
+ )
else:
# value is the primary key of the referenced document
- value = LazyReference(self.document_type, value, passthrough=self.passthrough)
+ value = LazyReference(
+ self.document_type, value, passthrough=self.passthrough
+ )
return value
def __get__(self, instance, owner):
@@ -2332,7 +2481,7 @@ class LazyReferenceField(BaseField):
else:
# value is the primary key of the referenced document
pk = value
- id_field_name = self.document_type._meta['id_field']
+ id_field_name = self.document_type._meta["id_field"]
id_field = self.document_type._fields[id_field_name]
pk = id_field.to_mongo(pk)
if self.dbref:
@@ -2343,7 +2492,7 @@ class LazyReferenceField(BaseField):
def validate(self, value):
if isinstance(value, LazyReference):
if value.collection != self.document_type._get_collection_name():
- self.error('Reference must be on a `%s` document.' % self.document_type)
+ self.error("Reference must be on a `%s` document." % self.document_type)
pk = value.pk
elif isinstance(value, self.document_type):
pk = value.pk
@@ -2355,7 +2504,7 @@ class LazyReferenceField(BaseField):
pk = value.id
else:
# value is the primary key of the referenced document
- id_field_name = self.document_type._meta['id_field']
+ id_field_name = self.document_type._meta["id_field"]
id_field = getattr(self.document_type, id_field_name)
pk = value
try:
@@ -2364,11 +2513,15 @@ class LazyReferenceField(BaseField):
self.error(
"value should be `{0}` document, LazyReference or DBRef on `{0}` "
"or `{0}`'s primary key (i.e. `{1}`)".format(
- self.document_type.__name__, type(id_field).__name__))
+ self.document_type.__name__, type(id_field).__name__
+ )
+ )
if pk is None:
- self.error('You can only reference documents once they have been '
- 'saved to the database')
+ self.error(
+ "You can only reference documents once they have been "
+ "saved to the database"
+ )
def prepare_query_value(self, op, value):
if value is None:
@@ -2399,7 +2552,7 @@ class GenericLazyReferenceField(GenericReferenceField):
"""
def __init__(self, *args, **kwargs):
- self.passthrough = kwargs.pop('passthrough', False)
+ self.passthrough = kwargs.pop("passthrough", False)
super(GenericLazyReferenceField, self).__init__(*args, **kwargs)
def _validate_choices(self, value):
@@ -2410,12 +2563,20 @@ class GenericLazyReferenceField(GenericReferenceField):
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
- value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
+ value = LazyReference(
+ value.document_type, value.pk, passthrough=self.passthrough
+ )
elif value is not None:
if isinstance(value, (dict, SON)):
- value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough)
+ value = LazyReference(
+ get_document(value["_cls"]),
+ value["_ref"].id,
+ passthrough=self.passthrough,
+ )
elif isinstance(value, Document):
- value = LazyReference(type(value), value.pk, passthrough=self.passthrough)
+ value = LazyReference(
+ type(value), value.pk, passthrough=self.passthrough
+ )
return value
def __get__(self, instance, owner):
@@ -2430,8 +2591,10 @@ class GenericLazyReferenceField(GenericReferenceField):
def validate(self, value):
if isinstance(value, LazyReference) and value.pk is None:
- self.error('You can only reference documents once they have been'
- ' saved to the database')
+ self.error(
+ "You can only reference documents once they have been"
+ " saved to the database"
+ )
return super(GenericLazyReferenceField, self).validate(value)
def to_mongo(self, document):
@@ -2439,9 +2602,16 @@ class GenericLazyReferenceField(GenericReferenceField):
return None
if isinstance(document, LazyReference):
- return SON((
- ('_cls', document.document_type._class_name),
- ('_ref', DBRef(document.document_type._get_collection_name(), document.pk))
- ))
+ return SON(
+ (
+ ("_cls", document.document_type._class_name),
+ (
+ "_ref",
+ DBRef(
+ document.document_type._get_collection_name(), document.pk
+ ),
+ ),
+ )
+ )
else:
return super(GenericLazyReferenceField, self).to_mongo(document)
diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py
index b20ebc1e..5d437fef 100644
--- a/mongoengine/mongodb_support.py
+++ b/mongoengine/mongodb_support.py
@@ -15,5 +15,5 @@ def get_mongodb_version():
:return: tuple(int, int)
"""
- version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2)
+ version_list = get_connection().server_info()["versionArray"][:2] # e.g: (3, 2)
return tuple(version_list)
diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py
index f66c038e..80c0661b 100644
--- a/mongoengine/pymongo_support.py
+++ b/mongoengine/pymongo_support.py
@@ -27,6 +27,6 @@ def list_collection_names(db, include_system_collections=False):
collections = db.collection_names()
if not include_system_collections:
- collections = [c for c in collections if not c.startswith('system.')]
+ collections = [c for c in collections if not c.startswith("system.")]
return collections
diff --git a/mongoengine/queryset/__init__.py b/mongoengine/queryset/__init__.py
index 5219c39e..f041d07b 100644
--- a/mongoengine/queryset/__init__.py
+++ b/mongoengine/queryset/__init__.py
@@ -7,11 +7,22 @@ from mongoengine.queryset.visitor import *
# Expose just the public subset of all imported objects and constants.
__all__ = (
- 'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager',
- 'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL',
-
+ "QuerySet",
+ "QuerySetNoCache",
+ "Q",
+ "queryset_manager",
+ "QuerySetManager",
+ "QueryFieldList",
+ "DO_NOTHING",
+ "NULLIFY",
+ "CASCADE",
+ "DENY",
+ "PULL",
# Errors that might be related to a queryset, mostly here for backward
# compatibility
- 'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned',
- 'NotUniqueError', 'OperationError',
+ "DoesNotExist",
+ "InvalidQueryError",
+ "MultipleObjectsReturned",
+ "NotUniqueError",
+ "OperationError",
)
diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py
index 85616c85..78e85399 100644
--- a/mongoengine/queryset/base.py
+++ b/mongoengine/queryset/base.py
@@ -20,14 +20,18 @@ from mongoengine.base import get_document
from mongoengine.common import _import_class
from mongoengine.connection import get_db
from mongoengine.context_managers import set_write_concern, switch_db
-from mongoengine.errors import (InvalidQueryError, LookUpError,
- NotUniqueError, OperationError)
+from mongoengine.errors import (
+ InvalidQueryError,
+ LookUpError,
+ NotUniqueError,
+ OperationError,
+)
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
-__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
+__all__ = ("BaseQuerySet", "DO_NOTHING", "NULLIFY", "CASCADE", "DENY", "PULL")
# Delete rules
DO_NOTHING = 0
@@ -41,6 +45,7 @@ class BaseQuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results.
"""
+
__dereference = False
_auto_dereference = True
@@ -66,13 +71,12 @@ class BaseQuerySet(object):
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
- if document._meta.get('allow_inheritance') is True:
+ if document._meta.get("allow_inheritance") is True:
if len(self._document._subclasses) == 1:
- self._initial_query = {'_cls': self._document._subclasses[0]}
+ self._initial_query = {"_cls": self._document._subclasses[0]}
else:
- self._initial_query = {
- '_cls': {'$in': self._document._subclasses}}
- self._loaded_fields = QueryFieldList(always_include=['_cls'])
+ self._initial_query = {"_cls": {"$in": self._document._subclasses}}
+ self._loaded_fields = QueryFieldList(always_include=["_cls"])
self._cursor_obj = None
self._limit = None
@@ -83,8 +87,7 @@ class BaseQuerySet(object):
self._max_time_ms = None
self._comment = None
- def __call__(self, q_obj=None, class_check=True, read_preference=None,
- **query):
+ def __call__(self, q_obj=None, class_check=True, read_preference=None, **query):
"""Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query.
@@ -102,8 +105,10 @@ class BaseQuerySet(object):
if q_obj:
# make sure proper query object is passed
if not isinstance(q_obj, QNode):
- msg = ('Not a query object: %s. '
- 'Did you intend to use key=value?' % q_obj)
+ msg = (
+ "Not a query object: %s. "
+ "Did you intend to use key=value?" % q_obj
+ )
raise InvalidQueryError(msg)
query &= q_obj
@@ -130,10 +135,10 @@ class BaseQuerySet(object):
obj_dict = self.__dict__.copy()
# don't picke collection, instead pickle collection params
- obj_dict.pop('_collection_obj')
+ obj_dict.pop("_collection_obj")
# don't pickle cursor
- obj_dict['_cursor_obj'] = None
+ obj_dict["_cursor_obj"] = None
return obj_dict
@@ -144,7 +149,7 @@ class BaseQuerySet(object):
See https://github.com/MongoEngine/mongoengine/issues/442
"""
- obj_dict['_collection_obj'] = obj_dict['_document']._get_collection()
+ obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection()
# update attributes
self.__dict__.update(obj_dict)
@@ -182,7 +187,7 @@ class BaseQuerySet(object):
queryset._document._from_son(
queryset._cursor[key],
_auto_dereference=self._auto_dereference,
- only_fields=self.only_fields
+ only_fields=self.only_fields,
)
)
@@ -192,10 +197,10 @@ class BaseQuerySet(object):
return queryset._document._from_son(
queryset._cursor[key],
_auto_dereference=self._auto_dereference,
- only_fields=self.only_fields
+ only_fields=self.only_fields,
)
- raise TypeError('Provide a slice or an integer index')
+ raise TypeError("Provide a slice or an integer index")
def __iter__(self):
raise NotImplementedError
@@ -235,14 +240,13 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
if queryset._search_text:
- raise OperationError(
- 'It is not possible to use search_text two times.')
+ raise OperationError("It is not possible to use search_text two times.")
- query_kwargs = SON({'$search': text})
+ query_kwargs = SON({"$search": text})
if language:
- query_kwargs['$language'] = language
+ query_kwargs["$language"] = language
- queryset._query_obj &= Q(__raw__={'$text': query_kwargs})
+ queryset._query_obj &= Q(__raw__={"$text": query_kwargs})
queryset._mongo_query = None
queryset._cursor_obj = None
queryset._search_text = text
@@ -265,8 +269,7 @@ class BaseQuerySet(object):
try:
result = six.next(queryset)
except StopIteration:
- msg = ('%s matching query does not exist.'
- % queryset._document._class_name)
+ msg = "%s matching query does not exist." % queryset._document._class_name
raise queryset._document.DoesNotExist(msg)
try:
six.next(queryset)
@@ -276,7 +279,7 @@ class BaseQuerySet(object):
# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind()
- message = u'%d items returned, instead of 1' % queryset.count()
+ message = u"%d items returned, instead of 1" % queryset.count()
raise queryset._document.MultipleObjectsReturned(message)
def create(self, **kwargs):
@@ -295,8 +298,9 @@ class BaseQuerySet(object):
result = None
return result
- def insert(self, doc_or_docs, load_bulk=True,
- write_concern=None, signal_kwargs=None):
+ def insert(
+ self, doc_or_docs, load_bulk=True, write_concern=None, signal_kwargs=None
+ ):
"""bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted
@@ -319,7 +323,7 @@ class BaseQuerySet(object):
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
- Document = _import_class('Document')
+ Document = _import_class("Document")
if write_concern is None:
write_concern = {}
@@ -332,16 +336,16 @@ class BaseQuerySet(object):
for doc in docs:
if not isinstance(doc, self._document):
- msg = ("Some documents inserted aren't instances of %s"
- % str(self._document))
+ msg = "Some documents inserted aren't instances of %s" % str(
+ self._document
+ )
raise OperationError(msg)
if doc.pk and not doc._created:
- msg = 'Some documents have ObjectIds, use doc.update() instead'
+ msg = "Some documents have ObjectIds, use doc.update() instead"
raise OperationError(msg)
signal_kwargs = signal_kwargs or {}
- signals.pre_bulk_insert.send(self._document,
- documents=docs, **signal_kwargs)
+ signals.pre_bulk_insert.send(self._document, documents=docs, **signal_kwargs)
raw = [doc.to_mongo() for doc in docs]
@@ -353,21 +357,25 @@ class BaseQuerySet(object):
try:
inserted_result = insert_func(raw)
- ids = [inserted_result.inserted_id] if return_one else inserted_result.inserted_ids
+ ids = (
+ [inserted_result.inserted_id]
+ if return_one
+ else inserted_result.inserted_ids
+ )
except pymongo.errors.DuplicateKeyError as err:
- message = 'Could not save document (%s)'
+ message = "Could not save document (%s)"
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.BulkWriteError as err:
# inserting documents that already have an _id field will
# give huge performance debt or raise
- message = u'Document must not have _id value before bulk write (%s)'
+ message = u"Document must not have _id value before bulk write (%s)"
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err:
- message = 'Could not save document (%s)'
- if re.match('^E1100[01] duplicate key', six.text_type(err)):
+ message = "Could not save document (%s)"
+ if re.match("^E1100[01] duplicate key", six.text_type(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
- message = u'Tried to save duplicate unique keys (%s)'
+ message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
@@ -377,13 +385,15 @@ class BaseQuerySet(object):
if not load_bulk:
signals.post_bulk_insert.send(
- self._document, documents=docs, loaded=False, **signal_kwargs)
+ self._document, documents=docs, loaded=False, **signal_kwargs
+ )
return ids[0] if return_one else ids
documents = self.in_bulk(ids)
results = [documents.get(obj_id) for obj_id in ids]
signals.post_bulk_insert.send(
- self._document, documents=results, loaded=True, **signal_kwargs)
+ self._document, documents=results, loaded=True, **signal_kwargs
+ )
return results[0] if return_one else results
def count(self, with_limit_and_skip=False):
@@ -399,8 +409,7 @@ class BaseQuerySet(object):
self._cursor_obj = None
return count
- def delete(self, write_concern=None, _from_doc_delete=False,
- cascade_refs=None):
+ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None):
"""Delete the documents matched by the query.
:param write_concern: Extra keyword arguments are passed down which
@@ -423,12 +432,13 @@ class BaseQuerySet(object):
# Handle deletes where skips or limits have been applied or
# there is an untriggered delete signal
has_delete_signal = signals.signals_available and (
- signals.pre_delete.has_receivers_for(doc) or
- signals.post_delete.has_receivers_for(doc)
+ signals.pre_delete.has_receivers_for(doc)
+ or signals.post_delete.has_receivers_for(doc)
)
- call_document_delete = (queryset._skip or queryset._limit or
- has_delete_signal) and not _from_doc_delete
+ call_document_delete = (
+ queryset._skip or queryset._limit or has_delete_signal
+ ) and not _from_doc_delete
if call_document_delete:
cnt = 0
@@ -437,28 +447,28 @@ class BaseQuerySet(object):
cnt += 1
return cnt
- delete_rules = doc._meta.get('delete_rules') or {}
+ delete_rules = doc._meta.get("delete_rules") or {}
delete_rules = list(delete_rules.items())
# Check for DENY rules before actually deleting/nullifying any other
# references
for rule_entry, rule in delete_rules:
document_cls, field_name = rule_entry
- if document_cls._meta.get('abstract'):
+ if document_cls._meta.get("abstract"):
continue
if rule == DENY:
- refs = document_cls.objects(**{field_name + '__in': self})
+ refs = document_cls.objects(**{field_name + "__in": self})
if refs.limit(1).count() > 0:
raise OperationError(
- 'Could not delete document (%s.%s refers to it)'
+ "Could not delete document (%s.%s refers to it)"
% (document_cls.__name__, field_name)
)
# Check all the other rules
for rule_entry, rule in delete_rules:
document_cls, field_name = rule_entry
- if document_cls._meta.get('abstract'):
+ if document_cls._meta.get("abstract"):
continue
if rule == CASCADE:
@@ -467,19 +477,19 @@ class BaseQuerySet(object):
if doc._collection == document_cls._collection:
for ref in queryset:
cascade_refs.add(ref.id)
- refs = document_cls.objects(**{field_name + '__in': self,
- 'pk__nin': cascade_refs})
+ refs = document_cls.objects(
+ **{field_name + "__in": self, "pk__nin": cascade_refs}
+ )
if refs.count() > 0:
- refs.delete(write_concern=write_concern,
- cascade_refs=cascade_refs)
+ refs.delete(write_concern=write_concern, cascade_refs=cascade_refs)
elif rule == NULLIFY:
- document_cls.objects(**{field_name + '__in': self}).update(
- write_concern=write_concern,
- **{'unset__%s' % field_name: 1})
+ document_cls.objects(**{field_name + "__in": self}).update(
+ write_concern=write_concern, **{"unset__%s" % field_name: 1}
+ )
elif rule == PULL:
- document_cls.objects(**{field_name + '__in': self}).update(
- write_concern=write_concern,
- **{'pull_all__%s' % field_name: self})
+ document_cls.objects(**{field_name + "__in": self}).update(
+ write_concern=write_concern, **{"pull_all__%s" % field_name: self}
+ )
with set_write_concern(queryset._collection, write_concern) as collection:
result = collection.delete_many(queryset._query)
@@ -490,8 +500,9 @@ class BaseQuerySet(object):
if result.acknowledged:
return result.deleted_count
- def update(self, upsert=False, multi=True, write_concern=None,
- full_result=False, **update):
+ def update(
+ self, upsert=False, multi=True, write_concern=None, full_result=False, **update
+ ):
"""Perform an atomic update on the fields matched by the query.
:param upsert: insert if document doesn't exist (default ``False``)
@@ -511,7 +522,7 @@ class BaseQuerySet(object):
.. versionadded:: 0.2
"""
if not update and not upsert:
- raise OperationError('No update parameters, would remove data')
+ raise OperationError("No update parameters, would remove data")
if write_concern is None:
write_concern = {}
@@ -522,11 +533,11 @@ class BaseQuerySet(object):
# If doing an atomic upsert on an inheritable class
# then ensure we add _cls to the update operation
- if upsert and '_cls' in query:
- if '$set' in update:
- update['$set']['_cls'] = queryset._document._class_name
+ if upsert and "_cls" in query:
+ if "$set" in update:
+ update["$set"]["_cls"] = queryset._document._class_name
else:
- update['$set'] = {'_cls': queryset._document._class_name}
+ update["$set"] = {"_cls": queryset._document._class_name}
try:
with set_write_concern(queryset._collection, write_concern) as collection:
update_func = collection.update_one
@@ -536,14 +547,14 @@ class BaseQuerySet(object):
if full_result:
return result
elif result.raw_result:
- return result.raw_result['n']
+ return result.raw_result["n"]
except pymongo.errors.DuplicateKeyError as err:
- raise NotUniqueError(u'Update failed (%s)' % six.text_type(err))
+ raise NotUniqueError(u"Update failed (%s)" % six.text_type(err))
except pymongo.errors.OperationFailure as err:
- if six.text_type(err) == u'multi not coded yet':
- message = u'update() method requires MongoDB 1.1.3+'
+ if six.text_type(err) == u"multi not coded yet":
+ message = u"update() method requires MongoDB 1.1.3+"
raise OperationError(message)
- raise OperationError(u'Update failed (%s)' % six.text_type(err))
+ raise OperationError(u"Update failed (%s)" % six.text_type(err))
def upsert_one(self, write_concern=None, **update):
"""Overwrite or add the first document matched by the query.
@@ -561,11 +572,15 @@ class BaseQuerySet(object):
.. versionadded:: 0.10.2
"""
- atomic_update = self.update(multi=False, upsert=True,
- write_concern=write_concern,
- full_result=True, **update)
+ atomic_update = self.update(
+ multi=False,
+ upsert=True,
+ write_concern=write_concern,
+ full_result=True,
+ **update
+ )
- if atomic_update.raw_result['updatedExisting']:
+ if atomic_update.raw_result["updatedExisting"]:
document = self.get()
else:
document = self._document.objects.with_id(atomic_update.upserted_id)
@@ -594,9 +609,12 @@ class BaseQuerySet(object):
multi=False,
write_concern=write_concern,
full_result=full_result,
- **update)
+ **update
+ )
- def modify(self, upsert=False, full_response=False, remove=False, new=False, **update):
+ def modify(
+ self, upsert=False, full_response=False, remove=False, new=False, **update
+ ):
"""Update and return the updated document.
Returns either the document before or after modification based on `new`
@@ -621,11 +639,10 @@ class BaseQuerySet(object):
"""
if remove and new:
- raise OperationError('Conflicting parameters: remove and new')
+ raise OperationError("Conflicting parameters: remove and new")
if not update and not upsert and not remove:
- raise OperationError(
- 'No update parameters, must either update or remove')
+ raise OperationError("No update parameters, must either update or remove")
queryset = self.clone()
query = queryset._query
@@ -635,27 +652,35 @@ class BaseQuerySet(object):
try:
if full_response:
- msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
+ msg = "With PyMongo 3+, it is not possible anymore to get the full response."
warnings.warn(msg, DeprecationWarning)
if remove:
result = queryset._collection.find_one_and_delete(
- query, sort=sort, **self._cursor_args)
+ query, sort=sort, **self._cursor_args
+ )
else:
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
- query, update, upsert=upsert, sort=sort, return_document=return_doc,
- **self._cursor_args)
+ query,
+ update,
+ upsert=upsert,
+ sort=sort,
+ return_document=return_doc,
+ **self._cursor_args
+ )
except pymongo.errors.DuplicateKeyError as err:
- raise NotUniqueError(u'Update failed (%s)' % err)
+ raise NotUniqueError(u"Update failed (%s)" % err)
except pymongo.errors.OperationFailure as err:
- raise OperationError(u'Update failed (%s)' % err)
+ raise OperationError(u"Update failed (%s)" % err)
if full_response:
- if result['value'] is not None:
- result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields)
+ if result["value"] is not None:
+ result["value"] = self._document._from_son(
+ result["value"], only_fields=self.only_fields
+ )
else:
if result is not None:
result = self._document._from_son(result, only_fields=self.only_fields)
@@ -673,7 +698,7 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
if not queryset._query_obj.empty:
- msg = 'Cannot use a filter whilst using `with_id`'
+ msg = "Cannot use a filter whilst using `with_id`"
raise InvalidQueryError(msg)
return queryset.filter(pk=object_id).first()
@@ -688,21 +713,22 @@ class BaseQuerySet(object):
"""
doc_map = {}
- docs = self._collection.find({'_id': {'$in': object_ids}},
- **self._cursor_args)
+ docs = self._collection.find({"_id": {"$in": object_ids}}, **self._cursor_args)
if self._scalar:
for doc in docs:
- doc_map[doc['_id']] = self._get_scalar(
- self._document._from_son(doc, only_fields=self.only_fields))
+ doc_map[doc["_id"]] = self._get_scalar(
+ self._document._from_son(doc, only_fields=self.only_fields)
+ )
elif self._as_pymongo:
for doc in docs:
- doc_map[doc['_id']] = doc
+ doc_map[doc["_id"]] = doc
else:
for doc in docs:
- doc_map[doc['_id']] = self._document._from_son(
+ doc_map[doc["_id"]] = self._document._from_son(
doc,
only_fields=self.only_fields,
- _auto_dereference=self._auto_dereference)
+ _auto_dereference=self._auto_dereference,
+ )
return doc_map
@@ -717,8 +743,8 @@ class BaseQuerySet(object):
Do NOT return any inherited documents.
"""
- if self._document._meta.get('allow_inheritance') is True:
- self._initial_query = {'_cls': self._document._class_name}
+ if self._document._meta.get("allow_inheritance") is True:
+ self._initial_query = {"_cls": self._document._class_name}
return self
@@ -747,15 +773,35 @@ class BaseQuerySet(object):
"""
if not isinstance(new_qs, BaseQuerySet):
raise OperationError(
- '%s is not a subclass of BaseQuerySet' % new_qs.__name__)
+ "%s is not a subclass of BaseQuerySet" % new_qs.__name__
+ )
- copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj',
- '_where_clause', '_loaded_fields', '_ordering',
- '_snapshot', '_timeout', '_class_check', '_slave_okay',
- '_read_preference', '_iter', '_scalar', '_as_pymongo',
- '_limit', '_skip', '_hint', '_auto_dereference',
- '_search_text', 'only_fields', '_max_time_ms',
- '_comment', '_batch_size')
+ copy_props = (
+ "_mongo_query",
+ "_initial_query",
+ "_none",
+ "_query_obj",
+ "_where_clause",
+ "_loaded_fields",
+ "_ordering",
+ "_snapshot",
+ "_timeout",
+ "_class_check",
+ "_slave_okay",
+ "_read_preference",
+ "_iter",
+ "_scalar",
+ "_as_pymongo",
+ "_limit",
+ "_skip",
+ "_hint",
+ "_auto_dereference",
+ "_search_text",
+ "only_fields",
+ "_max_time_ms",
+ "_comment",
+ "_batch_size",
+ )
for prop in copy_props:
val = getattr(self, prop)
@@ -868,37 +914,43 @@ class BaseQuerySet(object):
except LookUpError:
pass
- distinct = self._dereference(queryset._cursor.distinct(field), 1,
- name=field, instance=self._document)
+ distinct = self._dereference(
+ queryset._cursor.distinct(field), 1, name=field, instance=self._document
+ )
- doc_field = self._document._fields.get(field.split('.', 1)[0])
+ doc_field = self._document._fields.get(field.split(".", 1)[0])
instance = None
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
- EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
- ListField = _import_class('ListField')
- GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField')
+ EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
+ ListField = _import_class("ListField")
+ GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
if isinstance(doc_field, ListField):
- doc_field = getattr(doc_field, 'field', doc_field)
+ doc_field = getattr(doc_field, "field", doc_field)
if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
- instance = getattr(doc_field, 'document_type', None)
+ instance = getattr(doc_field, "document_type", None)
# handle distinct on subdocuments
- if '.' in field:
- for field_part in field.split('.')[1:]:
+ if "." in field:
+ for field_part in field.split(".")[1:]:
# if looping on embedded document, get the document type instance
- if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
+ if instance and isinstance(
+ doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)
+ ):
doc_field = instance
# now get the subdocument
doc_field = getattr(doc_field, field_part, doc_field)
# We may need to cast to the correct type eg. ListField(EmbeddedDocumentField)
if isinstance(doc_field, ListField):
- doc_field = getattr(doc_field, 'field', doc_field)
- if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)):
- instance = getattr(doc_field, 'document_type', None)
+ doc_field = getattr(doc_field, "field", doc_field)
+ if isinstance(
+ doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)
+ ):
+ instance = getattr(doc_field, "document_type", None)
- if instance and isinstance(doc_field, (EmbeddedDocumentField,
- GenericEmbeddedDocumentField)):
+ if instance and isinstance(
+ doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)
+ ):
distinct = [instance(**doc) for doc in distinct]
return distinct
@@ -970,14 +1022,14 @@ class BaseQuerySet(object):
"""
# Check for an operator and transform to mongo-style if there is
- operators = ['slice']
+ operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
- parts = key.split('__')
+ parts = key.split("__")
if parts[0] in operators:
op = parts.pop(0)
- value = {'$' + op: value}
- key = '.'.join(parts)
+ value = {"$" + op: value}
+ key = ".".join(parts)
cleaned_fields.append((key, value))
# Sort fields by their values, explicitly excluded fields first, then
@@ -998,7 +1050,8 @@ class BaseQuerySet(object):
fields = [field for field, value in group]
fields = queryset._fields_to_dbfields(fields)
queryset._loaded_fields += QueryFieldList(
- fields, value=value, _only_called=_only_called)
+ fields, value=value, _only_called=_only_called
+ )
return queryset
@@ -1012,7 +1065,8 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
queryset._loaded_fields = QueryFieldList(
- always_include=queryset._loaded_fields.always_include)
+ always_include=queryset._loaded_fields.always_include
+ )
return queryset
def order_by(self, *keys):
@@ -1053,7 +1107,7 @@ class BaseQuerySet(object):
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
- return self._chainable_method('comment', text)
+ return self._chainable_method("comment", text)
def explain(self, format=False):
"""Return an explain plan record for the
@@ -1066,8 +1120,10 @@ class BaseQuerySet(object):
# TODO remove this option completely - it's useless. If somebody
# wants to pretty-print the output, they easily can.
if format:
- msg = ('"format" param of BaseQuerySet.explain has been '
- 'deprecated and will be removed in future versions.')
+ msg = (
+ '"format" param of BaseQuerySet.explain has been '
+ "deprecated and will be removed in future versions."
+ )
warnings.warn(msg, DeprecationWarning)
plan = pprint.pformat(plan)
@@ -1082,7 +1138,7 @@ class BaseQuerySet(object):
..versionchanged:: 0.5 - made chainable
.. deprecated:: Ignored with PyMongo 3+
"""
- msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
+ msg = "snapshot is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._snapshot = enabled
@@ -1107,7 +1163,7 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+
"""
- msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
+ msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
@@ -1119,10 +1175,12 @@ class BaseQuerySet(object):
:param read_preference: override ReplicaSetConnection-level
preference.
"""
- validate_read_preference('read_preference', read_preference)
+ validate_read_preference("read_preference", read_preference)
queryset = self.clone()
queryset._read_preference = read_preference
- queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference
+ queryset._cursor_obj = (
+ None
+ ) # we need to re-create the cursor object whenever we apply read_preference
return queryset
def scalar(self, *fields):
@@ -1168,7 +1226,7 @@ class BaseQuerySet(object):
:param ms: the number of milliseconds before killing the query on the server
"""
- return self._chainable_method('max_time_ms', ms)
+ return self._chainable_method("max_time_ms", ms)
# JSON Helpers
@@ -1179,7 +1237,10 @@ class BaseQuerySet(object):
def from_json(self, json_data):
"""Converts json data to unsaved objects"""
son_data = json_util.loads(json_data)
- return [self._document._from_son(data, only_fields=self.only_fields) for data in son_data]
+ return [
+ self._document._from_son(data, only_fields=self.only_fields)
+ for data in son_data
+ ]
def aggregate(self, *pipeline, **kwargs):
"""
@@ -1192,32 +1253,34 @@ class BaseQuerySet(object):
initial_pipeline = []
if self._query:
- initial_pipeline.append({'$match': self._query})
+ initial_pipeline.append({"$match": self._query})
if self._ordering:
- initial_pipeline.append({'$sort': dict(self._ordering)})
+ initial_pipeline.append({"$sort": dict(self._ordering)})
if self._limit is not None:
# As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/),
# keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents
# for a skip stage that might succeed these. So we need to maintain more documents in memory in such a
# case (https://stackoverflow.com/a/24161461).
- initial_pipeline.append({'$limit': self._limit + (self._skip or 0)})
+ initial_pipeline.append({"$limit": self._limit + (self._skip or 0)})
if self._skip is not None:
- initial_pipeline.append({'$skip': self._skip})
+ initial_pipeline.append({"$skip": self._skip})
pipeline = initial_pipeline + list(pipeline)
if self._read_preference is not None:
- return self._collection.with_options(read_preference=self._read_preference) \
- .aggregate(pipeline, cursor={}, **kwargs)
+ return self._collection.with_options(
+ read_preference=self._read_preference
+ ).aggregate(pipeline, cursor={}, **kwargs)
return self._collection.aggregate(pipeline, cursor={}, **kwargs)
# JS functionality
- def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
- scope=None):
+ def map_reduce(
+ self, map_f, reduce_f, output, finalize_f=None, limit=None, scope=None
+ ):
"""Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
it must be the last call made, as it does not return a maleable
@@ -1257,10 +1320,10 @@ class BaseQuerySet(object):
"""
queryset = self.clone()
- MapReduceDocument = _import_class('MapReduceDocument')
+ MapReduceDocument = _import_class("MapReduceDocument")
- if not hasattr(self._collection, 'map_reduce'):
- raise NotImplementedError('Requires MongoDB >= 1.7.1')
+ if not hasattr(self._collection, "map_reduce"):
+ raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {}
if isinstance(map_f, Code):
@@ -1275,7 +1338,7 @@ class BaseQuerySet(object):
reduce_f_code = queryset._sub_js_fields(reduce_f)
reduce_f = Code(reduce_f_code, reduce_f_scope)
- mr_args = {'query': queryset._query}
+ mr_args = {"query": queryset._query}
if finalize_f:
finalize_f_scope = {}
@@ -1284,39 +1347,39 @@ class BaseQuerySet(object):
finalize_f = six.text_type(finalize_f)
finalize_f_code = queryset._sub_js_fields(finalize_f)
finalize_f = Code(finalize_f_code, finalize_f_scope)
- mr_args['finalize'] = finalize_f
+ mr_args["finalize"] = finalize_f
if scope:
- mr_args['scope'] = scope
+ mr_args["scope"] = scope
if limit:
- mr_args['limit'] = limit
+ mr_args["limit"] = limit
- if output == 'inline' and not queryset._ordering:
- map_reduce_function = 'inline_map_reduce'
+ if output == "inline" and not queryset._ordering:
+ map_reduce_function = "inline_map_reduce"
else:
- map_reduce_function = 'map_reduce'
+ map_reduce_function = "map_reduce"
if isinstance(output, six.string_types):
- mr_args['out'] = output
+ mr_args["out"] = output
elif isinstance(output, dict):
ordered_output = []
- for part in ('replace', 'merge', 'reduce'):
+ for part in ("replace", "merge", "reduce"):
value = output.get(part)
if value:
ordered_output.append((part, value))
break
else:
- raise OperationError('actionData not specified for output')
+ raise OperationError("actionData not specified for output")
- db_alias = output.get('db_alias')
- remaing_args = ['db', 'sharded', 'nonAtomic']
+ db_alias = output.get("db_alias")
+ remaing_args = ["db", "sharded", "nonAtomic"]
if db_alias:
- ordered_output.append(('db', get_db(db_alias).name))
+ ordered_output.append(("db", get_db(db_alias).name))
del remaing_args[0]
for part in remaing_args:
@@ -1324,20 +1387,22 @@ class BaseQuerySet(object):
if value:
ordered_output.append((part, value))
- mr_args['out'] = SON(ordered_output)
+ mr_args["out"] = SON(ordered_output)
results = getattr(queryset._collection, map_reduce_function)(
- map_f, reduce_f, **mr_args)
+ map_f, reduce_f, **mr_args
+ )
- if map_reduce_function == 'map_reduce':
+ if map_reduce_function == "map_reduce":
results = results.find()
if queryset._ordering:
results = results.sort(queryset._ordering)
for doc in results:
- yield MapReduceDocument(queryset._document, queryset._collection,
- doc['_id'], doc['value'])
+ yield MapReduceDocument(
+ queryset._document, queryset._collection, doc["_id"], doc["value"]
+ )
def exec_js(self, code, *fields, **options):
"""Execute a Javascript function on the server. A list of fields may be
@@ -1368,16 +1433,13 @@ class BaseQuerySet(object):
fields = [queryset._document._translate_field_name(f) for f in fields]
collection = queryset._document._get_collection_name()
- scope = {
- 'collection': collection,
- 'options': options or {},
- }
+ scope = {"collection": collection, "options": options or {}}
query = queryset._query
if queryset._where_clause:
- query['$where'] = queryset._where_clause
+ query["$where"] = queryset._where_clause
- scope['query'] = query
+ scope["query"] = query
code = Code(code, scope=scope)
db = queryset._document._get_db()
@@ -1407,22 +1469,22 @@ class BaseQuerySet(object):
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
- {'$match': self._query},
- {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}}
+ {"$match": self._query},
+ {"$group": {"_id": "sum", "total": {"$sum": "$" + db_field}}},
]
# if we're performing a sum over a list field, we sum up all the
# elements in the list, hence we need to $unwind the arrays first
- ListField = _import_class('ListField')
- field_parts = field.split('.')
+ ListField = _import_class("ListField")
+ field_parts = field.split(".")
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
- pipeline.insert(1, {'$unwind': '$' + field})
+ pipeline.insert(1, {"$unwind": "$" + field})
result = tuple(self._document._get_collection().aggregate(pipeline))
if result:
- return result[0]['total']
+ return result[0]["total"]
return 0
def average(self, field):
@@ -1433,22 +1495,22 @@ class BaseQuerySet(object):
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
- {'$match': self._query},
- {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}}
+ {"$match": self._query},
+ {"$group": {"_id": "avg", "total": {"$avg": "$" + db_field}}},
]
# if we're performing an average over a list field, we average out
# all the elements in the list, hence we need to $unwind the arrays
# first
- ListField = _import_class('ListField')
- field_parts = field.split('.')
+ ListField = _import_class("ListField")
+ field_parts = field.split(".")
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
- pipeline.insert(1, {'$unwind': '$' + field})
+ pipeline.insert(1, {"$unwind": "$" + field})
result = tuple(self._document._get_collection().aggregate(pipeline))
if result:
- return result[0]['total']
+ return result[0]["total"]
return 0
def item_frequencies(self, field, normalize=False, map_reduce=True):
@@ -1474,8 +1536,7 @@ class BaseQuerySet(object):
document lookups
"""
if map_reduce:
- return self._item_frequencies_map_reduce(field,
- normalize=normalize)
+ return self._item_frequencies_map_reduce(field, normalize=normalize)
return self._item_frequencies_exec_js(field, normalize=normalize)
# Iterator helpers
@@ -1492,15 +1553,17 @@ class BaseQuerySet(object):
return raw_doc
doc = self._document._from_son(
- raw_doc, _auto_dereference=self._auto_dereference,
- only_fields=self.only_fields)
+ raw_doc,
+ _auto_dereference=self._auto_dereference,
+ only_fields=self.only_fields,
+ )
if self._scalar:
return self._get_scalar(doc)
return doc
- next = __next__ # For Python2 support
+ next = __next__ # For Python2 support
def rewind(self):
"""Rewind the cursor to its unevaluated state.
@@ -1521,15 +1584,13 @@ class BaseQuerySet(object):
@property
def _cursor_args(self):
- fields_name = 'projection'
+ fields_name = "projection"
# snapshot is not handled at all by PyMongo 3+
# TODO: evaluate similar possibilities using modifiers
if self._snapshot:
- msg = 'The snapshot option is not anymore available with PyMongo 3+'
+ msg = "The snapshot option is not anymore available with PyMongo 3+"
warnings.warn(msg, DeprecationWarning)
- cursor_args = {
- 'no_cursor_timeout': not self._timeout
- }
+ cursor_args = {"no_cursor_timeout": not self._timeout}
if self._loaded_fields:
cursor_args[fields_name] = self._loaded_fields.as_dict()
@@ -1538,7 +1599,7 @@ class BaseQuerySet(object):
if fields_name not in cursor_args:
cursor_args[fields_name] = {}
- cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'}
+ cursor_args[fields_name]["_text_score"] = {"$meta": "textScore"}
return cursor_args
@@ -1555,12 +1616,11 @@ class BaseQuerySet(object):
# level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first.
if self._read_preference is not None:
- self._cursor_obj = self._collection\
- .with_options(read_preference=self._read_preference)\
- .find(self._query, **self._cursor_args)
+ self._cursor_obj = self._collection.with_options(
+ read_preference=self._read_preference
+ ).find(self._query, **self._cursor_args)
else:
- self._cursor_obj = self._collection.find(self._query,
- **self._cursor_args)
+ self._cursor_obj = self._collection.find(self._query, **self._cursor_args)
# Apply "where" clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
@@ -1576,9 +1636,9 @@ class BaseQuerySet(object):
if self._ordering:
# explicit ordering
self._cursor_obj.sort(self._ordering)
- elif self._ordering is None and self._document._meta['ordering']:
+ elif self._ordering is None and self._document._meta["ordering"]:
# default ordering
- order = self._get_order_by(self._document._meta['ordering'])
+ order = self._get_order_by(self._document._meta["ordering"])
self._cursor_obj.sort(order)
if self._limit is not None:
@@ -1607,8 +1667,10 @@ class BaseQuerySet(object):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
if self._class_check and self._initial_query:
- if '_cls' in self._mongo_query:
- self._mongo_query = {'$and': [self._initial_query, self._mongo_query]}
+ if "_cls" in self._mongo_query:
+ self._mongo_query = {
+ "$and": [self._initial_query, self._mongo_query]
+ }
else:
self._mongo_query.update(self._initial_query)
return self._mongo_query
@@ -1616,7 +1678,7 @@ class BaseQuerySet(object):
@property
def _dereference(self):
if not self.__dereference:
- self.__dereference = _import_class('DeReference')()
+ self.__dereference = _import_class("DeReference")()
return self.__dereference
def no_dereference(self):
@@ -1649,7 +1711,9 @@ class BaseQuerySet(object):
emit(null, 1);
}
}
- """ % {'field': field}
+ """ % {
+ "field": field
+ }
reduce_func = """
function(key, values) {
var total = 0;
@@ -1660,7 +1724,7 @@ class BaseQuerySet(object):
return total;
}
"""
- values = self.map_reduce(map_func, reduce_func, 'inline')
+ values = self.map_reduce(map_func, reduce_func, "inline")
frequencies = {}
for f in values:
key = f.key
@@ -1671,8 +1735,7 @@ class BaseQuerySet(object):
if normalize:
count = sum(frequencies.values())
- frequencies = {k: float(v) / count
- for k, v in frequencies.items()}
+ frequencies = {k: float(v) / count for k, v in frequencies.items()}
return frequencies
@@ -1742,15 +1805,14 @@ class BaseQuerySet(object):
def _fields_to_dbfields(self, fields):
"""Translate fields' paths to their db equivalents."""
subclasses = []
- if self._document._meta['allow_inheritance']:
- subclasses = [get_document(x)
- for x in self._document._subclasses][1:]
+ if self._document._meta["allow_inheritance"]:
+ subclasses = [get_document(x) for x in self._document._subclasses][1:]
db_field_paths = []
for field in fields:
- field_parts = field.split('.')
+ field_parts = field.split(".")
try:
- field = '.'.join(
+ field = ".".join(
f if isinstance(f, six.string_types) else f.db_field
for f in self._document._lookup_field(field_parts)
)
@@ -1762,7 +1824,7 @@ class BaseQuerySet(object):
# through its subclasses and see if it exists on any of them.
for subdoc in subclasses:
try:
- subfield = '.'.join(
+ subfield = ".".join(
f if isinstance(f, six.string_types) else f.db_field
for f in subdoc._lookup_field(field_parts)
)
@@ -1790,18 +1852,18 @@ class BaseQuerySet(object):
if not key:
continue
- if key == '$text_score':
- key_list.append(('_text_score', {'$meta': 'textScore'}))
+ if key == "$text_score":
+ key_list.append(("_text_score", {"$meta": "textScore"}))
continue
direction = pymongo.ASCENDING
- if key[0] == '-':
+ if key[0] == "-":
direction = pymongo.DESCENDING
- if key[0] in ('-', '+'):
+ if key[0] in ("-", "+"):
key = key[1:]
- key = key.replace('__', '.')
+ key = key.replace("__", ".")
try:
key = self._document._translate_field_name(key)
except Exception:
@@ -1813,9 +1875,8 @@ class BaseQuerySet(object):
return key_list
def _get_scalar(self, doc):
-
def lookup(obj, name):
- chunks = name.split('__')
+ chunks = name.split("__")
for chunk in chunks:
obj = getattr(obj, chunk)
return obj
@@ -1835,21 +1896,20 @@ class BaseQuerySet(object):
def field_sub(match):
# Extract just the field name, and look up the field objects
- field_name = match.group(1).split('.')
+ field_name = match.group(1).split(".")
fields = self._document._lookup_field(field_name)
# Substitute the correct name for the field into the javascript
return u'["%s"]' % fields[-1].db_field
def field_path_sub(match):
# Extract just the field name, and look up the field objects
- field_name = match.group(1).split('.')
+ field_name = match.group(1).split(".")
fields = self._document._lookup_field(field_name)
# Substitute the correct name for the field into the javascript
- return '.'.join([f.db_field for f in fields])
+ return ".".join([f.db_field for f in fields])
- code = re.sub(r'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code)
- code = re.sub(r'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub,
- code)
+ code = re.sub(r"\[\s*~([A-z_][A-z_0-9.]+?)\s*\]", field_sub, code)
+ code = re.sub(r"\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}", field_path_sub, code)
return code
def _chainable_method(self, method_name, val):
@@ -1866,22 +1926,26 @@ class BaseQuerySet(object):
getattr(cursor, method_name)(val)
# Cache the value on the queryset._{method_name}
- setattr(queryset, '_' + method_name, val)
+ setattr(queryset, "_" + method_name, val)
return queryset
# Deprecated
def ensure_index(self, **kwargs):
"""Deprecated use :func:`Document.ensure_index`"""
- msg = ('Doc.objects()._ensure_index() is deprecated. '
- 'Use Doc.ensure_index() instead.')
+ msg = (
+ "Doc.objects()._ensure_index() is deprecated. "
+ "Use Doc.ensure_index() instead."
+ )
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_index(**kwargs)
return self
def _ensure_indexes(self):
"""Deprecated use :func:`~Document.ensure_indexes`"""
- msg = ('Doc.objects()._ensure_indexes() is deprecated. '
- 'Use Doc.ensure_indexes() instead.')
+ msg = (
+ "Doc.objects()._ensure_indexes() is deprecated. "
+ "Use Doc.ensure_indexes() instead."
+ )
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes()
diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py
index dba724af..5c3ff222 100644
--- a/mongoengine/queryset/field_list.py
+++ b/mongoengine/queryset/field_list.py
@@ -1,12 +1,15 @@
-__all__ = ('QueryFieldList',)
+__all__ = ("QueryFieldList",)
class QueryFieldList(object):
"""Object that handles combinations of .only() and .exclude() calls"""
+
ONLY = 1
EXCLUDE = 0
- def __init__(self, fields=None, value=ONLY, always_include=None, _only_called=False):
+ def __init__(
+ self, fields=None, value=ONLY, always_include=None, _only_called=False
+ ):
"""The QueryFieldList builder
:param fields: A list of fields used in `.only()` or `.exclude()`
@@ -49,7 +52,7 @@ class QueryFieldList(object):
self.fields = f.fields - self.fields
self._clean_slice()
- if '_id' in f.fields:
+ if "_id" in f.fields:
self._id = f.value
if self.always_include:
@@ -59,7 +62,7 @@ class QueryFieldList(object):
else:
self.fields -= self.always_include
- if getattr(f, '_only_called', False):
+ if getattr(f, "_only_called", False):
self._only_called = True
return self
@@ -73,7 +76,7 @@ class QueryFieldList(object):
if self.slice:
field_list.update(self.slice)
if self._id is not None:
- field_list['_id'] = self._id
+ field_list["_id"] = self._id
return field_list
def reset(self):
diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py
index f93dbb43..5067ffbf 100644
--- a/mongoengine/queryset/manager.py
+++ b/mongoengine/queryset/manager.py
@@ -1,7 +1,7 @@
from functools import partial
from mongoengine.queryset.queryset import QuerySet
-__all__ = ('queryset_manager', 'QuerySetManager')
+__all__ = ("queryset_manager", "QuerySetManager")
class QuerySetManager(object):
@@ -33,7 +33,7 @@ class QuerySetManager(object):
return self
# owner is the document that contains the QuerySetManager
- queryset_class = owner._meta.get('queryset_class', self.default)
+ queryset_class = owner._meta.get("queryset_class", self.default)
queryset = queryset_class(owner, owner._get_collection())
if self.get_queryset:
arg_count = self.get_queryset.__code__.co_argcount
diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py
index c7c593b1..4ba62d46 100644
--- a/mongoengine/queryset/queryset.py
+++ b/mongoengine/queryset/queryset.py
@@ -1,11 +1,24 @@
import six
from mongoengine.errors import OperationError
-from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING,
- NULLIFY, PULL)
+from mongoengine.queryset.base import (
+ BaseQuerySet,
+ CASCADE,
+ DENY,
+ DO_NOTHING,
+ NULLIFY,
+ PULL,
+)
-__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE',
- 'DENY', 'PULL')
+__all__ = (
+ "QuerySet",
+ "QuerySetNoCache",
+ "DO_NOTHING",
+ "NULLIFY",
+ "CASCADE",
+ "DENY",
+ "PULL",
+)
# The maximum number of items to display in a QuerySet.__repr__
REPR_OUTPUT_SIZE = 20
@@ -57,12 +70,12 @@ class QuerySet(BaseQuerySet):
def __repr__(self):
"""Provide a string representation of the QuerySet"""
if self._iter:
- return '.. queryset mid-iteration ..'
+ return ".. queryset mid-iteration .."
self._populate_cache()
- data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
+ data = self._result_cache[: REPR_OUTPUT_SIZE + 1]
if len(data) > REPR_OUTPUT_SIZE:
- data[-1] = '...(remaining elements truncated)...'
+ data[-1] = "...(remaining elements truncated)..."
return repr(data)
def _iter_results(self):
@@ -143,10 +156,9 @@ class QuerySet(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
- raise OperationError('QuerySet already cached')
+ raise OperationError("QuerySet already cached")
- return self._clone_into(QuerySetNoCache(self._document,
- self._collection))
+ return self._clone_into(QuerySetNoCache(self._document, self._collection))
class QuerySetNoCache(BaseQuerySet):
@@ -165,7 +177,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionchanged:: 0.6.13 Now doesnt modify the cursor
"""
if self._iter:
- return '.. queryset mid-iteration ..'
+ return ".. queryset mid-iteration .."
data = []
for _ in six.moves.range(REPR_OUTPUT_SIZE + 1):
@@ -175,7 +187,7 @@ class QuerySetNoCache(BaseQuerySet):
break
if len(data) > REPR_OUTPUT_SIZE:
- data[-1] = '...(remaining elements truncated)...'
+ data[-1] = "...(remaining elements truncated)..."
self.rewind()
return repr(data)
diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py
index 128a4e44..0b73e99b 100644
--- a/mongoengine/queryset/transform.py
+++ b/mongoengine/queryset/transform.py
@@ -10,21 +10,54 @@ from mongoengine.base import UPDATE_OPERATORS
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError
-__all__ = ('query', 'update')
+__all__ = ("query", "update")
-COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
- 'all', 'size', 'exists', 'not', 'elemMatch', 'type')
-GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
- 'within_box', 'within_polygon', 'near', 'near_sphere',
- 'max_distance', 'min_distance', 'geo_within', 'geo_within_box',
- 'geo_within_polygon', 'geo_within_center',
- 'geo_within_sphere', 'geo_intersects')
-STRING_OPERATORS = ('contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith',
- 'exact', 'iexact')
-CUSTOM_OPERATORS = ('match',)
-MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
- STRING_OPERATORS + CUSTOM_OPERATORS)
+COMPARISON_OPERATORS = (
+ "ne",
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "in",
+ "nin",
+ "mod",
+ "all",
+ "size",
+ "exists",
+ "not",
+ "elemMatch",
+ "type",
+)
+GEO_OPERATORS = (
+ "within_distance",
+ "within_spherical_distance",
+ "within_box",
+ "within_polygon",
+ "near",
+ "near_sphere",
+ "max_distance",
+ "min_distance",
+ "geo_within",
+ "geo_within_box",
+ "geo_within_polygon",
+ "geo_within_center",
+ "geo_within_sphere",
+ "geo_intersects",
+)
+STRING_OPERATORS = (
+ "contains",
+ "icontains",
+ "startswith",
+ "istartswith",
+ "endswith",
+ "iendswith",
+ "exact",
+ "iexact",
+)
+CUSTOM_OPERATORS = ("match",)
+MATCH_OPERATORS = (
+ COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS
+)
# TODO make this less complex
@@ -33,11 +66,11 @@ def query(_doc_cls=None, **kwargs):
mongo_query = {}
merge_query = defaultdict(list)
for key, value in sorted(kwargs.items()):
- if key == '__raw__':
+ if key == "__raw__":
mongo_query.update(value)
continue
- parts = key.rsplit('__')
+ parts = key.rsplit("__")
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
@@ -46,11 +79,11 @@ def query(_doc_cls=None, **kwargs):
op = parts.pop()
# Allow to escape operator-like field name by __
- if len(parts) > 1 and parts[-1] == '':
+ if len(parts) > 1 and parts[-1] == "":
parts.pop()
negate = False
- if len(parts) > 1 and parts[-1] == 'not':
+ if len(parts) > 1 and parts[-1] == "not":
parts.pop()
negate = True
@@ -62,8 +95,8 @@ def query(_doc_cls=None, **kwargs):
raise InvalidQueryError(e)
parts = []
- CachedReferenceField = _import_class('CachedReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
+ CachedReferenceField = _import_class("CachedReferenceField")
+ GenericReferenceField = _import_class("GenericReferenceField")
cleaned_fields = []
for field in fields:
@@ -73,7 +106,7 @@ def query(_doc_cls=None, **kwargs):
append_field = False
# is last and CachedReferenceField
elif isinstance(field, CachedReferenceField) and fields[-1] == field:
- parts.append('%s._id' % field.db_field)
+ parts.append("%s._id" % field.db_field)
else:
parts.append(field.db_field)
@@ -83,15 +116,15 @@ def query(_doc_cls=None, **kwargs):
# Convert value to proper value
field = cleaned_fields[-1]
- singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
+ singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"]
singular_ops += STRING_OPERATORS
if op in singular_ops:
value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value:
- value = value['_id']
+ value = value["_id"]
- elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
+ elif op in ("in", "nin", "all", "near") and not isinstance(value, dict):
# Raise an error if the in/nin/all/near param is not iterable.
value = _prepare_query_for_iterable(field, op, value)
@@ -101,39 +134,40 @@ def query(_doc_cls=None, **kwargs):
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
if isinstance(field, GenericReferenceField):
if isinstance(value, DBRef):
- parts[-1] += '._ref'
+ parts[-1] += "._ref"
elif isinstance(value, ObjectId):
- parts[-1] += '._ref.$id'
+ parts[-1] += "._ref.$id"
# if op and op not in COMPARISON_OPERATORS:
if op:
if op in GEO_OPERATORS:
value = _geo_operator(field, op, value)
- elif op in ('match', 'elemMatch'):
- ListField = _import_class('ListField')
- EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
+ elif op in ("match", "elemMatch"):
+ ListField = _import_class("ListField")
+ EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
if (
- isinstance(value, dict) and
- isinstance(field, ListField) and
- isinstance(field.field, EmbeddedDocumentField)
+ isinstance(value, dict)
+ and isinstance(field, ListField)
+ and isinstance(field.field, EmbeddedDocumentField)
):
value = query(field.field.document_type, **value)
else:
value = field.prepare_query_value(op, value)
- value = {'$elemMatch': value}
+ value = {"$elemMatch": value}
elif op in CUSTOM_OPERATORS:
- NotImplementedError('Custom method "%s" has not '
- 'been implemented' % op)
+ NotImplementedError(
+ 'Custom method "%s" has not ' "been implemented" % op
+ )
elif op not in STRING_OPERATORS:
- value = {'$' + op: value}
+ value = {"$" + op: value}
if negate:
- value = {'$not': value}
+ value = {"$not": value}
for i, part in indices:
parts.insert(i, part)
- key = '.'.join(parts)
+ key = ".".join(parts)
if op is None or key not in mongo_query:
mongo_query[key] = value
@@ -142,30 +176,35 @@ def query(_doc_cls=None, **kwargs):
mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key]
- if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
- ('$near' in value_dict or '$nearSphere' in value_dict):
+ if ("$maxDistance" in value_dict or "$minDistance" in value_dict) and (
+ "$near" in value_dict or "$nearSphere" in value_dict
+ ):
value_son = SON()
for k, v in iteritems(value_dict):
- if k == '$maxDistance' or k == '$minDistance':
+ if k == "$maxDistance" or k == "$minDistance":
continue
value_son[k] = v
# Required for MongoDB >= 2.6, may fail when combining
# PyMongo 3+ and MongoDB < 2.6
near_embedded = False
- for near_op in ('$near', '$nearSphere'):
+ for near_op in ("$near", "$nearSphere"):
if isinstance(value_dict.get(near_op), dict):
value_son[near_op] = SON(value_son[near_op])
- if '$maxDistance' in value_dict:
- value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
- if '$minDistance' in value_dict:
- value_son[near_op]['$minDistance'] = value_dict['$minDistance']
+ if "$maxDistance" in value_dict:
+ value_son[near_op]["$maxDistance"] = value_dict[
+ "$maxDistance"
+ ]
+ if "$minDistance" in value_dict:
+ value_son[near_op]["$minDistance"] = value_dict[
+ "$minDistance"
+ ]
near_embedded = True
if not near_embedded:
- if '$maxDistance' in value_dict:
- value_son['$maxDistance'] = value_dict['$maxDistance']
- if '$minDistance' in value_dict:
- value_son['$minDistance'] = value_dict['$minDistance']
+ if "$maxDistance" in value_dict:
+ value_son["$maxDistance"] = value_dict["$maxDistance"]
+ if "$minDistance" in value_dict:
+ value_son["$minDistance"] = value_dict["$minDistance"]
mongo_query[key] = value_son
else:
# Store for manually merging later
@@ -177,10 +216,10 @@ def query(_doc_cls=None, **kwargs):
del mongo_query[k]
if isinstance(v, list):
value = [{k: val} for val in v]
- if '$and' in mongo_query.keys():
- mongo_query['$and'].extend(value)
+ if "$and" in mongo_query.keys():
+ mongo_query["$and"].extend(value)
else:
- mongo_query['$and'] = value
+ mongo_query["$and"] = value
return mongo_query
@@ -192,15 +231,15 @@ def update(_doc_cls=None, **update):
mongo_update = {}
for key, value in update.items():
- if key == '__raw__':
+ if key == "__raw__":
mongo_update.update(value)
continue
- parts = key.split('__')
+ parts = key.split("__")
# if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
- parts.insert(0, 'set')
+ parts.insert(0, "set")
# Check for an operator and transform to mongo-style if there is
op = None
@@ -208,13 +247,13 @@ def update(_doc_cls=None, **update):
op = parts.pop(0)
# Convert Pythonic names to Mongo equivalents
operator_map = {
- 'push_all': 'pushAll',
- 'pull_all': 'pullAll',
- 'dec': 'inc',
- 'add_to_set': 'addToSet',
- 'set_on_insert': 'setOnInsert'
+ "push_all": "pushAll",
+ "pull_all": "pullAll",
+ "dec": "inc",
+ "add_to_set": "addToSet",
+ "set_on_insert": "setOnInsert",
}
- if op == 'dec':
+ if op == "dec":
# Support decrement by flipping a positive value's sign
# and using 'inc'
value = -value
@@ -227,7 +266,7 @@ def update(_doc_cls=None, **update):
match = parts.pop()
# Allow to escape operator-like field name by __
- if len(parts) > 1 and parts[-1] == '':
+ if len(parts) > 1 and parts[-1] == "":
parts.pop()
if _doc_cls:
@@ -244,8 +283,8 @@ def update(_doc_cls=None, **update):
append_field = True
if isinstance(field, six.string_types):
# Convert the S operator to $
- if field == 'S':
- field = '$'
+ if field == "S":
+ field = "$"
parts.append(field)
append_field = False
else:
@@ -253,7 +292,7 @@ def update(_doc_cls=None, **update):
if append_field:
appended_sub_field = False
cleaned_fields.append(field)
- if hasattr(field, 'field'):
+ if hasattr(field, "field"):
cleaned_fields.append(field.field)
appended_sub_field = True
@@ -263,52 +302,53 @@ def update(_doc_cls=None, **update):
else:
field = cleaned_fields[-1]
- GeoJsonBaseField = _import_class('GeoJsonBaseField')
+ GeoJsonBaseField = _import_class("GeoJsonBaseField")
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
- if op == 'pull':
+ if op == "pull":
if field.required or value is not None:
- if match in ('in', 'nin') and not isinstance(value, dict):
+ if match in ("in", "nin") and not isinstance(value, dict):
value = _prepare_query_for_iterable(field, op, value)
else:
value = field.prepare_query_value(op, value)
- elif op == 'push' and isinstance(value, (list, tuple, set)):
+ elif op == "push" and isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
- elif op in (None, 'set', 'push'):
+ elif op in (None, "set", "push"):
if field.required or value is not None:
value = field.prepare_query_value(op, value)
- elif op in ('pushAll', 'pullAll'):
+ elif op in ("pushAll", "pullAll"):
value = [field.prepare_query_value(op, v) for v in value]
- elif op in ('addToSet', 'setOnInsert'):
+ elif op in ("addToSet", "setOnInsert"):
if isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None:
value = field.prepare_query_value(op, value)
- elif op == 'unset':
+ elif op == "unset":
value = 1
- elif op == 'inc':
+ elif op == "inc":
value = field.prepare_query_value(op, value)
if match:
- match = '$' + match
+ match = "$" + match
value = {match: value}
- key = '.'.join(parts)
+ key = ".".join(parts)
- if 'pull' in op and '.' in key:
+ if "pull" in op and "." in key:
# Dot operators don't work on pull operations
# unless they point to a list field
# Otherwise it uses nested dict syntax
- if op == 'pullAll':
- raise InvalidQueryError('pullAll operations only support '
- 'a single field depth')
+ if op == "pullAll":
+ raise InvalidQueryError(
+ "pullAll operations only support a single field depth"
+ )
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
- ListField = _import_class('ListField')
- EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
+ ListField = _import_class("ListField")
+ EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField")
if ListField in field_classes or EmbeddedDocumentListField in field_classes:
# Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
# Then process as normal
@@ -317,37 +357,36 @@ def update(_doc_cls=None, **update):
else:
_check_field = EmbeddedDocumentListField
- last_listField = len(
- cleaned_fields) - field_classes.index(_check_field)
- key = '.'.join(parts[:last_listField])
+ last_listField = len(cleaned_fields) - field_classes.index(_check_field)
+ key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse()
for key in parts:
value = {key: value}
- elif op == 'addToSet' and isinstance(value, list):
- value = {key: {'$each': value}}
- elif op in ('push', 'pushAll'):
+ elif op == "addToSet" and isinstance(value, list):
+ value = {key: {"$each": value}}
+ elif op in ("push", "pushAll"):
if parts[-1].isdigit():
- key = '.'.join(parts[0:-1])
+ key = ".".join(parts[0:-1])
position = int(parts[-1])
# $position expects an iterable. If pushing a single value,
# wrap it in a list.
if not isinstance(value, (set, tuple, list)):
value = [value]
- value = {key: {'$each': value, '$position': position}}
+ value = {key: {"$each": value, "$position": position}}
else:
- if op == 'pushAll':
- op = 'push' # convert to non-deprecated keyword
+ if op == "pushAll":
+ op = "push" # convert to non-deprecated keyword
if not isinstance(value, (set, tuple, list)):
value = [value]
- value = {key: {'$each': value}}
+ value = {key: {"$each": value}}
else:
value = {key: value}
else:
value = {key: value}
- key = '$' + op
+ key = "$" + op
if key not in mongo_update:
mongo_update[key] = value
elif key in mongo_update and isinstance(mongo_update[key], dict):
@@ -358,45 +397,45 @@ def update(_doc_cls=None, **update):
def _geo_operator(field, op, value):
"""Helper to return the query for a given geo query."""
- if op == 'max_distance':
- value = {'$maxDistance': value}
- elif op == 'min_distance':
- value = {'$minDistance': value}
+ if op == "max_distance":
+ value = {"$maxDistance": value}
+ elif op == "min_distance":
+ value = {"$minDistance": value}
elif field._geo_index == pymongo.GEO2D:
- if op == 'within_distance':
- value = {'$within': {'$center': value}}
- elif op == 'within_spherical_distance':
- value = {'$within': {'$centerSphere': value}}
- elif op == 'within_polygon':
- value = {'$within': {'$polygon': value}}
- elif op == 'near':
- value = {'$near': value}
- elif op == 'near_sphere':
- value = {'$nearSphere': value}
- elif op == 'within_box':
- value = {'$within': {'$box': value}}
- else:
- raise NotImplementedError('Geo method "%s" has not been '
- 'implemented for a GeoPointField' % op)
- else:
- if op == 'geo_within':
- value = {'$geoWithin': _infer_geometry(value)}
- elif op == 'geo_within_box':
- value = {'$geoWithin': {'$box': value}}
- elif op == 'geo_within_polygon':
- value = {'$geoWithin': {'$polygon': value}}
- elif op == 'geo_within_center':
- value = {'$geoWithin': {'$center': value}}
- elif op == 'geo_within_sphere':
- value = {'$geoWithin': {'$centerSphere': value}}
- elif op == 'geo_intersects':
- value = {'$geoIntersects': _infer_geometry(value)}
- elif op == 'near':
- value = {'$near': _infer_geometry(value)}
+ if op == "within_distance":
+ value = {"$within": {"$center": value}}
+ elif op == "within_spherical_distance":
+ value = {"$within": {"$centerSphere": value}}
+ elif op == "within_polygon":
+ value = {"$within": {"$polygon": value}}
+ elif op == "near":
+ value = {"$near": value}
+ elif op == "near_sphere":
+ value = {"$nearSphere": value}
+ elif op == "within_box":
+ value = {"$within": {"$box": value}}
else:
raise NotImplementedError(
- 'Geo method "%s" has not been implemented for a %s '
- % (op, field._name)
+ 'Geo method "%s" has not been ' "implemented for a GeoPointField" % op
+ )
+ else:
+ if op == "geo_within":
+ value = {"$geoWithin": _infer_geometry(value)}
+ elif op == "geo_within_box":
+ value = {"$geoWithin": {"$box": value}}
+ elif op == "geo_within_polygon":
+ value = {"$geoWithin": {"$polygon": value}}
+ elif op == "geo_within_center":
+ value = {"$geoWithin": {"$center": value}}
+ elif op == "geo_within_sphere":
+ value = {"$geoWithin": {"$centerSphere": value}}
+ elif op == "geo_intersects":
+ value = {"$geoIntersects": _infer_geometry(value)}
+ elif op == "near":
+ value = {"$near": _infer_geometry(value)}
+ else:
+ raise NotImplementedError(
+ 'Geo method "%s" has not been implemented for a %s ' % (op, field._name)
)
return value
@@ -406,51 +445,58 @@ def _infer_geometry(value):
given value.
"""
if isinstance(value, dict):
- if '$geometry' in value:
+ if "$geometry" in value:
return value
- elif 'coordinates' in value and 'type' in value:
- return {'$geometry': value}
- raise InvalidQueryError('Invalid $geometry dictionary should have '
- 'type and coordinates keys')
+ elif "coordinates" in value and "type" in value:
+ return {"$geometry": value}
+ raise InvalidQueryError(
+ "Invalid $geometry dictionary should have type and coordinates keys"
+ )
elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
try:
value[0][0][0]
- return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
+ return {"$geometry": {"type": "Polygon", "coordinates": value}}
except (TypeError, IndexError):
pass
try:
value[0][0]
- return {'$geometry': {'type': 'LineString', 'coordinates': value}}
+ return {"$geometry": {"type": "LineString", "coordinates": value}}
except (TypeError, IndexError):
pass
try:
value[0]
- return {'$geometry': {'type': 'Point', 'coordinates': value}}
+ return {"$geometry": {"type": "Point", "coordinates": value}}
except (TypeError, IndexError):
pass
- raise InvalidQueryError('Invalid $geometry data. Can be either a '
- 'dictionary or (nested) lists of coordinate(s)')
+ raise InvalidQueryError(
+ "Invalid $geometry data. Can be either a "
+ "dictionary or (nested) lists of coordinate(s)"
+ )
def _prepare_query_for_iterable(field, op, value):
# We need a special check for BaseDocument, because - although it's iterable - using
# it as such in the context of this method is most definitely a mistake.
- BaseDocument = _import_class('BaseDocument')
+ BaseDocument = _import_class("BaseDocument")
if isinstance(value, BaseDocument):
- raise TypeError("When using the `in`, `nin`, `all`, or "
- "`near`-operators you can\'t use a "
- "`Document`, you must wrap your object "
- "in a list (object -> [object]).")
+ raise TypeError(
+ "When using the `in`, `nin`, `all`, or "
+ "`near`-operators you can't use a "
+ "`Document`, you must wrap your object "
+ "in a list (object -> [object])."
+ )
- if not hasattr(value, '__iter__'):
- raise TypeError("The `in`, `nin`, `all`, or "
- "`near`-operators must be applied to an "
- "iterable (e.g. a list).")
+ if not hasattr(value, "__iter__"):
+ raise TypeError(
+ "The `in`, `nin`, `all`, or "
+ "`near`-operators must be applied to an "
+ "iterable (e.g. a list)."
+ )
return [field.prepare_query_value(op, v) for v in value]
diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py
index 9d97094b..0fe139fd 100644
--- a/mongoengine/queryset/visitor.py
+++ b/mongoengine/queryset/visitor.py
@@ -3,7 +3,7 @@ import copy
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform
-__all__ = ('Q', 'QNode')
+__all__ = ("Q", "QNode")
class QNodeVisitor(object):
@@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor):
self.document = document
def visit_combination(self, combination):
- operator = '$and'
+ operator = "$and"
if combination.operation == combination.OR:
- operator = '$or'
+ operator = "$or"
return {operator: combination.children}
def visit_query(self, query):
@@ -96,7 +96,7 @@ class QNode(object):
"""Combine this node with another node into a QCombination
object.
"""
- if getattr(other, 'empty', True):
+ if getattr(other, "empty", True):
return self
if self.empty:
@@ -132,8 +132,8 @@ class QCombination(QNode):
self.children.append(node)
def __repr__(self):
- op = ' & ' if self.operation is self.AND else ' | '
- return '(%s)' % op.join([repr(node) for node in self.children])
+ op = " & " if self.operation is self.AND else " | "
+ return "(%s)" % op.join([repr(node) for node in self.children])
def accept(self, visitor):
for i in range(len(self.children)):
@@ -156,7 +156,7 @@ class Q(QNode):
self.query = query
def __repr__(self):
- return 'Q(**%s)' % repr(self.query)
+ return "Q(**%s)" % repr(self.query)
def accept(self, visitor):
return visitor.visit_query(self)
diff --git a/mongoengine/signals.py b/mongoengine/signals.py
index a892dec0..0db63604 100644
--- a/mongoengine/signals.py
+++ b/mongoengine/signals.py
@@ -1,5 +1,12 @@
-__all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
- 'post_save', 'pre_delete', 'post_delete')
+__all__ = (
+ "pre_init",
+ "post_init",
+ "pre_save",
+ "pre_save_post_validation",
+ "post_save",
+ "pre_delete",
+ "post_delete",
+)
signals_available = False
try:
@@ -7,6 +14,7 @@ try:
signals_available = True
except ImportError:
+
class Namespace(object):
def signal(self, name, doc=None):
return _FakeSignal(name, doc)
@@ -23,13 +31,16 @@ except ImportError:
self.__doc__ = doc
def _fail(self, *args, **kwargs):
- raise RuntimeError('signalling support is unavailable '
- 'because the blinker library is '
- 'not installed.')
+ raise RuntimeError(
+ "signalling support is unavailable "
+ "because the blinker library is "
+ "not installed."
+ )
send = lambda *a, **kw: None # noqa
- connect = disconnect = has_receivers_for = receivers_for = \
- temporarily_connected_to = _fail
+ connect = (
+ disconnect
+ ) = has_receivers_for = receivers_for = temporarily_connected_to = _fail
del _fail
@@ -37,12 +48,12 @@ except ImportError:
# not put signals in here. Create your own namespace instead.
_signals = Namespace()
-pre_init = _signals.signal('pre_init')
-post_init = _signals.signal('post_init')
-pre_save = _signals.signal('pre_save')
-pre_save_post_validation = _signals.signal('pre_save_post_validation')
-post_save = _signals.signal('post_save')
-pre_delete = _signals.signal('pre_delete')
-post_delete = _signals.signal('post_delete')
-pre_bulk_insert = _signals.signal('pre_bulk_insert')
-post_bulk_insert = _signals.signal('post_bulk_insert')
+pre_init = _signals.signal("pre_init")
+post_init = _signals.signal("post_init")
+pre_save = _signals.signal("pre_save")
+pre_save_post_validation = _signals.signal("pre_save_post_validation")
+post_save = _signals.signal("post_save")
+pre_delete = _signals.signal("pre_delete")
+post_delete = _signals.signal("post_delete")
+pre_bulk_insert = _signals.signal("pre_bulk_insert")
+post_bulk_insert = _signals.signal("post_bulk_insert")
diff --git a/requirements.txt b/requirements.txt
index 9bb319a5..62ad8766 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+black
nose
pymongo>=3.4
six==1.10.0
diff --git a/setup.cfg b/setup.cfg
index 84086601..4bded428 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,7 +5,7 @@ detailed-errors=1
cover-package=mongoengine
[flake8]
-ignore=E501,F401,F403,F405,I201,I202,W504, W605
+ignore=E501,F401,F403,F405,I201,I202,W504, W605, W503
exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests
max-complexity=47
application-import-names=mongoengine,tests
diff --git a/setup.py b/setup.py
index f1f5dea7..c73a93ff 100644
--- a/setup.py
+++ b/setup.py
@@ -8,13 +8,10 @@ try:
except ImportError:
pass
-DESCRIPTION = (
- 'MongoEngine is a Python Object-Document '
- 'Mapper for working with MongoDB.'
-)
+DESCRIPTION = "MongoEngine is a Python Object-Document Mapper for working with MongoDB."
try:
- with open('README.rst') as fin:
+ with open("README.rst") as fin:
LONG_DESCRIPTION = fin.read()
except Exception:
LONG_DESCRIPTION = None
@@ -24,23 +21,23 @@ def get_version(version_tuple):
"""Return the version tuple as a string, e.g. for (0, 10, 7),
return '0.10.7'.
"""
- return '.'.join(map(str, version_tuple))
+ return ".".join(map(str, version_tuple))
# Dirty hack to get version number from monogengine/__init__.py - we can't
# import it as it depends on PyMongo and PyMongo isn't installed until this
# file is read
-init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py')
-version_line = list(filter(lambda l: l.startswith('VERSION'), open(init)))[0]
+init = os.path.join(os.path.dirname(__file__), "mongoengine", "__init__.py")
+version_line = list(filter(lambda l: l.startswith("VERSION"), open(init)))[0]
-VERSION = get_version(eval(version_line.split('=')[-1]))
+VERSION = get_version(eval(version_line.split("=")[-1]))
CLASSIFIERS = [
- 'Development Status :: 4 - Beta',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: MIT License',
- 'Operating System :: OS Independent',
- 'Programming Language :: Python',
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
@@ -48,39 +45,40 @@ CLASSIFIERS = [
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
- 'Topic :: Database',
- 'Topic :: Software Development :: Libraries :: Python Modules',
+ "Topic :: Database",
+ "Topic :: Software Development :: Libraries :: Python Modules",
]
extra_opts = {
- 'packages': find_packages(exclude=['tests', 'tests.*']),
- 'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0']
+ "packages": find_packages(exclude=["tests", "tests.*"]),
+ "tests_require": ["nose", "coverage==4.2", "blinker", "Pillow>=2.0.0"],
}
if sys.version_info[0] == 3:
- extra_opts['use_2to3'] = True
- if 'test' in sys.argv or 'nosetests' in sys.argv:
- extra_opts['packages'] = find_packages()
- extra_opts['package_data'] = {
- 'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']}
+ extra_opts["use_2to3"] = True
+ if "test" in sys.argv or "nosetests" in sys.argv:
+ extra_opts["packages"] = find_packages()
+ extra_opts["package_data"] = {
+ "tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]
+ }
else:
- extra_opts['tests_require'] += ['python-dateutil']
+ extra_opts["tests_require"] += ["python-dateutil"]
setup(
- name='mongoengine',
+ name="mongoengine",
version=VERSION,
- author='Harry Marr',
- author_email='harry.marr@gmail.com',
+ author="Harry Marr",
+ author_email="harry.marr@gmail.com",
maintainer="Stefan Wojcik",
maintainer_email="wojcikstefan@gmail.com",
- url='http://mongoengine.org/',
- download_url='https://github.com/MongoEngine/mongoengine/tarball/master',
- license='MIT',
+ url="http://mongoengine.org/",
+ download_url="https://github.com/MongoEngine/mongoengine/tarball/master",
+ license="MIT",
include_package_data=True,
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
- platforms=['any'],
+ platforms=["any"],
classifiers=CLASSIFIERS,
- install_requires=['pymongo>=3.4', 'six'],
- test_suite='nose.collector',
+ install_requires=["pymongo>=3.4", "six"],
+ test_suite="nose.collector",
**extra_opts
)
diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py
index 3aebe4ba..a755e7a3 100644
--- a/tests/all_warnings/__init__.py
+++ b/tests/all_warnings/__init__.py
@@ -9,34 +9,32 @@ import warnings
from mongoengine import *
-__all__ = ('AllWarnings', )
+__all__ = ("AllWarnings",)
class AllWarnings(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
self.warning_list = []
self.showwarning_default = warnings.showwarning
warnings.showwarning = self.append_to_warning_list
def append_to_warning_list(self, message, category, *args):
- self.warning_list.append({"message": message,
- "category": category})
+ self.warning_list.append({"message": message, "category": category})
def tearDown(self):
# restore default handling of warnings
warnings.showwarning = self.showwarning_default
def test_document_collection_syntax_warning(self):
-
class NonAbstractBase(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class InheritedDocumentFailTest(NonAbstractBase):
- meta = {'collection': 'fail'}
+ meta = {"collection": "fail"}
warning = self.warning_list[0]
self.assertEqual(SyntaxWarning, warning["category"])
- self.assertEqual('non_abstract_base',
- InheritedDocumentFailTest._get_collection_name())
+ self.assertEqual(
+ "non_abstract_base", InheritedDocumentFailTest._get_collection_name()
+ )
diff --git a/tests/document/__init__.py b/tests/document/__init__.py
index dc35c969..f2230c48 100644
--- a/tests/document/__init__.py
+++ b/tests/document/__init__.py
@@ -9,5 +9,5 @@ from .instance import *
from .json_serialisation import *
from .validation import *
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py
index 4fc648b7..87f1215b 100644
--- a/tests/document/class_methods.py
+++ b/tests/document/class_methods.py
@@ -7,13 +7,12 @@ from mongoengine.pymongo_support import list_collection_names
from mongoengine.queryset import NULLIFY, PULL
from mongoengine.connection import get_db
-__all__ = ("ClassMethodsTest", )
+__all__ = ("ClassMethodsTest",)
class ClassMethodsTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
self.db = get_db()
class Person(Document):
@@ -33,11 +32,13 @@ class ClassMethodsTest(unittest.TestCase):
def test_definition(self):
"""Ensure that document may be defined using fields.
"""
- self.assertEqual(['_cls', 'age', 'id', 'name'],
- sorted(self.Person._fields.keys()))
- self.assertEqual(["IntField", "ObjectIdField", "StringField", "StringField"],
- sorted([x.__class__.__name__ for x in
- self.Person._fields.values()]))
+ self.assertEqual(
+ ["_cls", "age", "id", "name"], sorted(self.Person._fields.keys())
+ )
+ self.assertEqual(
+ ["IntField", "ObjectIdField", "StringField", "StringField"],
+ sorted([x.__class__.__name__ for x in self.Person._fields.values()]),
+ )
def test_get_db(self):
"""Ensure that get_db returns the expected db.
@@ -49,21 +50,21 @@ class ClassMethodsTest(unittest.TestCase):
"""Ensure that get_collection_name returns the expected collection
name.
"""
- collection_name = 'person'
+ collection_name = "person"
self.assertEqual(collection_name, self.Person._get_collection_name())
def test_get_collection(self):
"""Ensure that get_collection returns the expected collection.
"""
- collection_name = 'person'
+ collection_name = "person"
collection = self.Person._get_collection()
self.assertEqual(self.db[collection_name], collection)
def test_drop_collection(self):
"""Ensure that the collection may be dropped from the database.
"""
- collection_name = 'person'
- self.Person(name='Test').save()
+ collection_name = "person"
+ self.Person(name="Test").save()
self.assertIn(collection_name, list_collection_names(self.db))
self.Person.drop_collection()
@@ -73,14 +74,16 @@ class ClassMethodsTest(unittest.TestCase):
"""Ensure that register delete rule adds a delete rule to the document
meta.
"""
+
class Job(Document):
employee = ReferenceField(self.Person)
- self.assertEqual(self.Person._meta.get('delete_rules'), None)
+ self.assertEqual(self.Person._meta.get("delete_rules"), None)
- self.Person.register_delete_rule(Job, 'employee', NULLIFY)
- self.assertEqual(self.Person._meta['delete_rules'],
- {(Job, 'employee'): NULLIFY})
+ self.Person.register_delete_rule(Job, "employee", NULLIFY)
+ self.assertEqual(
+ self.Person._meta["delete_rules"], {(Job, "employee"): NULLIFY}
+ )
def test_compare_indexes(self):
""" Ensure that the indexes are properly created and that
@@ -93,23 +96,27 @@ class ClassMethodsTest(unittest.TestCase):
description = StringField()
tags = StringField()
- meta = {
- 'indexes': [('author', 'title')]
- }
+ meta = {"indexes": [("author", "title")]}
BlogPost.drop_collection()
BlogPost.ensure_indexes()
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []})
+ self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []})
- BlogPost.ensure_index(['author', 'description'])
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('author', 1), ('description', 1)]]})
+ BlogPost.ensure_index(["author", "description"])
+ self.assertEqual(
+ BlogPost.compare_indexes(),
+ {"missing": [], "extra": [[("author", 1), ("description", 1)]]},
+ )
- BlogPost._get_collection().drop_index('author_1_description_1')
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []})
+ BlogPost._get_collection().drop_index("author_1_description_1")
+ self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []})
- BlogPost._get_collection().drop_index('author_1_title_1')
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('author', 1), ('title', 1)]], 'extra': []})
+ BlogPost._get_collection().drop_index("author_1_title_1")
+ self.assertEqual(
+ BlogPost.compare_indexes(),
+ {"missing": [[("author", 1), ("title", 1)]], "extra": []},
+ )
def test_compare_indexes_inheritance(self):
""" Ensure that the indexes are properly created and that
@@ -122,32 +129,34 @@ class ClassMethodsTest(unittest.TestCase):
title = StringField()
description = StringField()
- meta = {
- 'allow_inheritance': True
- }
+ meta = {"allow_inheritance": True}
class BlogPostWithTags(BlogPost):
tags = StringField()
tag_list = ListField(StringField())
- meta = {
- 'indexes': [('author', 'tags')]
- }
+ meta = {"indexes": [("author", "tags")]}
BlogPost.drop_collection()
BlogPost.ensure_indexes()
BlogPostWithTags.ensure_indexes()
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []})
+ self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []})
- BlogPostWithTags.ensure_index(['author', 'tag_list'])
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]]})
+ BlogPostWithTags.ensure_index(["author", "tag_list"])
+ self.assertEqual(
+ BlogPost.compare_indexes(),
+ {"missing": [], "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]]},
+ )
- BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1')
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []})
+ BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tag_list_1")
+ self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []})
- BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1')
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': []})
+ BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tags_1")
+ self.assertEqual(
+ BlogPost.compare_indexes(),
+ {"missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], "extra": []},
+ )
def test_compare_indexes_multiple_subclasses(self):
""" Ensure that compare_indexes behaves correctly if called from a
@@ -159,32 +168,30 @@ class ClassMethodsTest(unittest.TestCase):
title = StringField()
description = StringField()
- meta = {
- 'allow_inheritance': True
- }
+ meta = {"allow_inheritance": True}
class BlogPostWithTags(BlogPost):
tags = StringField()
tag_list = ListField(StringField())
- meta = {
- 'indexes': [('author', 'tags')]
- }
+ meta = {"indexes": [("author", "tags")]}
class BlogPostWithCustomField(BlogPost):
custom = DictField()
- meta = {
- 'indexes': [('author', 'custom')]
- }
+ meta = {"indexes": [("author", "custom")]}
BlogPost.ensure_indexes()
BlogPostWithTags.ensure_indexes()
BlogPostWithCustomField.ensure_indexes()
- self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []})
- self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []})
- self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []})
+ self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []})
+ self.assertEqual(
+ BlogPostWithTags.compare_indexes(), {"missing": [], "extra": []}
+ )
+ self.assertEqual(
+ BlogPostWithCustomField.compare_indexes(), {"missing": [], "extra": []}
+ )
def test_compare_indexes_for_text_indexes(self):
""" Ensure that compare_indexes behaves correctly for text indexes """
@@ -192,17 +199,20 @@ class ClassMethodsTest(unittest.TestCase):
class Doc(Document):
a = StringField()
b = StringField()
- meta = {'indexes': [
- {'fields': ['$a', "$b"],
- 'default_language': 'english',
- 'weights': {'a': 10, 'b': 2}
- }
- ]}
+ meta = {
+ "indexes": [
+ {
+ "fields": ["$a", "$b"],
+ "default_language": "english",
+ "weights": {"a": 10, "b": 2},
+ }
+ ]
+ }
Doc.drop_collection()
Doc.ensure_indexes()
actual = Doc.compare_indexes()
- expected = {'missing': [], 'extra': []}
+ expected = {"missing": [], "extra": []}
self.assertEqual(actual, expected)
def test_list_indexes_inheritance(self):
@@ -215,23 +225,17 @@ class ClassMethodsTest(unittest.TestCase):
title = StringField()
description = StringField()
- meta = {
- 'allow_inheritance': True
- }
+ meta = {"allow_inheritance": True}
class BlogPostWithTags(BlogPost):
tags = StringField()
- meta = {
- 'indexes': [('author', 'tags')]
- }
+ meta = {"indexes": [("author", "tags")]}
class BlogPostWithTagsAndExtraText(BlogPostWithTags):
extra_text = StringField()
- meta = {
- 'indexes': [('author', 'tags', 'extra_text')]
- }
+ meta = {"indexes": [("author", "tags", "extra_text")]}
BlogPost.drop_collection()
@@ -239,17 +243,21 @@ class ClassMethodsTest(unittest.TestCase):
BlogPostWithTags.ensure_indexes()
BlogPostWithTagsAndExtraText.ensure_indexes()
- self.assertEqual(BlogPost.list_indexes(),
- BlogPostWithTags.list_indexes())
- self.assertEqual(BlogPost.list_indexes(),
- BlogPostWithTagsAndExtraText.list_indexes())
- self.assertEqual(BlogPost.list_indexes(),
- [[('_cls', 1), ('author', 1), ('tags', 1)],
- [('_cls', 1), ('author', 1), ('tags', 1), ('extra_text', 1)],
- [(u'_id', 1)], [('_cls', 1)]])
+ self.assertEqual(BlogPost.list_indexes(), BlogPostWithTags.list_indexes())
+ self.assertEqual(
+ BlogPost.list_indexes(), BlogPostWithTagsAndExtraText.list_indexes()
+ )
+ self.assertEqual(
+ BlogPost.list_indexes(),
+ [
+ [("_cls", 1), ("author", 1), ("tags", 1)],
+ [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)],
+ [(u"_id", 1)],
+ [("_cls", 1)],
+ ],
+ )
def test_register_delete_rule_inherited(self):
-
class Vaccine(Document):
name = StringField(required=True)
@@ -257,15 +265,17 @@ class ClassMethodsTest(unittest.TestCase):
class Animal(Document):
family = StringField(required=True)
- vaccine_made = ListField(ReferenceField("Vaccine", reverse_delete_rule=PULL))
+ vaccine_made = ListField(
+ ReferenceField("Vaccine", reverse_delete_rule=PULL)
+ )
meta = {"allow_inheritance": True, "indexes": ["family"]}
class Cat(Animal):
name = StringField(required=True)
- self.assertEqual(Vaccine._meta['delete_rules'][(Animal, 'vaccine_made')], PULL)
- self.assertEqual(Vaccine._meta['delete_rules'][(Cat, 'vaccine_made')], PULL)
+ self.assertEqual(Vaccine._meta["delete_rules"][(Animal, "vaccine_made")], PULL)
+ self.assertEqual(Vaccine._meta["delete_rules"][(Cat, "vaccine_made")], PULL)
def test_collection_naming(self):
"""Ensure that a collection with a specified name may be used.
@@ -273,74 +283,73 @@ class ClassMethodsTest(unittest.TestCase):
class DefaultNamingTest(Document):
pass
- self.assertEqual('default_naming_test',
- DefaultNamingTest._get_collection_name())
+
+ self.assertEqual(
+ "default_naming_test", DefaultNamingTest._get_collection_name()
+ )
class CustomNamingTest(Document):
- meta = {'collection': 'pimp_my_collection'}
+ meta = {"collection": "pimp_my_collection"}
- self.assertEqual('pimp_my_collection',
- CustomNamingTest._get_collection_name())
+ self.assertEqual("pimp_my_collection", CustomNamingTest._get_collection_name())
class DynamicNamingTest(Document):
- meta = {'collection': lambda c: "DYNAMO"}
- self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name())
+ meta = {"collection": lambda c: "DYNAMO"}
+
+ self.assertEqual("DYNAMO", DynamicNamingTest._get_collection_name())
# Use Abstract class to handle backwards compatibility
class BaseDocument(Document):
- meta = {
- 'abstract': True,
- 'collection': lambda c: c.__name__.lower()
- }
+ meta = {"abstract": True, "collection": lambda c: c.__name__.lower()}
class OldNamingConvention(BaseDocument):
pass
- self.assertEqual('oldnamingconvention',
- OldNamingConvention._get_collection_name())
+
+ self.assertEqual(
+ "oldnamingconvention", OldNamingConvention._get_collection_name()
+ )
class InheritedAbstractNamingTest(BaseDocument):
- meta = {'collection': 'wibble'}
- self.assertEqual('wibble',
- InheritedAbstractNamingTest._get_collection_name())
+ meta = {"collection": "wibble"}
+
+ self.assertEqual("wibble", InheritedAbstractNamingTest._get_collection_name())
# Mixin tests
class BaseMixin(object):
- meta = {
- 'collection': lambda c: c.__name__.lower()
- }
+ meta = {"collection": lambda c: c.__name__.lower()}
class OldMixinNamingConvention(Document, BaseMixin):
pass
- self.assertEqual('oldmixinnamingconvention',
- OldMixinNamingConvention._get_collection_name())
+
+ self.assertEqual(
+ "oldmixinnamingconvention", OldMixinNamingConvention._get_collection_name()
+ )
class BaseMixin(object):
- meta = {
- 'collection': lambda c: c.__name__.lower()
- }
+ meta = {"collection": lambda c: c.__name__.lower()}
class BaseDocument(Document, BaseMixin):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class MyDocument(BaseDocument):
pass
- self.assertEqual('basedocument', MyDocument._get_collection_name())
+ self.assertEqual("basedocument", MyDocument._get_collection_name())
def test_custom_collection_name_operations(self):
"""Ensure that a collection with a specified name is used as expected.
"""
- collection_name = 'personCollTest'
+ collection_name = "personCollTest"
class Person(Document):
name = StringField()
- meta = {'collection': collection_name}
+ meta = {"collection": collection_name}
Person(name="Test User").save()
self.assertIn(collection_name, list_collection_names(self.db))
user_obj = self.db[collection_name].find_one()
- self.assertEqual(user_obj['name'], "Test User")
+ self.assertEqual(user_obj["name"], "Test User")
user_obj = Person.objects[0]
self.assertEqual(user_obj.name, "Test User")
@@ -354,7 +363,7 @@ class ClassMethodsTest(unittest.TestCase):
class Person(Document):
name = StringField(primary_key=True)
- meta = {'collection': 'app'}
+ meta = {"collection": "app"}
Person(name="Test User").save()
@@ -364,5 +373,5 @@ class ClassMethodsTest(unittest.TestCase):
Person.drop_collection()
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/delta.py b/tests/document/delta.py
index 504c1707..8f1575e6 100644
--- a/tests/document/delta.py
+++ b/tests/document/delta.py
@@ -8,7 +8,6 @@ from tests.utils import MongoDBTestCase
class DeltaTest(MongoDBTestCase):
-
def setUp(self):
super(DeltaTest, self).setUp()
@@ -31,7 +30,6 @@ class DeltaTest(MongoDBTestCase):
self.delta(DynamicDocument)
def delta(self, DocClass):
-
class Doc(DocClass):
string_field = StringField()
int_field = IntField()
@@ -46,37 +44,37 @@ class DeltaTest(MongoDBTestCase):
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {}))
- doc.string_field = 'hello'
- self.assertEqual(doc._get_changed_fields(), ['string_field'])
- self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {}))
+ doc.string_field = "hello"
+ self.assertEqual(doc._get_changed_fields(), ["string_field"])
+ self.assertEqual(doc._delta(), ({"string_field": "hello"}, {}))
doc._changed_fields = []
doc.int_field = 1
- self.assertEqual(doc._get_changed_fields(), ['int_field'])
- self.assertEqual(doc._delta(), ({'int_field': 1}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["int_field"])
+ self.assertEqual(doc._delta(), ({"int_field": 1}, {}))
doc._changed_fields = []
- dict_value = {'hello': 'world', 'ping': 'pong'}
+ dict_value = {"hello": "world", "ping": "pong"}
doc.dict_field = dict_value
- self.assertEqual(doc._get_changed_fields(), ['dict_field'])
- self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["dict_field"])
+ self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {}))
doc._changed_fields = []
- list_value = ['1', 2, {'hello': 'world'}]
+ list_value = ["1", 2, {"hello": "world"}]
doc.list_field = list_value
- self.assertEqual(doc._get_changed_fields(), ['list_field'])
- self.assertEqual(doc._delta(), ({'list_field': list_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["list_field"])
+ self.assertEqual(doc._delta(), ({"list_field": list_value}, {}))
# Test unsetting
doc._changed_fields = []
doc.dict_field = {}
- self.assertEqual(doc._get_changed_fields(), ['dict_field'])
- self.assertEqual(doc._delta(), ({}, {'dict_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["dict_field"])
+ self.assertEqual(doc._delta(), ({}, {"dict_field": 1}))
doc._changed_fields = []
doc.list_field = []
- self.assertEqual(doc._get_changed_fields(), ['list_field'])
- self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["list_field"])
+ self.assertEqual(doc._delta(), ({}, {"list_field": 1}))
def test_delta_recursive(self):
self.delta_recursive(Document, EmbeddedDocument)
@@ -85,7 +83,6 @@ class DeltaTest(MongoDBTestCase):
self.delta_recursive(DynamicDocument, DynamicEmbeddedDocument)
def delta_recursive(self, DocClass, EmbeddedClass):
-
class Embedded(EmbeddedClass):
id = StringField()
string_field = StringField()
@@ -110,165 +107,207 @@ class DeltaTest(MongoDBTestCase):
embedded_1 = Embedded()
embedded_1.id = "010101"
- embedded_1.string_field = 'hello'
+ embedded_1.string_field = "hello"
embedded_1.int_field = 1
- embedded_1.dict_field = {'hello': 'world'}
- embedded_1.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_1.dict_field = {"hello": "world"}
+ embedded_1.list_field = ["1", 2, {"hello": "world"}]
doc.embedded_field = embedded_1
- self.assertEqual(doc._get_changed_fields(), ['embedded_field'])
+ self.assertEqual(doc._get_changed_fields(), ["embedded_field"])
embedded_delta = {
- 'id': "010101",
- 'string_field': 'hello',
- 'int_field': 1,
- 'dict_field': {'hello': 'world'},
- 'list_field': ['1', 2, {'hello': 'world'}]
+ "id": "010101",
+ "string_field": "hello",
+ "int_field": 1,
+ "dict_field": {"hello": "world"},
+ "list_field": ["1", 2, {"hello": "world"}],
}
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
- self.assertEqual(doc._delta(),
- ({'embedded_field': embedded_delta}, {}))
+ self.assertEqual(doc._delta(), ({"embedded_field": embedded_delta}, {}))
doc.save()
doc = doc.reload(10)
doc.embedded_field.dict_field = {}
- self.assertEqual(doc._get_changed_fields(),
- ['embedded_field.dict_field'])
- self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1}))
- self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["embedded_field.dict_field"])
+ self.assertEqual(doc.embedded_field._delta(), ({}, {"dict_field": 1}))
+ self.assertEqual(doc._delta(), ({}, {"embedded_field.dict_field": 1}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = []
- self.assertEqual(doc._get_changed_fields(),
- ['embedded_field.list_field'])
- self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1}))
- self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"])
+ self.assertEqual(doc.embedded_field._delta(), ({}, {"list_field": 1}))
+ self.assertEqual(doc._delta(), ({}, {"embedded_field.list_field": 1}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field, [])
embedded_2 = Embedded()
- embedded_2.string_field = 'hello'
+ embedded_2.string_field = "hello"
embedded_2.int_field = 1
- embedded_2.dict_field = {'hello': 'world'}
- embedded_2.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_2.dict_field = {"hello": "world"}
+ embedded_2.list_field = ["1", 2, {"hello": "world"}]
- doc.embedded_field.list_field = ['1', 2, embedded_2]
- self.assertEqual(doc._get_changed_fields(),
- ['embedded_field.list_field'])
+ doc.embedded_field.list_field = ["1", 2, embedded_2]
+ self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"])
- self.assertEqual(doc.embedded_field._delta(), ({
- 'list_field': ['1', 2, {
- '_cls': 'Embedded',
- 'string_field': 'hello',
- 'dict_field': {'hello': 'world'},
- 'int_field': 1,
- 'list_field': ['1', 2, {'hello': 'world'}],
- }]
- }, {}))
+ self.assertEqual(
+ doc.embedded_field._delta(),
+ (
+ {
+ "list_field": [
+ "1",
+ 2,
+ {
+ "_cls": "Embedded",
+ "string_field": "hello",
+ "dict_field": {"hello": "world"},
+ "int_field": 1,
+ "list_field": ["1", 2, {"hello": "world"}],
+ },
+ ]
+ },
+ {},
+ ),
+ )
- self.assertEqual(doc._delta(), ({
- 'embedded_field.list_field': ['1', 2, {
- '_cls': 'Embedded',
- 'string_field': 'hello',
- 'dict_field': {'hello': 'world'},
- 'int_field': 1,
- 'list_field': ['1', 2, {'hello': 'world'}],
- }]
- }, {}))
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "embedded_field.list_field": [
+ "1",
+ 2,
+ {
+ "_cls": "Embedded",
+ "string_field": "hello",
+ "dict_field": {"hello": "world"},
+ "int_field": 1,
+ "list_field": ["1", 2, {"hello": "world"}],
+ },
+ ]
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[0], '1')
+ self.assertEqual(doc.embedded_field.list_field[0], "1")
self.assertEqual(doc.embedded_field.list_field[1], 2)
for k in doc.embedded_field.list_field[2]._fields:
- self.assertEqual(doc.embedded_field.list_field[2][k],
- embedded_2[k])
+ self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k])
- doc.embedded_field.list_field[2].string_field = 'world'
- self.assertEqual(doc._get_changed_fields(),
- ['embedded_field.list_field.2.string_field'])
- self.assertEqual(doc.embedded_field._delta(),
- ({'list_field.2.string_field': 'world'}, {}))
- self.assertEqual(doc._delta(),
- ({'embedded_field.list_field.2.string_field': 'world'}, {}))
+ doc.embedded_field.list_field[2].string_field = "world"
+ self.assertEqual(
+ doc._get_changed_fields(), ["embedded_field.list_field.2.string_field"]
+ )
+ self.assertEqual(
+ doc.embedded_field._delta(), ({"list_field.2.string_field": "world"}, {})
+ )
+ self.assertEqual(
+ doc._delta(), ({"embedded_field.list_field.2.string_field": "world"}, {})
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].string_field,
- 'world')
+ self.assertEqual(doc.embedded_field.list_field[2].string_field, "world")
# Test multiple assignments
- doc.embedded_field.list_field[2].string_field = 'hello world'
+ doc.embedded_field.list_field[2].string_field = "hello world"
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
- self.assertEqual(doc._get_changed_fields(),
- ['embedded_field.list_field.2'])
- self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': {
- '_cls': 'Embedded',
- 'string_field': 'hello world',
- 'int_field': 1,
- 'list_field': ['1', 2, {'hello': 'world'}],
- 'dict_field': {'hello': 'world'}}
- }, {}))
- self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': {
- '_cls': 'Embedded',
- 'string_field': 'hello world',
- 'int_field': 1,
- 'list_field': ['1', 2, {'hello': 'world'}],
- 'dict_field': {'hello': 'world'}}
- }, {}))
+ self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field.2"])
+ self.assertEqual(
+ doc.embedded_field._delta(),
+ (
+ {
+ "list_field.2": {
+ "_cls": "Embedded",
+ "string_field": "hello world",
+ "int_field": 1,
+ "list_field": ["1", 2, {"hello": "world"}],
+ "dict_field": {"hello": "world"},
+ }
+ },
+ {},
+ ),
+ )
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "embedded_field.list_field.2": {
+ "_cls": "Embedded",
+ "string_field": "hello world",
+ "int_field": 1,
+ "list_field": ["1", 2, {"hello": "world"}],
+ "dict_field": {"hello": "world"},
+ }
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].string_field,
- 'hello world')
+ self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world")
# Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0)
- self.assertEqual(doc._delta(),
- ({'embedded_field.list_field.2.list_field':
- [2, {'hello': 'world'}]}, {}))
+ self.assertEqual(
+ doc._delta(),
+ ({"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, {}),
+ )
doc.save()
doc = doc.reload(10)
doc.embedded_field.list_field[2].list_field.append(1)
- self.assertEqual(doc._delta(),
- ({'embedded_field.list_field.2.list_field':
- [2, {'hello': 'world'}, 1]}, {}))
+ self.assertEqual(
+ doc._delta(),
+ (
+ {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]},
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].list_field,
- [2, {'hello': 'world'}, 1])
+ self.assertEqual(
+ doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1]
+ )
doc.embedded_field.list_field[2].list_field.sort(key=str)
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].list_field,
- [1, 2, {'hello': 'world'}])
+ self.assertEqual(
+ doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}]
+ )
- del doc.embedded_field.list_field[2].list_field[2]['hello']
- self.assertEqual(doc._delta(),
- ({}, {'embedded_field.list_field.2.list_field.2.hello': 1}))
+ del doc.embedded_field.list_field[2].list_field[2]["hello"]
+ self.assertEqual(
+ doc._delta(), ({}, {"embedded_field.list_field.2.list_field.2.hello": 1})
+ )
doc.save()
doc = doc.reload(10)
del doc.embedded_field.list_field[2].list_field
- self.assertEqual(doc._delta(),
- ({}, {'embedded_field.list_field.2.list_field': 1}))
+ self.assertEqual(
+ doc._delta(), ({}, {"embedded_field.list_field.2.list_field": 1})
+ )
doc.save()
doc = doc.reload(10)
- doc.dict_field['Embedded'] = embedded_1
+ doc.dict_field["Embedded"] = embedded_1
doc.save()
doc = doc.reload(10)
- doc.dict_field['Embedded'].string_field = 'Hello World'
- self.assertEqual(doc._get_changed_fields(),
- ['dict_field.Embedded.string_field'])
- self.assertEqual(doc._delta(),
- ({'dict_field.Embedded.string_field': 'Hello World'}, {}))
+ doc.dict_field["Embedded"].string_field = "Hello World"
+ self.assertEqual(
+ doc._get_changed_fields(), ["dict_field.Embedded.string_field"]
+ )
+ self.assertEqual(
+ doc._delta(), ({"dict_field.Embedded.string_field": "Hello World"}, {})
+ )
def test_circular_reference_deltas(self):
self.circular_reference_deltas(Document, Document)
@@ -277,14 +316,13 @@ class DeltaTest(MongoDBTestCase):
self.circular_reference_deltas(DynamicDocument, DynamicDocument)
def circular_reference_deltas(self, DocClass1, DocClass2):
-
class Person(DocClass1):
name = StringField()
- owns = ListField(ReferenceField('Organization'))
+ owns = ListField(ReferenceField("Organization"))
class Organization(DocClass2):
name = StringField()
- owner = ReferenceField('Person')
+ owner = ReferenceField("Person")
Person.drop_collection()
Organization.drop_collection()
@@ -310,16 +348,15 @@ class DeltaTest(MongoDBTestCase):
self.circular_reference_deltas_2(DynamicDocument, DynamicDocument)
def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True):
-
class Person(DocClass1):
name = StringField()
- owns = ListField(ReferenceField('Organization', dbref=dbref))
- employer = ReferenceField('Organization', dbref=dbref)
+ owns = ListField(ReferenceField("Organization", dbref=dbref))
+ employer = ReferenceField("Organization", dbref=dbref)
class Organization(DocClass2):
name = StringField()
- owner = ReferenceField('Person', dbref=dbref)
- employees = ListField(ReferenceField('Person', dbref=dbref))
+ owner = ReferenceField("Person", dbref=dbref)
+ employees = ListField(ReferenceField("Person", dbref=dbref))
Person.drop_collection()
Organization.drop_collection()
@@ -353,12 +390,11 @@ class DeltaTest(MongoDBTestCase):
self.delta_db_field(DynamicDocument)
def delta_db_field(self, DocClass):
-
class Doc(DocClass):
- string_field = StringField(db_field='db_string_field')
- int_field = IntField(db_field='db_int_field')
- dict_field = DictField(db_field='db_dict_field')
- list_field = ListField(db_field='db_list_field')
+ string_field = StringField(db_field="db_string_field")
+ int_field = IntField(db_field="db_int_field")
+ dict_field = DictField(db_field="db_dict_field")
+ list_field = ListField(db_field="db_list_field")
Doc.drop_collection()
doc = Doc()
@@ -368,53 +404,53 @@ class DeltaTest(MongoDBTestCase):
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {}))
- doc.string_field = 'hello'
- self.assertEqual(doc._get_changed_fields(), ['db_string_field'])
- self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {}))
+ doc.string_field = "hello"
+ self.assertEqual(doc._get_changed_fields(), ["db_string_field"])
+ self.assertEqual(doc._delta(), ({"db_string_field": "hello"}, {}))
doc._changed_fields = []
doc.int_field = 1
- self.assertEqual(doc._get_changed_fields(), ['db_int_field'])
- self.assertEqual(doc._delta(), ({'db_int_field': 1}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["db_int_field"])
+ self.assertEqual(doc._delta(), ({"db_int_field": 1}, {}))
doc._changed_fields = []
- dict_value = {'hello': 'world', 'ping': 'pong'}
+ dict_value = {"hello": "world", "ping": "pong"}
doc.dict_field = dict_value
- self.assertEqual(doc._get_changed_fields(), ['db_dict_field'])
- self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["db_dict_field"])
+ self.assertEqual(doc._delta(), ({"db_dict_field": dict_value}, {}))
doc._changed_fields = []
- list_value = ['1', 2, {'hello': 'world'}]
+ list_value = ["1", 2, {"hello": "world"}]
doc.list_field = list_value
- self.assertEqual(doc._get_changed_fields(), ['db_list_field'])
- self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["db_list_field"])
+ self.assertEqual(doc._delta(), ({"db_list_field": list_value}, {}))
# Test unsetting
doc._changed_fields = []
doc.dict_field = {}
- self.assertEqual(doc._get_changed_fields(), ['db_dict_field'])
- self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["db_dict_field"])
+ self.assertEqual(doc._delta(), ({}, {"db_dict_field": 1}))
doc._changed_fields = []
doc.list_field = []
- self.assertEqual(doc._get_changed_fields(), ['db_list_field'])
- self.assertEqual(doc._delta(), ({}, {'db_list_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["db_list_field"])
+ self.assertEqual(doc._delta(), ({}, {"db_list_field": 1}))
# Test it saves that data
doc = Doc()
doc.save()
- doc.string_field = 'hello'
+ doc.string_field = "hello"
doc.int_field = 1
- doc.dict_field = {'hello': 'world'}
- doc.list_field = ['1', 2, {'hello': 'world'}]
+ doc.dict_field = {"hello": "world"}
+ doc.list_field = ["1", 2, {"hello": "world"}]
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.string_field, 'hello')
+ self.assertEqual(doc.string_field, "hello")
self.assertEqual(doc.int_field, 1)
- self.assertEqual(doc.dict_field, {'hello': 'world'})
- self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}])
+ self.assertEqual(doc.dict_field, {"hello": "world"})
+ self.assertEqual(doc.list_field, ["1", 2, {"hello": "world"}])
def test_delta_recursive_db_field(self):
self.delta_recursive_db_field(Document, EmbeddedDocument)
@@ -423,20 +459,20 @@ class DeltaTest(MongoDBTestCase):
self.delta_recursive_db_field(DynamicDocument, DynamicEmbeddedDocument)
def delta_recursive_db_field(self, DocClass, EmbeddedClass):
-
class Embedded(EmbeddedClass):
- string_field = StringField(db_field='db_string_field')
- int_field = IntField(db_field='db_int_field')
- dict_field = DictField(db_field='db_dict_field')
- list_field = ListField(db_field='db_list_field')
+ string_field = StringField(db_field="db_string_field")
+ int_field = IntField(db_field="db_int_field")
+ dict_field = DictField(db_field="db_dict_field")
+ list_field = ListField(db_field="db_list_field")
class Doc(DocClass):
- string_field = StringField(db_field='db_string_field')
- int_field = IntField(db_field='db_int_field')
- dict_field = DictField(db_field='db_dict_field')
- list_field = ListField(db_field='db_list_field')
- embedded_field = EmbeddedDocumentField(Embedded,
- db_field='db_embedded_field')
+ string_field = StringField(db_field="db_string_field")
+ int_field = IntField(db_field="db_int_field")
+ dict_field = DictField(db_field="db_dict_field")
+ list_field = ListField(db_field="db_list_field")
+ embedded_field = EmbeddedDocumentField(
+ Embedded, db_field="db_embedded_field"
+ )
Doc.drop_collection()
doc = Doc()
@@ -447,171 +483,228 @@ class DeltaTest(MongoDBTestCase):
self.assertEqual(doc._delta(), ({}, {}))
embedded_1 = Embedded()
- embedded_1.string_field = 'hello'
+ embedded_1.string_field = "hello"
embedded_1.int_field = 1
- embedded_1.dict_field = {'hello': 'world'}
- embedded_1.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_1.dict_field = {"hello": "world"}
+ embedded_1.list_field = ["1", 2, {"hello": "world"}]
doc.embedded_field = embedded_1
- self.assertEqual(doc._get_changed_fields(), ['db_embedded_field'])
+ self.assertEqual(doc._get_changed_fields(), ["db_embedded_field"])
embedded_delta = {
- 'db_string_field': 'hello',
- 'db_int_field': 1,
- 'db_dict_field': {'hello': 'world'},
- 'db_list_field': ['1', 2, {'hello': 'world'}]
+ "db_string_field": "hello",
+ "db_int_field": 1,
+ "db_dict_field": {"hello": "world"},
+ "db_list_field": ["1", 2, {"hello": "world"}],
}
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
- self.assertEqual(doc._delta(),
- ({'db_embedded_field': embedded_delta}, {}))
+ self.assertEqual(doc._delta(), ({"db_embedded_field": embedded_delta}, {}))
doc.save()
doc = doc.reload(10)
doc.embedded_field.dict_field = {}
- self.assertEqual(doc._get_changed_fields(),
- ['db_embedded_field.db_dict_field'])
- self.assertEqual(doc.embedded_field._delta(),
- ({}, {'db_dict_field': 1}))
- self.assertEqual(doc._delta(),
- ({}, {'db_embedded_field.db_dict_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_dict_field"])
+ self.assertEqual(doc.embedded_field._delta(), ({}, {"db_dict_field": 1}))
+ self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_dict_field": 1}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = []
- self.assertEqual(doc._get_changed_fields(),
- ['db_embedded_field.db_list_field'])
- self.assertEqual(doc.embedded_field._delta(),
- ({}, {'db_list_field': 1}))
- self.assertEqual(doc._delta(),
- ({}, {'db_embedded_field.db_list_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"])
+ self.assertEqual(doc.embedded_field._delta(), ({}, {"db_list_field": 1}))
+ self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_list_field": 1}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field, [])
embedded_2 = Embedded()
- embedded_2.string_field = 'hello'
+ embedded_2.string_field = "hello"
embedded_2.int_field = 1
- embedded_2.dict_field = {'hello': 'world'}
- embedded_2.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_2.dict_field = {"hello": "world"}
+ embedded_2.list_field = ["1", 2, {"hello": "world"}]
- doc.embedded_field.list_field = ['1', 2, embedded_2]
- self.assertEqual(doc._get_changed_fields(),
- ['db_embedded_field.db_list_field'])
- self.assertEqual(doc.embedded_field._delta(), ({
- 'db_list_field': ['1', 2, {
- '_cls': 'Embedded',
- 'db_string_field': 'hello',
- 'db_dict_field': {'hello': 'world'},
- 'db_int_field': 1,
- 'db_list_field': ['1', 2, {'hello': 'world'}],
- }]
- }, {}))
+ doc.embedded_field.list_field = ["1", 2, embedded_2]
+ self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"])
+ self.assertEqual(
+ doc.embedded_field._delta(),
+ (
+ {
+ "db_list_field": [
+ "1",
+ 2,
+ {
+ "_cls": "Embedded",
+ "db_string_field": "hello",
+ "db_dict_field": {"hello": "world"},
+ "db_int_field": 1,
+ "db_list_field": ["1", 2, {"hello": "world"}],
+ },
+ ]
+ },
+ {},
+ ),
+ )
- self.assertEqual(doc._delta(), ({
- 'db_embedded_field.db_list_field': ['1', 2, {
- '_cls': 'Embedded',
- 'db_string_field': 'hello',
- 'db_dict_field': {'hello': 'world'},
- 'db_int_field': 1,
- 'db_list_field': ['1', 2, {'hello': 'world'}],
- }]
- }, {}))
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "db_embedded_field.db_list_field": [
+ "1",
+ 2,
+ {
+ "_cls": "Embedded",
+ "db_string_field": "hello",
+ "db_dict_field": {"hello": "world"},
+ "db_int_field": 1,
+ "db_list_field": ["1", 2, {"hello": "world"}],
+ },
+ ]
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[0], '1')
+ self.assertEqual(doc.embedded_field.list_field[0], "1")
self.assertEqual(doc.embedded_field.list_field[1], 2)
for k in doc.embedded_field.list_field[2]._fields:
- self.assertEqual(doc.embedded_field.list_field[2][k],
- embedded_2[k])
+ self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k])
- doc.embedded_field.list_field[2].string_field = 'world'
- self.assertEqual(doc._get_changed_fields(),
- ['db_embedded_field.db_list_field.2.db_string_field'])
- self.assertEqual(doc.embedded_field._delta(),
- ({'db_list_field.2.db_string_field': 'world'}, {}))
- self.assertEqual(doc._delta(),
- ({'db_embedded_field.db_list_field.2.db_string_field': 'world'},
- {}))
+ doc.embedded_field.list_field[2].string_field = "world"
+ self.assertEqual(
+ doc._get_changed_fields(),
+ ["db_embedded_field.db_list_field.2.db_string_field"],
+ )
+ self.assertEqual(
+ doc.embedded_field._delta(),
+ ({"db_list_field.2.db_string_field": "world"}, {}),
+ )
+ self.assertEqual(
+ doc._delta(),
+ ({"db_embedded_field.db_list_field.2.db_string_field": "world"}, {}),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].string_field,
- 'world')
+ self.assertEqual(doc.embedded_field.list_field[2].string_field, "world")
# Test multiple assignments
- doc.embedded_field.list_field[2].string_field = 'hello world'
+ doc.embedded_field.list_field[2].string_field = "hello world"
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
- self.assertEqual(doc._get_changed_fields(),
- ['db_embedded_field.db_list_field.2'])
- self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': {
- '_cls': 'Embedded',
- 'db_string_field': 'hello world',
- 'db_int_field': 1,
- 'db_list_field': ['1', 2, {'hello': 'world'}],
- 'db_dict_field': {'hello': 'world'}}}, {}))
- self.assertEqual(doc._delta(), ({
- 'db_embedded_field.db_list_field.2': {
- '_cls': 'Embedded',
- 'db_string_field': 'hello world',
- 'db_int_field': 1,
- 'db_list_field': ['1', 2, {'hello': 'world'}],
- 'db_dict_field': {'hello': 'world'}}
- }, {}))
+ self.assertEqual(
+ doc._get_changed_fields(), ["db_embedded_field.db_list_field.2"]
+ )
+ self.assertEqual(
+ doc.embedded_field._delta(),
+ (
+ {
+ "db_list_field.2": {
+ "_cls": "Embedded",
+ "db_string_field": "hello world",
+ "db_int_field": 1,
+ "db_list_field": ["1", 2, {"hello": "world"}],
+ "db_dict_field": {"hello": "world"},
+ }
+ },
+ {},
+ ),
+ )
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "db_embedded_field.db_list_field.2": {
+ "_cls": "Embedded",
+ "db_string_field": "hello world",
+ "db_int_field": 1,
+ "db_list_field": ["1", 2, {"hello": "world"}],
+ "db_dict_field": {"hello": "world"},
+ }
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].string_field,
- 'hello world')
+ self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world")
# Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0)
- self.assertEqual(doc._delta(),
- ({'db_embedded_field.db_list_field.2.db_list_field':
- [2, {'hello': 'world'}]}, {}))
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "db_embedded_field.db_list_field.2.db_list_field": [
+ 2,
+ {"hello": "world"},
+ ]
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
doc.embedded_field.list_field[2].list_field.append(1)
- self.assertEqual(doc._delta(),
- ({'db_embedded_field.db_list_field.2.db_list_field':
- [2, {'hello': 'world'}, 1]}, {}))
+ self.assertEqual(
+ doc._delta(),
+ (
+ {
+ "db_embedded_field.db_list_field.2.db_list_field": [
+ 2,
+ {"hello": "world"},
+ 1,
+ ]
+ },
+ {},
+ ),
+ )
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].list_field,
- [2, {'hello': 'world'}, 1])
+ self.assertEqual(
+ doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1]
+ )
doc.embedded_field.list_field[2].list_field.sort(key=str)
doc.save()
doc = doc.reload(10)
- self.assertEqual(doc.embedded_field.list_field[2].list_field,
- [1, 2, {'hello': 'world'}])
+ self.assertEqual(
+ doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}]
+ )
- del doc.embedded_field.list_field[2].list_field[2]['hello']
- self.assertEqual(doc._delta(),
- ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1}))
+ del doc.embedded_field.list_field[2].list_field[2]["hello"]
+ self.assertEqual(
+ doc._delta(),
+ ({}, {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}),
+ )
doc.save()
doc = doc.reload(10)
del doc.embedded_field.list_field[2].list_field
- self.assertEqual(doc._delta(), ({},
- {'db_embedded_field.db_list_field.2.db_list_field': 1}))
+ self.assertEqual(
+ doc._delta(), ({}, {"db_embedded_field.db_list_field.2.db_list_field": 1})
+ )
def test_delta_for_dynamic_documents(self):
class Person(DynamicDocument):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
Person.drop_collection()
p = Person(name="James", age=34)
- self.assertEqual(p._delta(), (
- SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
+ self.assertEqual(
+ p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {})
+ )
p.doc = 123
del p.doc
- self.assertEqual(p._delta(), (
- SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {}))
+ self.assertEqual(
+ p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {})
+ )
p = Person()
p.name = "Dean"
@@ -620,20 +713,19 @@ class DeltaTest(MongoDBTestCase):
p.age = 24
self.assertEqual(p.age, 24)
- self.assertEqual(p._get_changed_fields(), ['age'])
- self.assertEqual(p._delta(), ({'age': 24}, {}))
+ self.assertEqual(p._get_changed_fields(), ["age"])
+ self.assertEqual(p._delta(), ({"age": 24}, {}))
p = Person.objects(age=22).get()
p.age = 24
self.assertEqual(p.age, 24)
- self.assertEqual(p._get_changed_fields(), ['age'])
- self.assertEqual(p._delta(), ({'age': 24}, {}))
+ self.assertEqual(p._get_changed_fields(), ["age"])
+ self.assertEqual(p._delta(), ({"age": 24}, {}))
p.save()
self.assertEqual(1, Person.objects(age=24).count())
def test_dynamic_delta(self):
-
class Doc(DynamicDocument):
pass
@@ -645,41 +737,43 @@ class DeltaTest(MongoDBTestCase):
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(doc._delta(), ({}, {}))
- doc.string_field = 'hello'
- self.assertEqual(doc._get_changed_fields(), ['string_field'])
- self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {}))
+ doc.string_field = "hello"
+ self.assertEqual(doc._get_changed_fields(), ["string_field"])
+ self.assertEqual(doc._delta(), ({"string_field": "hello"}, {}))
doc._changed_fields = []
doc.int_field = 1
- self.assertEqual(doc._get_changed_fields(), ['int_field'])
- self.assertEqual(doc._delta(), ({'int_field': 1}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["int_field"])
+ self.assertEqual(doc._delta(), ({"int_field": 1}, {}))
doc._changed_fields = []
- dict_value = {'hello': 'world', 'ping': 'pong'}
+ dict_value = {"hello": "world", "ping": "pong"}
doc.dict_field = dict_value
- self.assertEqual(doc._get_changed_fields(), ['dict_field'])
- self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["dict_field"])
+ self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {}))
doc._changed_fields = []
- list_value = ['1', 2, {'hello': 'world'}]
+ list_value = ["1", 2, {"hello": "world"}]
doc.list_field = list_value
- self.assertEqual(doc._get_changed_fields(), ['list_field'])
- self.assertEqual(doc._delta(), ({'list_field': list_value}, {}))
+ self.assertEqual(doc._get_changed_fields(), ["list_field"])
+ self.assertEqual(doc._delta(), ({"list_field": list_value}, {}))
# Test unsetting
doc._changed_fields = []
doc.dict_field = {}
- self.assertEqual(doc._get_changed_fields(), ['dict_field'])
- self.assertEqual(doc._delta(), ({}, {'dict_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["dict_field"])
+ self.assertEqual(doc._delta(), ({}, {"dict_field": 1}))
doc._changed_fields = []
doc.list_field = []
- self.assertEqual(doc._get_changed_fields(), ['list_field'])
- self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
+ self.assertEqual(doc._get_changed_fields(), ["list_field"])
+ self.assertEqual(doc._delta(), ({}, {"list_field": 1}))
def test_delta_with_dbref_true(self):
- person, organization, employee = self.circular_reference_deltas_2(Document, Document, True)
- employee.name = 'test'
+ person, organization, employee = self.circular_reference_deltas_2(
+ Document, Document, True
+ )
+ employee.name = "test"
self.assertEqual(organization._get_changed_fields(), [])
@@ -690,11 +784,13 @@ class DeltaTest(MongoDBTestCase):
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
- self.assertIn('employees', updates)
+ self.assertIn("employees", updates)
def test_delta_with_dbref_false(self):
- person, organization, employee = self.circular_reference_deltas_2(Document, Document, False)
- employee.name = 'test'
+ person, organization, employee = self.circular_reference_deltas_2(
+ Document, Document, False
+ )
+ employee.name = "test"
self.assertEqual(organization._get_changed_fields(), [])
@@ -705,7 +801,7 @@ class DeltaTest(MongoDBTestCase):
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
- self.assertIn('employees', updates)
+ self.assertIn("employees", updates)
def test_nested_nested_fields_mark_as_changed(self):
class EmbeddedDoc(EmbeddedDocument):
@@ -717,11 +813,13 @@ class DeltaTest(MongoDBTestCase):
MyDoc.drop_collection()
- mydoc = MyDoc(name='testcase1', subs={'a': {'b': EmbeddedDoc(name='foo')}}).save()
+ mydoc = MyDoc(
+ name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}
+ ).save()
mydoc = MyDoc.objects.first()
- subdoc = mydoc.subs['a']['b']
- subdoc.name = 'bar'
+ subdoc = mydoc.subs["a"]["b"]
+ subdoc.name = "bar"
self.assertEqual(["name"], subdoc._get_changed_fields())
self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields())
@@ -741,11 +839,11 @@ class DeltaTest(MongoDBTestCase):
MyDoc().save()
mydoc = MyDoc.objects.first()
- mydoc.subs['a'] = EmbeddedDoc()
+ mydoc.subs["a"] = EmbeddedDoc()
self.assertEqual(["subs.a"], mydoc._get_changed_fields())
- subdoc = mydoc.subs['a']
- subdoc.name = 'bar'
+ subdoc = mydoc.subs["a"]
+ subdoc.name = "bar"
self.assertEqual(["name"], subdoc._get_changed_fields())
self.assertEqual(["subs.a"], mydoc._get_changed_fields())
@@ -763,16 +861,16 @@ class DeltaTest(MongoDBTestCase):
MyDoc.drop_collection()
- MyDoc(subs={'a': EmbeddedDoc(name='foo')}).save()
+ MyDoc(subs={"a": EmbeddedDoc(name="foo")}).save()
mydoc = MyDoc.objects.first()
- subdoc = mydoc.subs['a']
- subdoc.name = 'bar'
+ subdoc = mydoc.subs["a"]
+ subdoc.name = "bar"
self.assertEqual(["name"], subdoc._get_changed_fields())
self.assertEqual(["subs.a.name"], mydoc._get_changed_fields())
- mydoc.subs['a'] = EmbeddedDoc()
+ mydoc.subs["a"] = EmbeddedDoc()
self.assertEqual(["subs.a"], mydoc._get_changed_fields())
mydoc.save()
@@ -787,39 +885,39 @@ class DeltaTest(MongoDBTestCase):
class User(Document):
name = StringField()
- org = ReferenceField('Organization', required=True)
+ org = ReferenceField("Organization", required=True)
Organization.drop_collection()
User.drop_collection()
- org1 = Organization(name='Org 1')
+ org1 = Organization(name="Org 1")
org1.save()
- org2 = Organization(name='Org 2')
+ org2 = Organization(name="Org 2")
org2.save()
- user = User(name='Fred', org=org1)
+ user = User(name="Fred", org=org1)
user.save()
org1.reload()
org2.reload()
user.reload()
- self.assertEqual(org1.name, 'Org 1')
- self.assertEqual(org2.name, 'Org 2')
- self.assertEqual(user.name, 'Fred')
+ self.assertEqual(org1.name, "Org 1")
+ self.assertEqual(org2.name, "Org 2")
+ self.assertEqual(user.name, "Fred")
- user.name = 'Harold'
+ user.name = "Harold"
user.org = org2
- org2.name = 'New Org 2'
- self.assertEqual(org2.name, 'New Org 2')
+ org2.name = "New Org 2"
+ self.assertEqual(org2.name, "New Org 2")
user.save()
org2.save()
- self.assertEqual(org2.name, 'New Org 2')
+ self.assertEqual(org2.name, "New Org 2")
org2.reload()
- self.assertEqual(org2.name, 'New Org 2')
+ self.assertEqual(org2.name, "New Org 2")
def test_delta_for_nested_map_fields(self):
class UInfoDocument(Document):
@@ -855,10 +953,10 @@ class DeltaTest(MongoDBTestCase):
self.assertEqual(True, "users.007.roles.666" in delta[0])
self.assertEqual(True, "users.007.rolist" in delta[0])
self.assertEqual(True, "users.007.info" in delta[0])
- self.assertEqual('superadmin', delta[0]["users.007.roles.666"]["type"])
- self.assertEqual('oops', delta[0]["users.007.rolist"][0]["type"])
+ self.assertEqual("superadmin", delta[0]["users.007.roles.666"]["type"])
+ self.assertEqual("oops", delta[0]["users.007.rolist"][0]["type"])
self.assertEqual(uinfo.id, delta[0]["users.007.info"])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py
index 44548d27..414d3352 100644
--- a/tests/document/dynamic.py
+++ b/tests/document/dynamic.py
@@ -3,17 +3,16 @@ import unittest
from mongoengine import *
from tests.utils import MongoDBTestCase
-__all__ = ("TestDynamicDocument", )
+__all__ = ("TestDynamicDocument",)
class TestDynamicDocument(MongoDBTestCase):
-
def setUp(self):
super(TestDynamicDocument, self).setUp()
class Person(DynamicDocument):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
Person.drop_collection()
@@ -26,8 +25,7 @@ class TestDynamicDocument(MongoDBTestCase):
p.name = "James"
p.age = 34
- self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James",
- "age": 34})
+ self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", "age": 34})
self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"])
p.save()
self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"])
@@ -35,7 +33,7 @@ class TestDynamicDocument(MongoDBTestCase):
self.assertEqual(self.Person.objects.first().age, 34)
# Confirm no changes to self.Person
- self.assertFalse(hasattr(self.Person, 'age'))
+ self.assertFalse(hasattr(self.Person, "age"))
def test_change_scope_of_variable(self):
"""Test changing the scope of a dynamic field has no adverse effects"""
@@ -45,11 +43,11 @@ class TestDynamicDocument(MongoDBTestCase):
p.save()
p = self.Person.objects.get()
- p.misc = {'hello': 'world'}
+ p.misc = {"hello": "world"}
p.save()
p = self.Person.objects.get()
- self.assertEqual(p.misc, {'hello': 'world'})
+ self.assertEqual(p.misc, {"hello": "world"})
def test_delete_dynamic_field(self):
"""Test deleting a dynamic field works"""
@@ -60,23 +58,23 @@ class TestDynamicDocument(MongoDBTestCase):
p.save()
p = self.Person.objects.get()
- p.misc = {'hello': 'world'}
+ p.misc = {"hello": "world"}
p.save()
p = self.Person.objects.get()
- self.assertEqual(p.misc, {'hello': 'world'})
+ self.assertEqual(p.misc, {"hello": "world"})
collection = self.db[self.Person._get_collection_name()]
obj = collection.find_one()
- self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'misc', 'name'])
+ self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "misc", "name"])
del p.misc
p.save()
p = self.Person.objects.get()
- self.assertFalse(hasattr(p, 'misc'))
+ self.assertFalse(hasattr(p, "misc"))
obj = collection.find_one()
- self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'name'])
+ self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "name"])
def test_reload_after_unsetting(self):
p = self.Person()
@@ -91,77 +89,52 @@ class TestDynamicDocument(MongoDBTestCase):
p.update(age=1)
self.assertEqual(len(p._data), 3)
- self.assertEqual(sorted(p._data.keys()), ['_cls', 'id', 'name'])
+ self.assertEqual(sorted(p._data.keys()), ["_cls", "id", "name"])
p.reload()
self.assertEqual(len(p._data), 4)
- self.assertEqual(sorted(p._data.keys()), ['_cls', 'age', 'id', 'name'])
+ self.assertEqual(sorted(p._data.keys()), ["_cls", "age", "id", "name"])
def test_fields_without_underscore(self):
"""Ensure we can query dynamic fields"""
Person = self.Person
- p = self.Person(name='Dean')
+ p = self.Person(name="Dean")
p.save()
raw_p = Person.objects.as_pymongo().get(id=p.id)
- self.assertEqual(
- raw_p,
- {
- '_cls': u'Person',
- '_id': p.id,
- 'name': u'Dean'
- }
- )
+ self.assertEqual(raw_p, {"_cls": u"Person", "_id": p.id, "name": u"Dean"})
- p.name = 'OldDean'
- p.newattr = 'garbage'
+ p.name = "OldDean"
+ p.newattr = "garbage"
p.save()
raw_p = Person.objects.as_pymongo().get(id=p.id)
self.assertEqual(
raw_p,
- {
- '_cls': u'Person',
- '_id': p.id,
- 'name': 'OldDean',
- 'newattr': u'garbage'
- }
+ {"_cls": u"Person", "_id": p.id, "name": "OldDean", "newattr": u"garbage"},
)
def test_fields_containing_underscore(self):
"""Ensure we can query dynamic fields"""
+
class WeirdPerson(DynamicDocument):
name = StringField()
_name = StringField()
WeirdPerson.drop_collection()
- p = WeirdPerson(name='Dean', _name='Dean')
+ p = WeirdPerson(name="Dean", _name="Dean")
p.save()
raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id)
- self.assertEqual(
- raw_p,
- {
- '_id': p.id,
- '_name': u'Dean',
- 'name': u'Dean'
- }
- )
+ self.assertEqual(raw_p, {"_id": p.id, "_name": u"Dean", "name": u"Dean"})
- p.name = 'OldDean'
- p._name = 'NewDean'
- p._newattr1 = 'garbage' # Unknown fields won't be added
+ p.name = "OldDean"
+ p._name = "NewDean"
+ p._newattr1 = "garbage" # Unknown fields won't be added
p.save()
raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id)
- self.assertEqual(
- raw_p,
- {
- '_id': p.id,
- '_name': u'NewDean',
- 'name': u'OldDean',
- }
- )
+ self.assertEqual(raw_p, {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"})
def test_dynamic_document_queries(self):
"""Ensure we can query dynamic fields"""
@@ -193,26 +166,25 @@ class TestDynamicDocument(MongoDBTestCase):
p2.age = 10
p2.save()
- self.assertEqual(Person.objects(age__icontains='ten').count(), 2)
+ self.assertEqual(Person.objects(age__icontains="ten").count(), 2)
self.assertEqual(Person.objects(age__gte=10).count(), 1)
def test_complex_data_lookups(self):
"""Ensure you can query dynamic document dynamic fields"""
p = self.Person()
- p.misc = {'hello': 'world'}
+ p.misc = {"hello": "world"}
p.save()
- self.assertEqual(1, self.Person.objects(misc__hello='world').count())
+ self.assertEqual(1, self.Person.objects(misc__hello="world").count())
def test_three_level_complex_data_lookups(self):
"""Ensure you can query three level document dynamic fields"""
- p = self.Person.objects.create(
- misc={'hello': {'hello2': 'world'}}
- )
- self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count())
+ p = self.Person.objects.create(misc={"hello": {"hello2": "world"}})
+ self.assertEqual(1, self.Person.objects(misc__hello__hello2="world").count())
def test_complex_embedded_document_validation(self):
"""Ensure embedded dynamic documents may be validated"""
+
class Embedded(DynamicEmbeddedDocument):
content = URLField()
@@ -222,10 +194,10 @@ class TestDynamicDocument(MongoDBTestCase):
Doc.drop_collection()
doc = Doc()
- embedded_doc_1 = Embedded(content='http://mongoengine.org')
+ embedded_doc_1 = Embedded(content="http://mongoengine.org")
embedded_doc_1.validate()
- embedded_doc_2 = Embedded(content='this is not a url')
+ embedded_doc_2 = Embedded(content="this is not a url")
self.assertRaises(ValidationError, embedded_doc_2.validate)
doc.embedded_field_1 = embedded_doc_1
@@ -234,15 +206,17 @@ class TestDynamicDocument(MongoDBTestCase):
def test_inheritance(self):
"""Ensure that dynamic document plays nice with inheritance"""
+
class Employee(self.Person):
salary = IntField()
Employee.drop_collection()
- self.assertIn('name', Employee._fields)
- self.assertIn('salary', Employee._fields)
- self.assertEqual(Employee._get_collection_name(),
- self.Person._get_collection_name())
+ self.assertIn("name", Employee._fields)
+ self.assertIn("salary", Employee._fields)
+ self.assertEqual(
+ Employee._get_collection_name(), self.Person._get_collection_name()
+ )
joe_bloggs = Employee()
joe_bloggs.name = "Joe Bloggs"
@@ -258,6 +232,7 @@ class TestDynamicDocument(MongoDBTestCase):
def test_embedded_dynamic_document(self):
"""Test dynamic embedded documents"""
+
class Embedded(DynamicEmbeddedDocument):
pass
@@ -268,78 +243,88 @@ class TestDynamicDocument(MongoDBTestCase):
doc = Doc()
embedded_1 = Embedded()
- embedded_1.string_field = 'hello'
+ embedded_1.string_field = "hello"
embedded_1.int_field = 1
- embedded_1.dict_field = {'hello': 'world'}
- embedded_1.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_1.dict_field = {"hello": "world"}
+ embedded_1.list_field = ["1", 2, {"hello": "world"}]
doc.embedded_field = embedded_1
- self.assertEqual(doc.to_mongo(), {
- "embedded_field": {
- "_cls": "Embedded",
- "string_field": "hello",
- "int_field": 1,
- "dict_field": {"hello": "world"},
- "list_field": ['1', 2, {'hello': 'world'}]
- }
- })
- doc.save()
-
- doc = Doc.objects.first()
- self.assertEqual(doc.embedded_field.__class__, Embedded)
- self.assertEqual(doc.embedded_field.string_field, "hello")
- self.assertEqual(doc.embedded_field.int_field, 1)
- self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'})
- self.assertEqual(doc.embedded_field.list_field,
- ['1', 2, {'hello': 'world'}])
-
- def test_complex_embedded_documents(self):
- """Test complex dynamic embedded documents setups"""
- class Embedded(DynamicEmbeddedDocument):
- pass
-
- class Doc(DynamicDocument):
- pass
-
- Doc.drop_collection()
- doc = Doc()
-
- embedded_1 = Embedded()
- embedded_1.string_field = 'hello'
- embedded_1.int_field = 1
- embedded_1.dict_field = {'hello': 'world'}
-
- embedded_2 = Embedded()
- embedded_2.string_field = 'hello'
- embedded_2.int_field = 1
- embedded_2.dict_field = {'hello': 'world'}
- embedded_2.list_field = ['1', 2, {'hello': 'world'}]
-
- embedded_1.list_field = ['1', 2, embedded_2]
- doc.embedded_field = embedded_1
-
- self.assertEqual(doc.to_mongo(), {
- "embedded_field": {
- "_cls": "Embedded",
- "string_field": "hello",
- "int_field": 1,
- "dict_field": {"hello": "world"},
- "list_field": ['1', 2,
- {"_cls": "Embedded",
+ self.assertEqual(
+ doc.to_mongo(),
+ {
+ "embedded_field": {
+ "_cls": "Embedded",
"string_field": "hello",
"int_field": 1,
"dict_field": {"hello": "world"},
- "list_field": ['1', 2, {'hello': 'world'}]}
- ]
- }
- })
+ "list_field": ["1", 2, {"hello": "world"}],
+ }
+ },
+ )
+ doc.save()
+
+ doc = Doc.objects.first()
+ self.assertEqual(doc.embedded_field.__class__, Embedded)
+ self.assertEqual(doc.embedded_field.string_field, "hello")
+ self.assertEqual(doc.embedded_field.int_field, 1)
+ self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"})
+ self.assertEqual(doc.embedded_field.list_field, ["1", 2, {"hello": "world"}])
+
+ def test_complex_embedded_documents(self):
+ """Test complex dynamic embedded documents setups"""
+
+ class Embedded(DynamicEmbeddedDocument):
+ pass
+
+ class Doc(DynamicDocument):
+ pass
+
+ Doc.drop_collection()
+ doc = Doc()
+
+ embedded_1 = Embedded()
+ embedded_1.string_field = "hello"
+ embedded_1.int_field = 1
+ embedded_1.dict_field = {"hello": "world"}
+
+ embedded_2 = Embedded()
+ embedded_2.string_field = "hello"
+ embedded_2.int_field = 1
+ embedded_2.dict_field = {"hello": "world"}
+ embedded_2.list_field = ["1", 2, {"hello": "world"}]
+
+ embedded_1.list_field = ["1", 2, embedded_2]
+ doc.embedded_field = embedded_1
+
+ self.assertEqual(
+ doc.to_mongo(),
+ {
+ "embedded_field": {
+ "_cls": "Embedded",
+ "string_field": "hello",
+ "int_field": 1,
+ "dict_field": {"hello": "world"},
+ "list_field": [
+ "1",
+ 2,
+ {
+ "_cls": "Embedded",
+ "string_field": "hello",
+ "int_field": 1,
+ "dict_field": {"hello": "world"},
+ "list_field": ["1", 2, {"hello": "world"}],
+ },
+ ],
+ }
+ },
+ )
doc.save()
doc = Doc.objects.first()
self.assertEqual(doc.embedded_field.__class__, Embedded)
self.assertEqual(doc.embedded_field.string_field, "hello")
self.assertEqual(doc.embedded_field.int_field, 1)
- self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'})
- self.assertEqual(doc.embedded_field.list_field[0], '1')
+ self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"})
+ self.assertEqual(doc.embedded_field.list_field[0], "1")
self.assertEqual(doc.embedded_field.list_field[1], 2)
embedded_field = doc.embedded_field.list_field[2]
@@ -347,9 +332,8 @@ class TestDynamicDocument(MongoDBTestCase):
self.assertEqual(embedded_field.__class__, Embedded)
self.assertEqual(embedded_field.string_field, "hello")
self.assertEqual(embedded_field.int_field, 1)
- self.assertEqual(embedded_field.dict_field, {'hello': 'world'})
- self.assertEqual(embedded_field.list_field, ['1', 2,
- {'hello': 'world'}])
+ self.assertEqual(embedded_field.dict_field, {"hello": "world"})
+ self.assertEqual(embedded_field.list_field, ["1", 2, {"hello": "world"}])
def test_dynamic_and_embedded(self):
"""Ensure embedded documents play nicely"""
@@ -392,10 +376,15 @@ class TestDynamicDocument(MongoDBTestCase):
Person.drop_collection()
- Person(name="Eric", address=Address(city="San Francisco", street_number="1337")).save()
+ Person(
+ name="Eric", address=Address(city="San Francisco", street_number="1337")
+ ).save()
- self.assertEqual(Person.objects.first().address.street_number, '1337')
- self.assertEqual(Person.objects.only('address__street_number').first().address.street_number, '1337')
+ self.assertEqual(Person.objects.first().address.street_number, "1337")
+ self.assertEqual(
+ Person.objects.only("address__street_number").first().address.street_number,
+ "1337",
+ )
def test_dynamic_and_embedded_dict_access(self):
"""Ensure embedded dynamic documents work with dict[] style access"""
@@ -435,5 +424,5 @@ class TestDynamicDocument(MongoDBTestCase):
self.assertEqual(Person.objects.first().age, 35)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/indexes.py b/tests/document/indexes.py
index 764ef0c5..570e619e 100644
--- a/tests/document/indexes.py
+++ b/tests/document/indexes.py
@@ -10,13 +10,12 @@ from six import iteritems
from mongoengine import *
from mongoengine.connection import get_db
-__all__ = ("IndexesTest", )
+__all__ = ("IndexesTest",)
class IndexesTest(unittest.TestCase):
-
def setUp(self):
- self.connection = connect(db='mongoenginetest')
+ self.connection = connect(db="mongoenginetest")
self.db = get_db()
class Person(Document):
@@ -45,52 +44,43 @@ class IndexesTest(unittest.TestCase):
self._index_test(DynamicDocument)
def _index_test(self, InheritFrom):
-
class BlogPost(InheritFrom):
- date = DateTimeField(db_field='addDate', default=datetime.now)
+ date = DateTimeField(db_field="addDate", default=datetime.now)
category = StringField()
tags = ListField(StringField())
- meta = {
- 'indexes': [
- '-date',
- 'tags',
- ('category', '-date')
- ]
- }
+ meta = {"indexes": ["-date", "tags", ("category", "-date")]}
- expected_specs = [{'fields': [('addDate', -1)]},
- {'fields': [('tags', 1)]},
- {'fields': [('category', 1), ('addDate', -1)]}]
- self.assertEqual(expected_specs, BlogPost._meta['index_specs'])
+ expected_specs = [
+ {"fields": [("addDate", -1)]},
+ {"fields": [("tags", 1)]},
+ {"fields": [("category", 1), ("addDate", -1)]},
+ ]
+ self.assertEqual(expected_specs, BlogPost._meta["index_specs"])
BlogPost.ensure_indexes()
info = BlogPost.objects._collection.index_information()
# _id, '-date', 'tags', ('cat', 'date')
self.assertEqual(len(info), 4)
- info = [value['key'] for key, value in iteritems(info)]
+ info = [value["key"] for key, value in iteritems(info)]
for expected in expected_specs:
- self.assertIn(expected['fields'], info)
+ self.assertIn(expected["fields"], info)
def _index_test_inheritance(self, InheritFrom):
-
class BlogPost(InheritFrom):
- date = DateTimeField(db_field='addDate', default=datetime.now)
+ date = DateTimeField(db_field="addDate", default=datetime.now)
category = StringField()
tags = ListField(StringField())
meta = {
- 'indexes': [
- '-date',
- 'tags',
- ('category', '-date')
- ],
- 'allow_inheritance': True
+ "indexes": ["-date", "tags", ("category", "-date")],
+ "allow_inheritance": True,
}
- expected_specs = [{'fields': [('_cls', 1), ('addDate', -1)]},
- {'fields': [('_cls', 1), ('tags', 1)]},
- {'fields': [('_cls', 1), ('category', 1),
- ('addDate', -1)]}]
- self.assertEqual(expected_specs, BlogPost._meta['index_specs'])
+ expected_specs = [
+ {"fields": [("_cls", 1), ("addDate", -1)]},
+ {"fields": [("_cls", 1), ("tags", 1)]},
+ {"fields": [("_cls", 1), ("category", 1), ("addDate", -1)]},
+ ]
+ self.assertEqual(expected_specs, BlogPost._meta["index_specs"])
BlogPost.ensure_indexes()
info = BlogPost.objects._collection.index_information()
@@ -99,24 +89,24 @@ class IndexesTest(unittest.TestCase):
# the indices on -date and tags will both contain
# _cls as first element in the key
self.assertEqual(len(info), 4)
- info = [value['key'] for key, value in iteritems(info)]
+ info = [value["key"] for key, value in iteritems(info)]
for expected in expected_specs:
- self.assertIn(expected['fields'], info)
+ self.assertIn(expected["fields"], info)
class ExtendedBlogPost(BlogPost):
title = StringField()
- meta = {'indexes': ['title']}
+ meta = {"indexes": ["title"]}
- expected_specs.append({'fields': [('_cls', 1), ('title', 1)]})
- self.assertEqual(expected_specs, ExtendedBlogPost._meta['index_specs'])
+ expected_specs.append({"fields": [("_cls", 1), ("title", 1)]})
+ self.assertEqual(expected_specs, ExtendedBlogPost._meta["index_specs"])
BlogPost.drop_collection()
ExtendedBlogPost.ensure_indexes()
info = ExtendedBlogPost.objects._collection.index_information()
- info = [value['key'] for key, value in iteritems(info)]
+ info = [value["key"] for key, value in iteritems(info)]
for expected in expected_specs:
- self.assertIn(expected['fields'], info)
+ self.assertIn(expected["fields"], info)
def test_indexes_document_inheritance(self):
"""Ensure that indexes are used when meta[indexes] is specified for
@@ -135,21 +125,15 @@ class IndexesTest(unittest.TestCase):
class A(Document):
title = StringField()
- meta = {
- 'indexes': [
- {
- 'fields': ('title',),
- },
- ],
- 'allow_inheritance': True,
- }
+ meta = {"indexes": [{"fields": ("title",)}], "allow_inheritance": True}
class B(A):
description = StringField()
- self.assertEqual(A._meta['index_specs'], B._meta['index_specs'])
- self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}],
- A._meta['index_specs'])
+ self.assertEqual(A._meta["index_specs"], B._meta["index_specs"])
+ self.assertEqual(
+ [{"fields": [("_cls", 1), ("title", 1)]}], A._meta["index_specs"]
+ )
def test_index_no_cls(self):
"""Ensure index specs are inhertited correctly"""
@@ -157,14 +141,12 @@ class IndexesTest(unittest.TestCase):
class A(Document):
title = StringField()
meta = {
- 'indexes': [
- {'fields': ('title',), 'cls': False},
- ],
- 'allow_inheritance': True,
- 'index_cls': False
- }
+ "indexes": [{"fields": ("title",), "cls": False}],
+ "allow_inheritance": True,
+ "index_cls": False,
+ }
- self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields'])
+ self.assertEqual([("title", 1)], A._meta["index_specs"][0]["fields"])
A._get_collection().drop_indexes()
A.ensure_indexes()
info = A._get_collection().index_information()
@@ -174,34 +156,30 @@ class IndexesTest(unittest.TestCase):
c = StringField()
d = StringField()
meta = {
- 'indexes': [{'fields': ['c']}, {'fields': ['d'], 'cls': True}],
- 'allow_inheritance': True
+ "indexes": [{"fields": ["c"]}, {"fields": ["d"], "cls": True}],
+ "allow_inheritance": True,
}
- self.assertEqual([('c', 1)], B._meta['index_specs'][1]['fields'])
- self.assertEqual([('_cls', 1), ('d', 1)], B._meta['index_specs'][2]['fields'])
+
+ self.assertEqual([("c", 1)], B._meta["index_specs"][1]["fields"])
+ self.assertEqual([("_cls", 1), ("d", 1)], B._meta["index_specs"][2]["fields"])
def test_build_index_spec_is_not_destructive(self):
-
class MyDoc(Document):
keywords = StringField()
- meta = {
- 'indexes': ['keywords'],
- 'allow_inheritance': False
- }
+ meta = {"indexes": ["keywords"], "allow_inheritance": False}
- self.assertEqual(MyDoc._meta['index_specs'],
- [{'fields': [('keywords', 1)]}])
+ self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}])
# Force index creation
MyDoc.ensure_indexes()
- self.assertEqual(MyDoc._meta['index_specs'],
- [{'fields': [('keywords', 1)]}])
+ self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}])
def test_embedded_document_index_meta(self):
"""Ensure that embedded document indexes are created explicitly
"""
+
class Rank(EmbeddedDocument):
title = StringField(required=True)
@@ -209,138 +187,123 @@ class IndexesTest(unittest.TestCase):
name = StringField(required=True)
rank = EmbeddedDocumentField(Rank, required=False)
- meta = {
- 'indexes': [
- 'rank.title',
- ],
- 'allow_inheritance': False
- }
+ meta = {"indexes": ["rank.title"], "allow_inheritance": False}
- self.assertEqual([{'fields': [('rank.title', 1)]}],
- Person._meta['index_specs'])
+ self.assertEqual([{"fields": [("rank.title", 1)]}], Person._meta["index_specs"])
Person.drop_collection()
# Indexes are lazy so use list() to perform query
list(Person.objects)
info = Person.objects._collection.index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('rank.title', 1)], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("rank.title", 1)], info)
def test_explicit_geo2d_index(self):
"""Ensure that geo2d indexes work when created via meta[indexes]
"""
+
class Place(Document):
location = DictField()
- meta = {
- 'allow_inheritance': True,
- 'indexes': [
- '*location.point',
- ]
- }
+ meta = {"allow_inheritance": True, "indexes": ["*location.point"]}
- self.assertEqual([{'fields': [('location.point', '2d')]}],
- Place._meta['index_specs'])
+ self.assertEqual(
+ [{"fields": [("location.point", "2d")]}], Place._meta["index_specs"]
+ )
Place.ensure_indexes()
info = Place._get_collection().index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('location.point', '2d')], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("location.point", "2d")], info)
def test_explicit_geo2d_index_embedded(self):
"""Ensure that geo2d indexes work when created via meta[indexes]
"""
+
class EmbeddedLocation(EmbeddedDocument):
location = DictField()
class Place(Document):
- current = DictField(field=EmbeddedDocumentField('EmbeddedLocation'))
- meta = {
- 'allow_inheritance': True,
- 'indexes': [
- '*current.location.point',
- ]
- }
+ current = DictField(field=EmbeddedDocumentField("EmbeddedLocation"))
+ meta = {"allow_inheritance": True, "indexes": ["*current.location.point"]}
- self.assertEqual([{'fields': [('current.location.point', '2d')]}],
- Place._meta['index_specs'])
+ self.assertEqual(
+ [{"fields": [("current.location.point", "2d")]}], Place._meta["index_specs"]
+ )
Place.ensure_indexes()
info = Place._get_collection().index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('current.location.point', '2d')], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("current.location.point", "2d")], info)
def test_explicit_geosphere_index(self):
"""Ensure that geosphere indexes work when created via meta[indexes]
"""
+
class Place(Document):
location = DictField()
- meta = {
- 'allow_inheritance': True,
- 'indexes': [
- '(location.point',
- ]
- }
+ meta = {"allow_inheritance": True, "indexes": ["(location.point"]}
- self.assertEqual([{'fields': [('location.point', '2dsphere')]}],
- Place._meta['index_specs'])
+ self.assertEqual(
+ [{"fields": [("location.point", "2dsphere")]}], Place._meta["index_specs"]
+ )
Place.ensure_indexes()
info = Place._get_collection().index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('location.point', '2dsphere')], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("location.point", "2dsphere")], info)
def test_explicit_geohaystack_index(self):
"""Ensure that geohaystack indexes work when created via meta[indexes]
"""
- raise SkipTest('GeoHaystack index creation is not supported for now'
- 'from meta, as it requires a bucketSize parameter.')
+ raise SkipTest(
+ "GeoHaystack index creation is not supported for now"
+ "from meta, as it requires a bucketSize parameter."
+ )
class Place(Document):
location = DictField()
name = StringField()
- meta = {
- 'indexes': [
- (')location.point', 'name')
- ]
- }
- self.assertEqual([{'fields': [('location.point', 'geoHaystack'), ('name', 1)]}],
- Place._meta['index_specs'])
+ meta = {"indexes": [(")location.point", "name")]}
+
+ self.assertEqual(
+ [{"fields": [("location.point", "geoHaystack"), ("name", 1)]}],
+ Place._meta["index_specs"],
+ )
Place.ensure_indexes()
info = Place._get_collection().index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('location.point', 'geoHaystack')], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("location.point", "geoHaystack")], info)
def test_create_geohaystack_index(self):
"""Ensure that geohaystack indexes can be created
"""
+
class Place(Document):
location = DictField()
name = StringField()
- Place.create_index({'fields': (')location.point', 'name')}, bucketSize=10)
+ Place.create_index({"fields": (")location.point", "name")}, bucketSize=10)
info = Place._get_collection().index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('location.point', 'geoHaystack'), ('name', 1)], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("location.point", "geoHaystack"), ("name", 1)], info)
def test_dictionary_indexes(self):
"""Ensure that indexes are used when meta[indexes] contains
dictionaries instead of lists.
"""
+
class BlogPost(Document):
- date = DateTimeField(db_field='addDate', default=datetime.now)
+ date = DateTimeField(db_field="addDate", default=datetime.now)
category = StringField()
tags = ListField(StringField())
- meta = {
- 'indexes': [
- {'fields': ['-date'], 'unique': True, 'sparse': True},
- ],
- }
+ meta = {"indexes": [{"fields": ["-date"], "unique": True, "sparse": True}]}
- self.assertEqual([{'fields': [('addDate', -1)], 'unique': True,
- 'sparse': True}],
- BlogPost._meta['index_specs'])
+ self.assertEqual(
+ [{"fields": [("addDate", -1)], "unique": True, "sparse": True}],
+ BlogPost._meta["index_specs"],
+ )
BlogPost.drop_collection()
@@ -351,48 +314,48 @@ class IndexesTest(unittest.TestCase):
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
- info = [(value['key'],
- value.get('unique', False),
- value.get('sparse', False))
- for key, value in iteritems(info)]
- self.assertIn(([('addDate', -1)], True, True), info)
+ info = [
+ (value["key"], value.get("unique", False), value.get("sparse", False))
+ for key, value in iteritems(info)
+ ]
+ self.assertIn(([("addDate", -1)], True, True), info)
BlogPost.drop_collection()
def test_abstract_index_inheritance(self):
-
class UserBase(Document):
user_guid = StringField(required=True)
meta = {
- 'abstract': True,
- 'indexes': ['user_guid'],
- 'allow_inheritance': True
+ "abstract": True,
+ "indexes": ["user_guid"],
+ "allow_inheritance": True,
}
class Person(UserBase):
name = StringField()
- meta = {
- 'indexes': ['name'],
- }
+ meta = {"indexes": ["name"]}
+
Person.drop_collection()
- Person(name="test", user_guid='123').save()
+ Person(name="test", user_guid="123").save()
self.assertEqual(1, Person.objects.count())
info = Person.objects._collection.index_information()
- self.assertEqual(sorted(info.keys()),
- ['_cls_1_name_1', '_cls_1_user_guid_1', '_id_'])
+ self.assertEqual(
+ sorted(info.keys()), ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"]
+ )
def test_disable_index_creation(self):
"""Tests setting auto_create_index to False on the connection will
disable any index generation.
"""
+
class User(Document):
meta = {
- 'allow_inheritance': True,
- 'indexes': ['user_guid'],
- 'auto_create_index': False
+ "allow_inheritance": True,
+ "indexes": ["user_guid"],
+ "auto_create_index": False,
}
user_guid = StringField(required=True)
@@ -401,88 +364,81 @@ class IndexesTest(unittest.TestCase):
User.drop_collection()
- User(user_guid='123').save()
- MongoUser(user_guid='123').save()
+ User(user_guid="123").save()
+ MongoUser(user_guid="123").save()
self.assertEqual(2, User.objects.count())
info = User.objects._collection.index_information()
- self.assertEqual(list(info.keys()), ['_id_'])
+ self.assertEqual(list(info.keys()), ["_id_"])
User.ensure_indexes()
info = User.objects._collection.index_information()
- self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_'])
+ self.assertEqual(sorted(info.keys()), ["_cls_1_user_guid_1", "_id_"])
def test_embedded_document_index(self):
"""Tests settings an index on an embedded document
"""
+
class Date(EmbeddedDocument):
- year = IntField(db_field='yr')
+ year = IntField(db_field="yr")
class BlogPost(Document):
title = StringField()
date = EmbeddedDocumentField(Date)
- meta = {
- 'indexes': [
- '-date.year'
- ],
- }
+ meta = {"indexes": ["-date.year"]}
BlogPost.drop_collection()
info = BlogPost.objects._collection.index_information()
- self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1'])
+ self.assertEqual(sorted(info.keys()), ["_id_", "date.yr_-1"])
def test_list_embedded_document_index(self):
"""Ensure list embedded documents can be indexed
"""
+
class Tag(EmbeddedDocument):
- name = StringField(db_field='tag')
+ name = StringField(db_field="tag")
class BlogPost(Document):
title = StringField()
tags = ListField(EmbeddedDocumentField(Tag))
- meta = {
- 'indexes': [
- 'tags.name'
- ]
- }
+ meta = {"indexes": ["tags.name"]}
BlogPost.drop_collection()
info = BlogPost.objects._collection.index_information()
# we don't use _cls in with list fields by default
- self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1'])
+ self.assertEqual(sorted(info.keys()), ["_id_", "tags.tag_1"])
- post1 = BlogPost(title="Embedded Indexes tests in place",
- tags=[Tag(name="about"), Tag(name="time")])
+ post1 = BlogPost(
+ title="Embedded Indexes tests in place",
+ tags=[Tag(name="about"), Tag(name="time")],
+ )
post1.save()
def test_recursive_embedded_objects_dont_break_indexes(self):
-
class RecursiveObject(EmbeddedDocument):
- obj = EmbeddedDocumentField('self')
+ obj = EmbeddedDocumentField("self")
class RecursiveDocument(Document):
recursive_obj = EmbeddedDocumentField(RecursiveObject)
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
RecursiveDocument.ensure_indexes()
info = RecursiveDocument._get_collection().index_information()
- self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_'])
+ self.assertEqual(sorted(info.keys()), ["_cls_1", "_id_"])
def test_covered_index(self):
"""Ensure that covered indexes can be used
"""
+
class Test(Document):
a = IntField()
b = IntField()
- meta = {
- 'indexes': ['a'],
- 'allow_inheritance': False
- }
+ meta = {"indexes": ["a"], "allow_inheritance": False}
Test.drop_collection()
@@ -491,45 +447,51 @@ class IndexesTest(unittest.TestCase):
# Need to be explicit about covered indexes as mongoDB doesn't know if
# the documents returned might have more keys in that here.
- query_plan = Test.objects(id=obj.id).exclude('a').explain()
+ query_plan = Test.objects(id=obj.id).exclude("a").explain()
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
- 'IDHACK'
+ query_plan.get("queryPlanner")
+ .get("winningPlan")
+ .get("inputStage")
+ .get("stage"),
+ "IDHACK",
)
- query_plan = Test.objects(id=obj.id).only('id').explain()
+ query_plan = Test.objects(id=obj.id).only("id").explain()
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
- 'IDHACK'
+ query_plan.get("queryPlanner")
+ .get("winningPlan")
+ .get("inputStage")
+ .get("stage"),
+ "IDHACK",
)
- query_plan = Test.objects(a=1).only('a').exclude('id').explain()
+ query_plan = Test.objects(a=1).only("a").exclude("id").explain()
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
- 'IXSCAN'
+ query_plan.get("queryPlanner")
+ .get("winningPlan")
+ .get("inputStage")
+ .get("stage"),
+ "IXSCAN",
)
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('stage'),
- 'PROJECTION'
+ query_plan.get("queryPlanner").get("winningPlan").get("stage"), "PROJECTION"
)
query_plan = Test.objects(a=1).explain()
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'),
- 'IXSCAN'
+ query_plan.get("queryPlanner")
+ .get("winningPlan")
+ .get("inputStage")
+ .get("stage"),
+ "IXSCAN",
)
self.assertEqual(
- query_plan.get('queryPlanner').get('winningPlan').get('stage'),
- 'FETCH'
+ query_plan.get("queryPlanner").get("winningPlan").get("stage"), "FETCH"
)
def test_index_on_id(self):
class BlogPost(Document):
- meta = {
- 'indexes': [
- ['categories', 'id']
- ]
- }
+ meta = {"indexes": [["categories", "id"]]}
title = StringField(required=True)
description = StringField(required=True)
@@ -538,22 +500,16 @@ class IndexesTest(unittest.TestCase):
BlogPost.drop_collection()
indexes = BlogPost.objects._collection.index_information()
- self.assertEqual(indexes['categories_1__id_1']['key'],
- [('categories', 1), ('_id', 1)])
+ self.assertEqual(
+ indexes["categories_1__id_1"]["key"], [("categories", 1), ("_id", 1)]
+ )
def test_hint(self):
- TAGS_INDEX_NAME = 'tags_1'
+ TAGS_INDEX_NAME = "tags_1"
class BlogPost(Document):
tags = ListField(StringField())
- meta = {
- 'indexes': [
- {
- 'fields': ['tags'],
- 'name': TAGS_INDEX_NAME
- }
- ],
- }
+ meta = {"indexes": [{"fields": ["tags"], "name": TAGS_INDEX_NAME}]}
BlogPost.drop_collection()
@@ -562,41 +518,42 @@ class IndexesTest(unittest.TestCase):
BlogPost(tags=tags).save()
# Hinting by shape should work.
- self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10)
+ self.assertEqual(BlogPost.objects.hint([("tags", 1)]).count(), 10)
# Hinting by index name should work.
self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10)
# Clearing the hint should work fine.
self.assertEqual(BlogPost.objects.hint().count(), 10)
- self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).hint().count(), 10)
+ self.assertEqual(BlogPost.objects.hint([("ZZ", 1)]).hint().count(), 10)
# Hinting on a non-existent index shape should fail.
with self.assertRaises(OperationFailure):
- BlogPost.objects.hint([('ZZ', 1)]).count()
+ BlogPost.objects.hint([("ZZ", 1)]).count()
# Hinting on a non-existent index name should fail.
with self.assertRaises(OperationFailure):
- BlogPost.objects.hint('Bad Name').count()
+ BlogPost.objects.hint("Bad Name").count()
# Invalid shape argument (missing list brackets) should fail.
with self.assertRaises(ValueError):
- BlogPost.objects.hint(('tags', 1)).count()
+ BlogPost.objects.hint(("tags", 1)).count()
def test_unique(self):
"""Ensure that uniqueness constraints are applied to fields.
"""
+
class BlogPost(Document):
title = StringField()
slug = StringField(unique=True)
BlogPost.drop_collection()
- post1 = BlogPost(title='test1', slug='test')
+ post1 = BlogPost(title="test1", slug="test")
post1.save()
# Two posts with the same slug is not allowed
- post2 = BlogPost(title='test2', slug='test')
+ post2 = BlogPost(title="test2", slug="test")
self.assertRaises(NotUniqueError, post2.save)
self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2)
@@ -605,54 +562,62 @@ class IndexesTest(unittest.TestCase):
def test_primary_key_unique_not_working(self):
"""Relates to #1445"""
+
class Blog(Document):
id = StringField(primary_key=True, unique=True)
Blog.drop_collection()
with self.assertRaises(OperationFailure) as ctx_err:
- Blog(id='garbage').save()
+ Blog(id="garbage").save()
# One of the errors below should happen. Which one depends on the
# PyMongo version and dict order.
err_msg = str(ctx_err.exception)
self.assertTrue(
- any([
- "The field 'unique' is not valid for an _id index specification" in err_msg,
- "The field 'background' is not valid for an _id index specification" in err_msg,
- "The field 'sparse' is not valid for an _id index specification" in err_msg,
- ])
+ any(
+ [
+ "The field 'unique' is not valid for an _id index specification"
+ in err_msg,
+ "The field 'background' is not valid for an _id index specification"
+ in err_msg,
+ "The field 'sparse' is not valid for an _id index specification"
+ in err_msg,
+ ]
+ )
)
def test_unique_with(self):
"""Ensure that unique_with constraints are applied to fields.
"""
+
class Date(EmbeddedDocument):
- year = IntField(db_field='yr')
+ year = IntField(db_field="yr")
class BlogPost(Document):
title = StringField()
date = EmbeddedDocumentField(Date)
- slug = StringField(unique_with='date.year')
+ slug = StringField(unique_with="date.year")
BlogPost.drop_collection()
- post1 = BlogPost(title='test1', date=Date(year=2009), slug='test')
+ post1 = BlogPost(title="test1", date=Date(year=2009), slug="test")
post1.save()
# day is different so won't raise exception
- post2 = BlogPost(title='test2', date=Date(year=2010), slug='test')
+ post2 = BlogPost(title="test2", date=Date(year=2010), slug="test")
post2.save()
# Now there will be two docs with the same slug and the same day: fail
- post3 = BlogPost(title='test3', date=Date(year=2010), slug='test')
+ post3 = BlogPost(title="test3", date=Date(year=2010), slug="test")
self.assertRaises(OperationError, post3.save)
def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents.
"""
+
class SubDocument(EmbeddedDocument):
- year = IntField(db_field='yr')
+ year = IntField(db_field="yr")
slug = StringField(unique=True)
class BlogPost(Document):
@@ -661,18 +626,15 @@ class IndexesTest(unittest.TestCase):
BlogPost.drop_collection()
- post1 = BlogPost(title='test1',
- sub=SubDocument(year=2009, slug="test"))
+ post1 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
- post2 = BlogPost(title='test2',
- sub=SubDocument(year=2010, slug='another-slug'))
+ post2 = BlogPost(title="test2", sub=SubDocument(year=2010, slug="another-slug"))
post2.save()
# Now there will be two docs with the same sub.slug
- post3 = BlogPost(title='test3',
- sub=SubDocument(year=2010, slug='test'))
+ post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test"))
self.assertRaises(NotUniqueError, post3.save)
def test_unique_embedded_document_in_list(self):
@@ -681,8 +643,9 @@ class IndexesTest(unittest.TestCase):
embedded documents, even when the embedded documents in in a
list field.
"""
+
class SubDocument(EmbeddedDocument):
- year = IntField(db_field='yr')
+ year = IntField(db_field="yr")
slug = StringField(unique=True)
class BlogPost(Document):
@@ -692,16 +655,15 @@ class IndexesTest(unittest.TestCase):
BlogPost.drop_collection()
post1 = BlogPost(
- title='test1', subs=[
- SubDocument(year=2009, slug='conflict'),
- SubDocument(year=2009, slug='conflict')
- ]
+ title="test1",
+ subs=[
+ SubDocument(year=2009, slug="conflict"),
+ SubDocument(year=2009, slug="conflict"),
+ ],
)
post1.save()
- post2 = BlogPost(
- title='test2', subs=[SubDocument(year=2014, slug='conflict')]
- )
+ post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")])
self.assertRaises(NotUniqueError, post2.save)
@@ -711,33 +673,32 @@ class IndexesTest(unittest.TestCase):
embedded documents, even when the embedded documents in a sorted list
field.
"""
+
class SubDocument(EmbeddedDocument):
year = IntField()
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField()
- subs = SortedListField(EmbeddedDocumentField(SubDocument),
- ordering='year')
+ subs = SortedListField(EmbeddedDocumentField(SubDocument), ordering="year")
BlogPost.drop_collection()
post1 = BlogPost(
- title='test1', subs=[
- SubDocument(year=2009, slug='conflict'),
- SubDocument(year=2009, slug='conflict')
- ]
+ title="test1",
+ subs=[
+ SubDocument(year=2009, slug="conflict"),
+ SubDocument(year=2009, slug="conflict"),
+ ],
)
post1.save()
# confirm that the unique index is created
indexes = BlogPost._get_collection().index_information()
- self.assertIn('subs.slug_1', indexes)
- self.assertTrue(indexes['subs.slug_1']['unique'])
+ self.assertIn("subs.slug_1", indexes)
+ self.assertTrue(indexes["subs.slug_1"]["unique"])
- post2 = BlogPost(
- title='test2', subs=[SubDocument(year=2014, slug='conflict')]
- )
+ post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")])
self.assertRaises(NotUniqueError, post2.save)
@@ -747,6 +708,7 @@ class IndexesTest(unittest.TestCase):
embedded documents, even when the embedded documents in an embedded
list field.
"""
+
class SubDocument(EmbeddedDocument):
year = IntField()
slug = StringField(unique=True)
@@ -758,21 +720,20 @@ class IndexesTest(unittest.TestCase):
BlogPost.drop_collection()
post1 = BlogPost(
- title='test1', subs=[
- SubDocument(year=2009, slug='conflict'),
- SubDocument(year=2009, slug='conflict')
- ]
+ title="test1",
+ subs=[
+ SubDocument(year=2009, slug="conflict"),
+ SubDocument(year=2009, slug="conflict"),
+ ],
)
post1.save()
# confirm that the unique index is created
indexes = BlogPost._get_collection().index_information()
- self.assertIn('subs.slug_1', indexes)
- self.assertTrue(indexes['subs.slug_1']['unique'])
+ self.assertIn("subs.slug_1", indexes)
+ self.assertTrue(indexes["subs.slug_1"]["unique"])
- post2 = BlogPost(
- title='test2', subs=[SubDocument(year=2014, slug='conflict')]
- )
+ post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")])
self.assertRaises(NotUniqueError, post2.save)
@@ -780,60 +741,51 @@ class IndexesTest(unittest.TestCase):
"""Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well.
"""
+
class SubDocument(EmbeddedDocument):
- year = IntField(db_field='yr')
+ year = IntField(db_field="yr")
slug = StringField(unique=True)
class BlogPost(Document):
- title = StringField(unique_with='sub.year')
+ title = StringField(unique_with="sub.year")
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
- post1 = BlogPost(title='test1',
- sub=SubDocument(year=2009, slug="test"))
+ post1 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
- post2 = BlogPost(title='test2',
- sub=SubDocument(year=2010, slug='another-slug'))
+ post2 = BlogPost(title="test2", sub=SubDocument(year=2010, slug="another-slug"))
post2.save()
# Now there will be two docs with the same sub.slug
- post3 = BlogPost(title='test3',
- sub=SubDocument(year=2010, slug='test'))
+ post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test"))
self.assertRaises(NotUniqueError, post3.save)
# Now there will be two docs with the same title and year
- post3 = BlogPost(title='test1',
- sub=SubDocument(year=2009, slug='test-1'))
+ post3 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test-1"))
self.assertRaises(NotUniqueError, post3.save)
def test_ttl_indexes(self):
-
class Log(Document):
created = DateTimeField(default=datetime.now)
- meta = {
- 'indexes': [
- {'fields': ['created'], 'expireAfterSeconds': 3600}
- ]
- }
+ meta = {"indexes": [{"fields": ["created"], "expireAfterSeconds": 3600}]}
Log.drop_collection()
# Indexes are lazy so use list() to perform query
list(Log.objects)
info = Log.objects._collection.index_information()
- self.assertEqual(3600,
- info['created_1']['expireAfterSeconds'])
+ self.assertEqual(3600, info["created_1"]["expireAfterSeconds"])
def test_index_drop_dups_silently_ignored(self):
class Customer(Document):
cust_id = IntField(unique=True, required=True)
meta = {
- 'indexes': ['cust_id'],
- 'index_drop_dups': True,
- 'allow_inheritance': False,
+ "indexes": ["cust_id"],
+ "index_drop_dups": True,
+ "allow_inheritance": False,
}
Customer.drop_collection()
@@ -843,12 +795,10 @@ class IndexesTest(unittest.TestCase):
"""Ensure that 'unique' constraints aren't overridden by
meta.indexes.
"""
+
class Customer(Document):
cust_id = IntField(unique=True, required=True)
- meta = {
- 'indexes': ['cust_id'],
- 'allow_inheritance': False,
- }
+ meta = {"indexes": ["cust_id"], "allow_inheritance": False}
Customer.drop_collection()
cust = Customer(cust_id=1)
@@ -870,37 +820,39 @@ class IndexesTest(unittest.TestCase):
"""If you set a field as primary, then unexpected behaviour can occur.
You won't create a duplicate but you will update an existing document.
"""
+
class User(Document):
name = StringField(primary_key=True)
password = StringField()
User.drop_collection()
- user = User(name='huangz', password='secret')
+ user = User(name="huangz", password="secret")
user.save()
- user = User(name='huangz', password='secret2')
+ user = User(name="huangz", password="secret2")
user.save()
self.assertEqual(User.objects.count(), 1)
- self.assertEqual(User.objects.get().password, 'secret2')
+ self.assertEqual(User.objects.get().password, "secret2")
def test_unique_and_primary_create(self):
"""Create a new record with a duplicate primary key
throws an exception
"""
+
class User(Document):
name = StringField(primary_key=True)
password = StringField()
User.drop_collection()
- User.objects.create(name='huangz', password='secret')
+ User.objects.create(name="huangz", password="secret")
with self.assertRaises(NotUniqueError):
- User.objects.create(name='huangz', password='secret2')
+ User.objects.create(name="huangz", password="secret2")
self.assertEqual(User.objects.count(), 1)
- self.assertEqual(User.objects.get().password, 'secret')
+ self.assertEqual(User.objects.get().password, "secret")
def test_index_with_pk(self):
"""Ensure you can use `pk` as part of a query"""
@@ -909,21 +861,24 @@ class IndexesTest(unittest.TestCase):
comment_id = IntField(required=True)
try:
+
class BlogPost(Document):
comments = EmbeddedDocumentField(Comment)
- meta = {'indexes': [
- {'fields': ['pk', 'comments.comment_id'],
- 'unique': True}]}
+ meta = {
+ "indexes": [
+ {"fields": ["pk", "comments.comment_id"], "unique": True}
+ ]
+ }
+
except UnboundLocalError:
- self.fail('Unbound local error at index + pk definition')
+ self.fail("Unbound local error at index + pk definition")
info = BlogPost.objects._collection.index_information()
- info = [value['key'] for key, value in iteritems(info)]
- index_item = [('_id', 1), ('comments.comment_id', 1)]
+ info = [value["key"] for key, value in iteritems(info)]
+ index_item = [("_id", 1), ("comments.comment_id", 1)]
self.assertIn(index_item, info)
def test_compound_key_embedded(self):
-
class CompoundKey(EmbeddedDocument):
name = StringField(required=True)
term = StringField(required=True)
@@ -935,12 +890,12 @@ class IndexesTest(unittest.TestCase):
my_key = CompoundKey(name="n", term="ok")
report = ReportEmbedded(text="OK", key=my_key).save()
- self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
- report.to_mongo())
+ self.assertEqual(
+ {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo()
+ )
self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key))
def test_compound_key_dictfield(self):
-
class ReportDictField(Document):
key = DictField(primary_key=True)
text = StringField()
@@ -948,65 +903,60 @@ class IndexesTest(unittest.TestCase):
my_key = {"name": "n", "term": "ok"}
report = ReportDictField(text="OK", key=my_key).save()
- self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
- report.to_mongo())
+ self.assertEqual(
+ {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo()
+ )
# We can't directly call ReportDictField.objects.get(pk=my_key),
# because dicts are unordered, and if the order in MongoDB is
# different than the one in `my_key`, this test will fail.
- self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key['name']))
- self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key['term']))
+ self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key["name"]))
+ self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key["term"]))
def test_string_indexes(self):
-
class MyDoc(Document):
provider_ids = DictField()
- meta = {
- "indexes": ["provider_ids.foo", "provider_ids.bar"],
- }
+ meta = {"indexes": ["provider_ids.foo", "provider_ids.bar"]}
info = MyDoc.objects._collection.index_information()
- info = [value['key'] for key, value in iteritems(info)]
- self.assertIn([('provider_ids.foo', 1)], info)
- self.assertIn([('provider_ids.bar', 1)], info)
+ info = [value["key"] for key, value in iteritems(info)]
+ self.assertIn([("provider_ids.foo", 1)], info)
+ self.assertIn([("provider_ids.bar", 1)], info)
def test_sparse_compound_indexes(self):
-
class MyDoc(Document):
provider_ids = DictField()
meta = {
- "indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"),
- 'sparse': True}],
+ "indexes": [
+ {"fields": ("provider_ids.foo", "provider_ids.bar"), "sparse": True}
+ ]
}
info = MyDoc.objects._collection.index_information()
- self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)],
- info['provider_ids.foo_1_provider_ids.bar_1']['key'])
- self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse'])
+ self.assertEqual(
+ [("provider_ids.foo", 1), ("provider_ids.bar", 1)],
+ info["provider_ids.foo_1_provider_ids.bar_1"]["key"],
+ )
+ self.assertTrue(info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"])
def test_text_indexes(self):
class Book(Document):
title = DictField()
- meta = {
- "indexes": ["$title"],
- }
+ meta = {"indexes": ["$title"]}
indexes = Book.objects._collection.index_information()
self.assertIn("title_text", indexes)
key = indexes["title_text"]["key"]
- self.assertIn(('_fts', 'text'), key)
+ self.assertIn(("_fts", "text"), key)
def test_hashed_indexes(self):
-
class Book(Document):
ref_id = StringField()
- meta = {
- "indexes": ["#ref_id"],
- }
+ meta = {"indexes": ["#ref_id"]}
indexes = Book.objects._collection.index_information()
self.assertIn("ref_id_hashed", indexes)
- self.assertIn(('ref_id', 'hashed'), indexes["ref_id_hashed"]["key"])
+ self.assertIn(("ref_id", "hashed"), indexes["ref_id_hashed"]["key"])
def test_indexes_after_database_drop(self):
"""
@@ -1017,35 +967,36 @@ class IndexesTest(unittest.TestCase):
"""
# Use a new connection and database since dropping the database could
# cause concurrent tests to fail.
- connection = connect(db='tempdatabase',
- alias='test_indexes_after_database_drop')
+ connection = connect(
+ db="tempdatabase", alias="test_indexes_after_database_drop"
+ )
class BlogPost(Document):
title = StringField()
slug = StringField(unique=True)
- meta = {'db_alias': 'test_indexes_after_database_drop'}
+ meta = {"db_alias": "test_indexes_after_database_drop"}
try:
BlogPost.drop_collection()
# Create Post #1
- post1 = BlogPost(title='test1', slug='test')
+ post1 = BlogPost(title="test1", slug="test")
post1.save()
# Drop the Database
- connection.drop_database('tempdatabase')
+ connection.drop_database("tempdatabase")
# Re-create Post #1
- post1 = BlogPost(title='test1', slug='test')
+ post1 = BlogPost(title="test1", slug="test")
post1.save()
# Create Post #2
- post2 = BlogPost(title='test2', slug='test')
+ post2 = BlogPost(title="test2", slug="test")
self.assertRaises(NotUniqueError, post2.save)
finally:
# Drop the temporary database at the end
- connection.drop_database('tempdatabase')
+ connection.drop_database("tempdatabase")
def test_index_dont_send_cls_option(self):
"""
@@ -1057,24 +1008,19 @@ class IndexesTest(unittest.TestCase):
options that are passed to ensureIndex. For more details, see:
https://jira.mongodb.org/browse/SERVER-769
"""
+
class TestDoc(Document):
txt = StringField()
meta = {
- 'allow_inheritance': True,
- 'indexes': [
- {'fields': ('txt',), 'cls': False}
- ]
+ "allow_inheritance": True,
+ "indexes": [{"fields": ("txt",), "cls": False}],
}
class TestChildDoc(TestDoc):
txt2 = StringField()
- meta = {
- 'indexes': [
- {'fields': ('txt2',), 'cls': False}
- ]
- }
+ meta = {"indexes": [{"fields": ("txt2",), "cls": False}]}
TestDoc.drop_collection()
TestDoc.ensure_indexes()
@@ -1082,54 +1028,51 @@ class IndexesTest(unittest.TestCase):
index_info = TestDoc._get_collection().index_information()
for key in index_info:
- del index_info[key]['v'] # drop the index version - we don't care about that here
- if 'ns' in index_info[key]:
- del index_info[key]['ns'] # drop the index namespace - we don't care about that here, MongoDB 3+
- if 'dropDups' in index_info[key]:
- del index_info[key]['dropDups'] # drop the index dropDups - it is deprecated in MongoDB 3+
+ del index_info[key][
+ "v"
+ ] # drop the index version - we don't care about that here
+ if "ns" in index_info[key]:
+ del index_info[key][
+ "ns"
+ ] # drop the index namespace - we don't care about that here, MongoDB 3+
+ if "dropDups" in index_info[key]:
+ del index_info[key][
+ "dropDups"
+ ] # drop the index dropDups - it is deprecated in MongoDB 3+
- self.assertEqual(index_info, {
- 'txt_1': {
- 'key': [('txt', 1)],
- 'background': False
+ self.assertEqual(
+ index_info,
+ {
+ "txt_1": {"key": [("txt", 1)], "background": False},
+ "_id_": {"key": [("_id", 1)]},
+ "txt2_1": {"key": [("txt2", 1)], "background": False},
+ "_cls_1": {"key": [("_cls", 1)], "background": False},
},
- '_id_': {
- 'key': [('_id', 1)],
- },
- 'txt2_1': {
- 'key': [('txt2', 1)],
- 'background': False
- },
- '_cls_1': {
- 'key': [('_cls', 1)],
- 'background': False,
- }
- })
+ )
def test_compound_index_underscore_cls_not_overwritten(self):
"""
Test that the compound index doesn't get another _cls when it is specified
"""
+
class TestDoc(Document):
shard_1 = StringField()
txt_1 = StringField()
meta = {
- 'collection': 'test',
- 'allow_inheritance': True,
- 'sparse': True,
- 'shard_key': 'shard_1',
- 'indexes': [
- ('shard_1', '_cls', 'txt_1'),
- ]
+ "collection": "test",
+ "allow_inheritance": True,
+ "sparse": True,
+ "shard_key": "shard_1",
+ "indexes": [("shard_1", "_cls", "txt_1")],
}
TestDoc.drop_collection()
TestDoc.ensure_indexes()
index_info = TestDoc._get_collection().index_information()
- self.assertIn('shard_1_1__cls_1_txt_1_1', index_info)
+ self.assertIn("shard_1_1__cls_1_txt_1_1", index_info)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py
index d81039f4..4f21d5f4 100644
--- a/tests/document/inheritance.py
+++ b/tests/document/inheritance.py
@@ -4,18 +4,24 @@ import warnings
from six import iteritems
-from mongoengine import (BooleanField, Document, EmbeddedDocument,
- EmbeddedDocumentField, GenericReferenceField,
- IntField, ReferenceField, StringField)
+from mongoengine import (
+ BooleanField,
+ Document,
+ EmbeddedDocument,
+ EmbeddedDocumentField,
+ GenericReferenceField,
+ IntField,
+ ReferenceField,
+ StringField,
+)
from mongoengine.pymongo_support import list_collection_names
from tests.utils import MongoDBTestCase
from tests.fixtures import Base
-__all__ = ('InheritanceTest', )
+__all__ = ("InheritanceTest",)
class InheritanceTest(MongoDBTestCase):
-
def tearDown(self):
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
@@ -25,16 +31,16 @@ class InheritanceTest(MongoDBTestCase):
# and when object gets reloaded (prevent regression of #1950)
class EmbedData(EmbeddedDocument):
data = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class DataDoc(Document):
name = StringField()
embed = EmbeddedDocumentField(EmbedData)
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
- test_doc = DataDoc(name='test', embed=EmbedData(data='data'))
- self.assertEqual(test_doc._cls, 'DataDoc')
- self.assertEqual(test_doc.embed._cls, 'EmbedData')
+ test_doc = DataDoc(name="test", embed=EmbedData(data="data"))
+ self.assertEqual(test_doc._cls, "DataDoc")
+ self.assertEqual(test_doc.embed._cls, "EmbedData")
test_doc.save()
saved_doc = DataDoc.objects.with_id(test_doc.id)
self.assertEqual(test_doc._cls, saved_doc._cls)
@@ -44,163 +50,234 @@ class InheritanceTest(MongoDBTestCase):
def test_superclasses(self):
"""Ensure that the correct list of superclasses is assembled.
"""
+
class Animal(Document):
- meta = {'allow_inheritance': True}
- class Fish(Animal): pass
- class Guppy(Fish): pass
- class Mammal(Animal): pass
- class Dog(Mammal): pass
- class Human(Mammal): pass
+ meta = {"allow_inheritance": True}
+
+ class Fish(Animal):
+ pass
+
+ class Guppy(Fish):
+ pass
+
+ class Mammal(Animal):
+ pass
+
+ class Dog(Mammal):
+ pass
+
+ class Human(Mammal):
+ pass
self.assertEqual(Animal._superclasses, ())
- self.assertEqual(Fish._superclasses, ('Animal',))
- self.assertEqual(Guppy._superclasses, ('Animal', 'Animal.Fish'))
- self.assertEqual(Mammal._superclasses, ('Animal',))
- self.assertEqual(Dog._superclasses, ('Animal', 'Animal.Mammal'))
- self.assertEqual(Human._superclasses, ('Animal', 'Animal.Mammal'))
+ self.assertEqual(Fish._superclasses, ("Animal",))
+ self.assertEqual(Guppy._superclasses, ("Animal", "Animal.Fish"))
+ self.assertEqual(Mammal._superclasses, ("Animal",))
+ self.assertEqual(Dog._superclasses, ("Animal", "Animal.Mammal"))
+ self.assertEqual(Human._superclasses, ("Animal", "Animal.Mammal"))
def test_external_superclasses(self):
"""Ensure that the correct list of super classes is assembled when
importing part of the model.
"""
- class Animal(Base): pass
- class Fish(Animal): pass
- class Guppy(Fish): pass
- class Mammal(Animal): pass
- class Dog(Mammal): pass
- class Human(Mammal): pass
- self.assertEqual(Animal._superclasses, ('Base', ))
- self.assertEqual(Fish._superclasses, ('Base', 'Base.Animal',))
- self.assertEqual(Guppy._superclasses, ('Base', 'Base.Animal',
- 'Base.Animal.Fish'))
- self.assertEqual(Mammal._superclasses, ('Base', 'Base.Animal',))
- self.assertEqual(Dog._superclasses, ('Base', 'Base.Animal',
- 'Base.Animal.Mammal'))
- self.assertEqual(Human._superclasses, ('Base', 'Base.Animal',
- 'Base.Animal.Mammal'))
+ class Animal(Base):
+ pass
+
+ class Fish(Animal):
+ pass
+
+ class Guppy(Fish):
+ pass
+
+ class Mammal(Animal):
+ pass
+
+ class Dog(Mammal):
+ pass
+
+ class Human(Mammal):
+ pass
+
+ self.assertEqual(Animal._superclasses, ("Base",))
+ self.assertEqual(Fish._superclasses, ("Base", "Base.Animal"))
+ self.assertEqual(
+ Guppy._superclasses, ("Base", "Base.Animal", "Base.Animal.Fish")
+ )
+ self.assertEqual(Mammal._superclasses, ("Base", "Base.Animal"))
+ self.assertEqual(
+ Dog._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal")
+ )
+ self.assertEqual(
+ Human._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal")
+ )
def test_subclasses(self):
"""Ensure that the correct list of _subclasses (subclasses) is
assembled.
"""
- class Animal(Document):
- meta = {'allow_inheritance': True}
- class Fish(Animal): pass
- class Guppy(Fish): pass
- class Mammal(Animal): pass
- class Dog(Mammal): pass
- class Human(Mammal): pass
- self.assertEqual(Animal._subclasses, ('Animal',
- 'Animal.Fish',
- 'Animal.Fish.Guppy',
- 'Animal.Mammal',
- 'Animal.Mammal.Dog',
- 'Animal.Mammal.Human'))
- self.assertEqual(Fish._subclasses, ('Animal.Fish',
- 'Animal.Fish.Guppy',))
- self.assertEqual(Guppy._subclasses, ('Animal.Fish.Guppy',))
- self.assertEqual(Mammal._subclasses, ('Animal.Mammal',
- 'Animal.Mammal.Dog',
- 'Animal.Mammal.Human'))
- self.assertEqual(Human._subclasses, ('Animal.Mammal.Human',))
+ class Animal(Document):
+ meta = {"allow_inheritance": True}
+
+ class Fish(Animal):
+ pass
+
+ class Guppy(Fish):
+ pass
+
+ class Mammal(Animal):
+ pass
+
+ class Dog(Mammal):
+ pass
+
+ class Human(Mammal):
+ pass
+
+ self.assertEqual(
+ Animal._subclasses,
+ (
+ "Animal",
+ "Animal.Fish",
+ "Animal.Fish.Guppy",
+ "Animal.Mammal",
+ "Animal.Mammal.Dog",
+ "Animal.Mammal.Human",
+ ),
+ )
+ self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Guppy"))
+ self.assertEqual(Guppy._subclasses, ("Animal.Fish.Guppy",))
+ self.assertEqual(
+ Mammal._subclasses,
+ ("Animal.Mammal", "Animal.Mammal.Dog", "Animal.Mammal.Human"),
+ )
+ self.assertEqual(Human._subclasses, ("Animal.Mammal.Human",))
def test_external_subclasses(self):
"""Ensure that the correct list of _subclasses (subclasses) is
assembled when importing part of the model.
"""
- class Animal(Base): pass
- class Fish(Animal): pass
- class Guppy(Fish): pass
- class Mammal(Animal): pass
- class Dog(Mammal): pass
- class Human(Mammal): pass
- self.assertEqual(Animal._subclasses, ('Base.Animal',
- 'Base.Animal.Fish',
- 'Base.Animal.Fish.Guppy',
- 'Base.Animal.Mammal',
- 'Base.Animal.Mammal.Dog',
- 'Base.Animal.Mammal.Human'))
- self.assertEqual(Fish._subclasses, ('Base.Animal.Fish',
- 'Base.Animal.Fish.Guppy',))
- self.assertEqual(Guppy._subclasses, ('Base.Animal.Fish.Guppy',))
- self.assertEqual(Mammal._subclasses, ('Base.Animal.Mammal',
- 'Base.Animal.Mammal.Dog',
- 'Base.Animal.Mammal.Human'))
- self.assertEqual(Human._subclasses, ('Base.Animal.Mammal.Human',))
+ class Animal(Base):
+ pass
+
+ class Fish(Animal):
+ pass
+
+ class Guppy(Fish):
+ pass
+
+ class Mammal(Animal):
+ pass
+
+ class Dog(Mammal):
+ pass
+
+ class Human(Mammal):
+ pass
+
+ self.assertEqual(
+ Animal._subclasses,
+ (
+ "Base.Animal",
+ "Base.Animal.Fish",
+ "Base.Animal.Fish.Guppy",
+ "Base.Animal.Mammal",
+ "Base.Animal.Mammal.Dog",
+ "Base.Animal.Mammal.Human",
+ ),
+ )
+ self.assertEqual(
+ Fish._subclasses, ("Base.Animal.Fish", "Base.Animal.Fish.Guppy")
+ )
+ self.assertEqual(Guppy._subclasses, ("Base.Animal.Fish.Guppy",))
+ self.assertEqual(
+ Mammal._subclasses,
+ (
+ "Base.Animal.Mammal",
+ "Base.Animal.Mammal.Dog",
+ "Base.Animal.Mammal.Human",
+ ),
+ )
+ self.assertEqual(Human._subclasses, ("Base.Animal.Mammal.Human",))
def test_dynamic_declarations(self):
"""Test that declaring an extra class updates meta data"""
class Animal(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
self.assertEqual(Animal._superclasses, ())
- self.assertEqual(Animal._subclasses, ('Animal',))
+ self.assertEqual(Animal._subclasses, ("Animal",))
# Test dynamically adding a class changes the meta data
class Fish(Animal):
pass
self.assertEqual(Animal._superclasses, ())
- self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish'))
+ self.assertEqual(Animal._subclasses, ("Animal", "Animal.Fish"))
- self.assertEqual(Fish._superclasses, ('Animal', ))
- self.assertEqual(Fish._subclasses, ('Animal.Fish',))
+ self.assertEqual(Fish._superclasses, ("Animal",))
+ self.assertEqual(Fish._subclasses, ("Animal.Fish",))
# Test dynamically adding an inherited class changes the meta data
class Pike(Fish):
pass
self.assertEqual(Animal._superclasses, ())
- self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish',
- 'Animal.Fish.Pike'))
+ self.assertEqual(
+ Animal._subclasses, ("Animal", "Animal.Fish", "Animal.Fish.Pike")
+ )
- self.assertEqual(Fish._superclasses, ('Animal', ))
- self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike'))
+ self.assertEqual(Fish._superclasses, ("Animal",))
+ self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Pike"))
- self.assertEqual(Pike._superclasses, ('Animal', 'Animal.Fish'))
- self.assertEqual(Pike._subclasses, ('Animal.Fish.Pike',))
+ self.assertEqual(Pike._superclasses, ("Animal", "Animal.Fish"))
+ self.assertEqual(Pike._subclasses, ("Animal.Fish.Pike",))
def test_inheritance_meta_data(self):
"""Ensure that document may inherit fields from a superclass document.
"""
+
class Person(Document):
name = StringField()
age = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Employee(Person):
salary = IntField()
- self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'],
- sorted(Employee._fields.keys()))
- self.assertEqual(Employee._get_collection_name(),
- Person._get_collection_name())
+ self.assertEqual(
+ ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys())
+ )
+ self.assertEqual(Employee._get_collection_name(), Person._get_collection_name())
def test_inheritance_to_mongo_keys(self):
"""Ensure that document may inherit fields from a superclass document.
"""
+
class Person(Document):
name = StringField()
age = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Employee(Person):
salary = IntField()
- self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'],
- sorted(Employee._fields.keys()))
- self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
- ['_cls', 'name', 'age'])
- self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
- ['_cls', 'name', 'age', 'salary'])
- self.assertEqual(Employee._get_collection_name(),
- Person._get_collection_name())
+ self.assertEqual(
+ ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys())
+ )
+ self.assertEqual(
+ Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"]
+ )
+ self.assertEqual(
+ Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
+ ["_cls", "name", "age", "salary"],
+ )
+ self.assertEqual(Employee._get_collection_name(), Person._get_collection_name())
def test_indexes_and_multiple_inheritance(self):
""" Ensure that all of the indexes are created for a document with
@@ -210,18 +287,12 @@ class InheritanceTest(MongoDBTestCase):
class A(Document):
a = StringField()
- meta = {
- 'allow_inheritance': True,
- 'indexes': ['a']
- }
+ meta = {"allow_inheritance": True, "indexes": ["a"]}
class B(Document):
b = StringField()
- meta = {
- 'allow_inheritance': True,
- 'indexes': ['b']
- }
+ meta = {"allow_inheritance": True, "indexes": ["b"]}
class C(A, B):
pass
@@ -233,8 +304,12 @@ class InheritanceTest(MongoDBTestCase):
C.ensure_indexes()
self.assertEqual(
- sorted([idx['key'] for idx in C._get_collection().index_information().values()]),
- sorted([[(u'_cls', 1), (u'b', 1)], [(u'_id', 1)], [(u'_cls', 1), (u'a', 1)]])
+ sorted(
+ [idx["key"] for idx in C._get_collection().index_information().values()]
+ ),
+ sorted(
+ [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]]
+ ),
)
def test_polymorphic_queries(self):
@@ -242,11 +317,19 @@ class InheritanceTest(MongoDBTestCase):
"""
class Animal(Document):
- meta = {'allow_inheritance': True}
- class Fish(Animal): pass
- class Mammal(Animal): pass
- class Dog(Mammal): pass
- class Human(Mammal): pass
+ meta = {"allow_inheritance": True}
+
+ class Fish(Animal):
+ pass
+
+ class Mammal(Animal):
+ pass
+
+ class Dog(Mammal):
+ pass
+
+ class Human(Mammal):
+ pass
Animal.drop_collection()
@@ -269,58 +352,68 @@ class InheritanceTest(MongoDBTestCase):
"""Ensure that inheritance is disabled by default on simple
classes and that _cls will not be used.
"""
+
class Animal(Document):
name = StringField()
# can't inherit because Animal didn't explicitly allow inheritance
with self.assertRaises(ValueError) as cm:
+
class Dog(Animal):
pass
+
self.assertIn("Document Animal may not be subclassed", str(cm.exception))
# Check that _cls etc aren't present on simple documents
- dog = Animal(name='dog').save()
- self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
+ dog = Animal(name="dog").save()
+ self.assertEqual(dog.to_mongo().keys(), ["_id", "name"])
collection = self.db[Animal._get_collection_name()]
obj = collection.find_one()
- self.assertNotIn('_cls', obj)
+ self.assertNotIn("_cls", obj)
def test_cant_turn_off_inheritance_on_subclass(self):
"""Ensure if inheritance is on in a subclass you cant turn it off.
"""
+
class Animal(Document):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
with self.assertRaises(ValueError) as cm:
+
class Mammal(Animal):
- meta = {'allow_inheritance': False}
- self.assertEqual(str(cm.exception), 'Only direct subclasses of Document may set "allow_inheritance" to False')
+ meta = {"allow_inheritance": False}
+
+ self.assertEqual(
+ str(cm.exception),
+ 'Only direct subclasses of Document may set "allow_inheritance" to False',
+ )
def test_allow_inheritance_abstract_document(self):
"""Ensure that abstract documents can set inheritance rules and that
_cls will not be used.
"""
+
class FinalDocument(Document):
- meta = {'abstract': True,
- 'allow_inheritance': False}
+ meta = {"abstract": True, "allow_inheritance": False}
class Animal(FinalDocument):
name = StringField()
with self.assertRaises(ValueError) as cm:
+
class Mammal(Animal):
pass
# Check that _cls isn't present in simple documents
- doc = Animal(name='dog')
- self.assertNotIn('_cls', doc.to_mongo())
+ doc = Animal(name="dog")
+ self.assertNotIn("_cls", doc.to_mongo())
def test_using_abstract_class_in_reference_field(self):
# Ensures no regression of #1920
class AbstractHuman(Document):
- meta = {'abstract': True}
+ meta = {"abstract": True}
class Dad(AbstractHuman):
name = StringField()
@@ -329,130 +422,122 @@ class InheritanceTest(MongoDBTestCase):
dad = ReferenceField(AbstractHuman) # Referencing the abstract class
address = StringField()
- dad = Dad(name='5').save()
- Home(dad=dad, address='street').save()
+ dad = Dad(name="5").save()
+ Home(dad=dad, address="street").save()
home = Home.objects.first()
- home.address = 'garbage'
- home.save() # Was failing with ValidationError
+ home.address = "garbage"
+ home.save() # Was failing with ValidationError
def test_abstract_class_referencing_self(self):
# Ensures no regression of #1920
class Human(Document):
- meta = {'abstract': True}
- creator = ReferenceField('self', dbref=True)
+ meta = {"abstract": True}
+ creator = ReferenceField("self", dbref=True)
class User(Human):
name = StringField()
- user = User(name='John').save()
- user2 = User(name='Foo', creator=user).save()
+ user = User(name="John").save()
+ user2 = User(name="Foo", creator=user).save()
user2 = User.objects.with_id(user2.id)
- user2.name = 'Bar'
- user2.save() # Was failing with ValidationError
+ user2.name = "Bar"
+ user2.save() # Was failing with ValidationError
def test_abstract_handle_ids_in_metaclass_properly(self):
-
class City(Document):
continent = StringField()
- meta = {'abstract': True,
- 'allow_inheritance': False}
+ meta = {"abstract": True, "allow_inheritance": False}
class EuropeanCity(City):
name = StringField()
- berlin = EuropeanCity(name='Berlin', continent='Europe')
+ berlin = EuropeanCity(name="Berlin", continent="Europe")
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 3)
- self.assertEqual(berlin._fields_ordered[0], 'id')
+ self.assertEqual(berlin._fields_ordered[0], "id")
def test_auto_id_not_set_if_specific_in_parent_class(self):
-
class City(Document):
continent = StringField()
city_id = IntField(primary_key=True)
- meta = {'abstract': True,
- 'allow_inheritance': False}
+ meta = {"abstract": True, "allow_inheritance": False}
class EuropeanCity(City):
name = StringField()
- berlin = EuropeanCity(name='Berlin', continent='Europe')
+ berlin = EuropeanCity(name="Berlin", continent="Europe")
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 3)
- self.assertEqual(berlin._fields_ordered[0], 'city_id')
+ self.assertEqual(berlin._fields_ordered[0], "city_id")
def test_auto_id_vs_non_pk_id_field(self):
-
class City(Document):
continent = StringField()
id = IntField()
- meta = {'abstract': True,
- 'allow_inheritance': False}
+ meta = {"abstract": True, "allow_inheritance": False}
class EuropeanCity(City):
name = StringField()
- berlin = EuropeanCity(name='Berlin', continent='Europe')
+ berlin = EuropeanCity(name="Berlin", continent="Europe")
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 4)
- self.assertEqual(berlin._fields_ordered[0], 'auto_id_0')
+ self.assertEqual(berlin._fields_ordered[0], "auto_id_0")
berlin.save()
self.assertEqual(berlin.pk, berlin.auto_id_0)
def test_abstract_document_creation_does_not_fail(self):
class City(Document):
continent = StringField()
- meta = {'abstract': True,
- 'allow_inheritance': False}
+ meta = {"abstract": True, "allow_inheritance": False}
- city = City(continent='asia')
+ city = City(continent="asia")
self.assertEqual(None, city.pk)
# TODO: expected error? Shouldn't we create a new error type?
with self.assertRaises(KeyError):
- setattr(city, 'pk', 1)
+ setattr(city, "pk", 1)
def test_allow_inheritance_embedded_document(self):
"""Ensure embedded documents respect inheritance."""
+
class Comment(EmbeddedDocument):
content = StringField()
with self.assertRaises(ValueError):
+
class SpecialComment(Comment):
pass
- doc = Comment(content='test')
- self.assertNotIn('_cls', doc.to_mongo())
+ doc = Comment(content="test")
+ self.assertNotIn("_cls", doc.to_mongo())
class Comment(EmbeddedDocument):
content = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
- doc = Comment(content='test')
- self.assertIn('_cls', doc.to_mongo())
+ doc = Comment(content="test")
+ self.assertIn("_cls", doc.to_mongo())
def test_document_inheritance(self):
"""Ensure mutliple inheritance of abstract documents
"""
+
class DateCreatedDocument(Document):
- meta = {
- 'allow_inheritance': True,
- 'abstract': True,
- }
+ meta = {"allow_inheritance": True, "abstract": True}
class DateUpdatedDocument(Document):
- meta = {
- 'allow_inheritance': True,
- 'abstract': True,
- }
+ meta = {"allow_inheritance": True, "abstract": True}
try:
+
class MyDocument(DateCreatedDocument, DateUpdatedDocument):
pass
+
except Exception:
self.assertTrue(False, "Couldn't create MyDocument class")
@@ -460,47 +545,55 @@ class InheritanceTest(MongoDBTestCase):
"""Ensure that a document superclass can be marked as abstract
thereby not using it as the name for the collection."""
- defaults = {'index_background': True,
- 'index_drop_dups': True,
- 'index_opts': {'hello': 'world'},
- 'allow_inheritance': True,
- 'queryset_class': 'QuerySet',
- 'db_alias': 'myDB',
- 'shard_key': ('hello', 'world')}
+ defaults = {
+ "index_background": True,
+ "index_drop_dups": True,
+ "index_opts": {"hello": "world"},
+ "allow_inheritance": True,
+ "queryset_class": "QuerySet",
+ "db_alias": "myDB",
+ "shard_key": ("hello", "world"),
+ }
- meta_settings = {'abstract': True}
+ meta_settings = {"abstract": True}
meta_settings.update(defaults)
class Animal(Document):
name = StringField()
meta = meta_settings
- class Fish(Animal): pass
- class Guppy(Fish): pass
+ class Fish(Animal):
+ pass
+
+ class Guppy(Fish):
+ pass
class Mammal(Animal):
- meta = {'abstract': True}
- class Human(Mammal): pass
+ meta = {"abstract": True}
+
+ class Human(Mammal):
+ pass
for k, v in iteritems(defaults):
for cls in [Animal, Fish, Guppy]:
self.assertEqual(cls._meta[k], v)
- self.assertNotIn('collection', Animal._meta)
- self.assertNotIn('collection', Mammal._meta)
+ self.assertNotIn("collection", Animal._meta)
+ self.assertNotIn("collection", Mammal._meta)
self.assertEqual(Animal._get_collection_name(), None)
self.assertEqual(Mammal._get_collection_name(), None)
- self.assertEqual(Fish._get_collection_name(), 'fish')
- self.assertEqual(Guppy._get_collection_name(), 'fish')
- self.assertEqual(Human._get_collection_name(), 'human')
+ self.assertEqual(Fish._get_collection_name(), "fish")
+ self.assertEqual(Guppy._get_collection_name(), "fish")
+ self.assertEqual(Human._get_collection_name(), "human")
# ensure that a subclass of a non-abstract class can't be abstract
with self.assertRaises(ValueError):
+
class EvilHuman(Human):
evil = BooleanField(default=True)
- meta = {'abstract': True}
+ meta = {"abstract": True}
def test_abstract_embedded_documents(self):
# 789: EmbeddedDocument shouldn't inherit abstract
@@ -519,7 +612,7 @@ class InheritanceTest(MongoDBTestCase):
class Drink(Document):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Drinker(Document):
drink = GenericReferenceField()
@@ -528,13 +621,13 @@ class InheritanceTest(MongoDBTestCase):
warnings.simplefilter("error")
class AcloholicDrink(Drink):
- meta = {'collection': 'booze'}
+ meta = {"collection": "booze"}
except SyntaxWarning:
warnings.simplefilter("ignore")
class AlcoholicDrink(Drink):
- meta = {'collection': 'booze'}
+ meta = {"collection": "booze"}
else:
raise AssertionError("SyntaxWarning should be triggered")
@@ -545,13 +638,13 @@ class InheritanceTest(MongoDBTestCase):
AlcoholicDrink.drop_collection()
Drinker.drop_collection()
- red_bull = Drink(name='Red Bull')
+ red_bull = Drink(name="Red Bull")
red_bull.save()
programmer = Drinker(drink=red_bull)
programmer.save()
- beer = AlcoholicDrink(name='Beer')
+ beer = AlcoholicDrink(name="Beer")
beer.save()
real_person = Drinker(drink=beer)
real_person.save()
@@ -560,5 +653,5 @@ class InheritanceTest(MongoDBTestCase):
self.assertEqual(Drinker.objects[1].drink.name, beer.name)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/instance.py b/tests/document/instance.py
index 06f65076..49606cff 100644
--- a/tests/document/instance.py
+++ b/tests/document/instance.py
@@ -16,24 +16,33 @@ from mongoengine import signals
from mongoengine.base import _document_registry, get_document
from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, switch_db
-from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, \
- InvalidQueryError, NotRegistered, NotUniqueError, SaveConditionError)
+from mongoengine.errors import (
+ FieldDoesNotExist,
+ InvalidDocumentError,
+ InvalidQueryError,
+ NotRegistered,
+ NotUniqueError,
+ SaveConditionError,
+)
from mongoengine.mongodb_support import MONGODB_34, MONGODB_36, get_mongodb_version
from mongoengine.pymongo_support import list_collection_names
from mongoengine.queryset import NULLIFY, Q
from tests import fixtures
-from tests.fixtures import (PickleDynamicEmbedded, PickleDynamicTest, \
- PickleEmbedded, PickleSignalsTest, PickleTest)
+from tests.fixtures import (
+ PickleDynamicEmbedded,
+ PickleDynamicTest,
+ PickleEmbedded,
+ PickleSignalsTest,
+ PickleTest,
+)
from tests.utils import MongoDBTestCase, get_as_pymongo
-TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
- '../fields/mongoengine.png')
+TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png")
__all__ = ("InstanceTest",)
class InstanceTest(MongoDBTestCase):
-
def setUp(self):
class Job(EmbeddedDocument):
name = StringField()
@@ -58,7 +67,8 @@ class InstanceTest(MongoDBTestCase):
def assertDbEqual(self, docs):
self.assertEqual(
list(self.Person._get_collection().find().sort("id")),
- sorted(docs, key=lambda doc: doc["_id"]))
+ sorted(docs, key=lambda doc: doc["_id"]),
+ )
def assertHasInstance(self, field, instance):
self.assertTrue(hasattr(field, "_instance"))
@@ -70,12 +80,10 @@ class InstanceTest(MongoDBTestCase):
def test_capped_collection(self):
"""Ensure that capped collections work properly."""
+
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_documents': 10,
- 'max_size': 4096,
- }
+ meta = {"max_documents": 10, "max_size": 4096}
Log.drop_collection()
@@ -90,16 +98,14 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(Log.objects.count(), 10)
options = Log.objects._collection.options()
- self.assertEqual(options['capped'], True)
- self.assertEqual(options['max'], 10)
- self.assertEqual(options['size'], 4096)
+ self.assertEqual(options["capped"], True)
+ self.assertEqual(options["max"], 10)
+ self.assertEqual(options["size"], 4096)
# Check that the document cannot be redefined with different options
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_documents': 11,
- }
+ meta = {"max_documents": 11}
# Accessing Document.objects creates the collection
with self.assertRaises(InvalidCollectionError):
@@ -107,11 +113,10 @@ class InstanceTest(MongoDBTestCase):
def test_capped_collection_default(self):
"""Ensure that capped collections defaults work properly."""
+
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_documents': 10,
- }
+ meta = {"max_documents": 10}
Log.drop_collection()
@@ -119,16 +124,14 @@ class InstanceTest(MongoDBTestCase):
Log().save()
options = Log.objects._collection.options()
- self.assertEqual(options['capped'], True)
- self.assertEqual(options['max'], 10)
- self.assertEqual(options['size'], 10 * 2**20)
+ self.assertEqual(options["capped"], True)
+ self.assertEqual(options["max"], 10)
+ self.assertEqual(options["size"], 10 * 2 ** 20)
# Check that the document with default value can be recreated
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_documents': 10,
- }
+ meta = {"max_documents": 10}
# Create the collection by accessing Document.objects
Log.objects
@@ -138,11 +141,10 @@ class InstanceTest(MongoDBTestCase):
MongoDB rounds up max_size to next multiple of 256, recreating a doc
with the same spec failed in mongoengine <0.10
"""
+
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_size': 10000,
- }
+ meta = {"max_size": 10000}
Log.drop_collection()
@@ -150,15 +152,13 @@ class InstanceTest(MongoDBTestCase):
Log().save()
options = Log.objects._collection.options()
- self.assertEqual(options['capped'], True)
- self.assertTrue(options['size'] >= 10000)
+ self.assertEqual(options["capped"], True)
+ self.assertTrue(options["size"] >= 10000)
# Check that the document with odd max_size value can be recreated
class Log(Document):
date = DateTimeField(default=datetime.now)
- meta = {
- 'max_size': 10000,
- }
+ meta = {"max_size": 10000}
# Create the collection by accessing Document.objects
Log.objects
@@ -166,26 +166,28 @@ class InstanceTest(MongoDBTestCase):
def test_repr(self):
"""Ensure that unicode representation works
"""
+
class Article(Document):
title = StringField()
def __unicode__(self):
return self.title
- doc = Article(title=u'привет мир')
+ doc = Article(title=u"привет мир")
- self.assertEqual('', repr(doc))
+ self.assertEqual("", repr(doc))
def test_repr_none(self):
"""Ensure None values are handled correctly."""
+
class Article(Document):
title = StringField()
def __str__(self):
return None
- doc = Article(title=u'привет мир')
- self.assertEqual('', repr(doc))
+ doc = Article(title=u"привет мир")
+ self.assertEqual("", repr(doc))
def test_queryset_resurrects_dropped_collection(self):
self.Person.drop_collection()
@@ -203,8 +205,9 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that the correct subclasses are returned from a query
when using references / generic references
"""
+
class Animal(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Fish(Animal):
pass
@@ -255,7 +258,7 @@ class InstanceTest(MongoDBTestCase):
class Stats(Document):
created = DateTimeField(default=datetime.now)
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class CompareStats(Document):
generated = DateTimeField(default=datetime.now)
@@ -278,6 +281,7 @@ class InstanceTest(MongoDBTestCase):
def test_db_field_load(self):
"""Ensure we load data correctly from the right db field."""
+
class Person(Document):
name = StringField(required=True)
_rank = StringField(required=False, db_field="rank")
@@ -297,14 +301,13 @@ class InstanceTest(MongoDBTestCase):
def test_db_embedded_doc_field_load(self):
"""Ensure we load embedded document data correctly."""
+
class Rank(EmbeddedDocument):
title = StringField(required=True)
class Person(Document):
name = StringField(required=True)
- rank_ = EmbeddedDocumentField(Rank,
- required=False,
- db_field='rank')
+ rank_ = EmbeddedDocumentField(Rank, required=False, db_field="rank")
@property
def rank(self):
@@ -322,45 +325,50 @@ class InstanceTest(MongoDBTestCase):
def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys."""
+
class User(Document):
username = StringField(primary_key=True)
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
User.drop_collection()
- self.assertEqual(User._fields['username'].db_field, '_id')
- self.assertEqual(User._meta['id_field'], 'username')
+ self.assertEqual(User._fields["username"].db_field, "_id")
+ self.assertEqual(User._meta["id_field"], "username")
- User.objects.create(username='test', name='test user')
+ User.objects.create(username="test", name="test user")
user = User.objects.first()
- self.assertEqual(user.id, 'test')
- self.assertEqual(user.pk, 'test')
+ self.assertEqual(user.id, "test")
+ self.assertEqual(user.pk, "test")
user_dict = User.objects._collection.find_one()
- self.assertEqual(user_dict['_id'], 'test')
+ self.assertEqual(user_dict["_id"], "test")
def test_change_custom_id_field_in_subclass(self):
"""Subclasses cannot override which field is the primary key."""
+
class User(Document):
username = StringField(primary_key=True)
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
with self.assertRaises(ValueError) as e:
+
class EmailUser(User):
email = StringField(primary_key=True)
+
exc = e.exception
- self.assertEqual(str(exc), 'Cannot override primary key field')
+ self.assertEqual(str(exc), "Cannot override primary key field")
def test_custom_id_field_is_required(self):
"""Ensure the custom primary key field is required."""
+
class User(Document):
username = StringField(primary_key=True)
name = StringField()
with self.assertRaises(ValidationError) as e:
- User(name='test').save()
+ User(name="test").save()
exc = e.exception
self.assertTrue("Field is required: ['username']" in str(exc))
@@ -368,7 +376,7 @@ class InstanceTest(MongoDBTestCase):
class Place(Document):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class NicePlace(Place):
pass
@@ -380,7 +388,7 @@ class InstanceTest(MongoDBTestCase):
# Mimic Place and NicePlace definitions being in a different file
# and the NicePlace model not being imported in at query time.
- del(_document_registry['Place.NicePlace'])
+ del _document_registry["Place.NicePlace"]
with self.assertRaises(NotRegistered):
list(Place.objects.all())
@@ -388,10 +396,10 @@ class InstanceTest(MongoDBTestCase):
def test_document_registry_regressions(self):
class Location(Document):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Area(Location):
- location = ReferenceField('Location', dbref=True)
+ location = ReferenceField("Location", dbref=True)
Location.drop_collection()
@@ -413,18 +421,19 @@ class InstanceTest(MongoDBTestCase):
def test_key_like_attribute_access(self):
person = self.Person(age=30)
- self.assertEqual(person['age'], 30)
+ self.assertEqual(person["age"], 30)
with self.assertRaises(KeyError):
- person['unknown_attr']
+ person["unknown_attr"]
def test_save_abstract_document(self):
"""Saving an abstract document should fail."""
+
class Doc(Document):
name = StringField()
- meta = {'abstract': True}
+ meta = {"abstract": True}
with self.assertRaises(InvalidDocumentError):
- Doc(name='aaa').save()
+ Doc(name="aaa").save()
def test_reload(self):
"""Ensure that attributes may be reloaded."""
@@ -439,7 +448,7 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 20)
- person.reload('age')
+ person.reload("age")
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 21)
@@ -454,19 +463,22 @@ class InstanceTest(MongoDBTestCase):
def test_reload_sharded(self):
class Animal(Document):
superphylum = StringField()
- meta = {'shard_key': ('superphylum',)}
+ meta = {"shard_key": ("superphylum",)}
Animal.drop_collection()
- doc = Animal(superphylum='Deuterostomia')
+ doc = Animal(superphylum="Deuterostomia")
doc.save()
mongo_db = get_mongodb_version()
- CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query'
+ CMD_QUERY_KEY = "command" if mongo_db >= MONGODB_36 else "query"
with query_counter() as q:
doc.reload()
- query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
- self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum']))
+ query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0]
+ self.assertEqual(
+ set(query_op[CMD_QUERY_KEY]["filter"].keys()),
+ set(["_id", "superphylum"]),
+ )
Animal.drop_collection()
@@ -476,10 +488,10 @@ class InstanceTest(MongoDBTestCase):
class Animal(Document):
superphylum = EmbeddedDocumentField(SuperPhylum)
- meta = {'shard_key': ('superphylum.name',)}
+ meta = {"shard_key": ("superphylum.name",)}
Animal.drop_collection()
- doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
+ doc = Animal(superphylum=SuperPhylum(name="Deuterostomia"))
doc.save()
doc.reload()
Animal.drop_collection()
@@ -488,49 +500,57 @@ class InstanceTest(MongoDBTestCase):
"""Ensures updating a doc with a specified shard_key includes it in
the query.
"""
+
class Animal(Document):
is_mammal = BooleanField()
name = StringField()
- meta = {'shard_key': ('is_mammal', 'id')}
+ meta = {"shard_key": ("is_mammal", "id")}
Animal.drop_collection()
- doc = Animal(is_mammal=True, name='Dog')
+ doc = Animal(is_mammal=True, name="Dog")
doc.save()
mongo_db = get_mongodb_version()
with query_counter() as q:
- doc.name = 'Cat'
+ doc.name = "Cat"
doc.save()
- query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
- self.assertEqual(query_op['op'], 'update')
+ query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0]
+ self.assertEqual(query_op["op"], "update")
if mongo_db <= MONGODB_34:
- self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal']))
+ self.assertEqual(
+ set(query_op["query"].keys()), set(["_id", "is_mammal"])
+ )
else:
- self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal']))
+ self.assertEqual(
+ set(query_op["command"]["q"].keys()), set(["_id", "is_mammal"])
+ )
Animal.drop_collection()
def test_reload_with_changed_fields(self):
"""Ensures reloading will not affect changed fields"""
+
class User(Document):
name = StringField()
number = IntField()
+
User.drop_collection()
user = User(name="Bob", number=1).save()
user.name = "John"
user.number = 2
- self.assertEqual(user._get_changed_fields(), ['name', 'number'])
- user.reload('number')
- self.assertEqual(user._get_changed_fields(), ['name'])
+ self.assertEqual(user._get_changed_fields(), ["name", "number"])
+ user.reload("number")
+ self.assertEqual(user._get_changed_fields(), ["name"])
user.save()
user.reload()
self.assertEqual(user.name, "John")
def test_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly."""
+
class Embedded(EmbeddedDocument):
dict_field = DictField()
list_field = ListField()
@@ -542,24 +562,30 @@ class InstanceTest(MongoDBTestCase):
Doc.drop_collection()
doc = Doc()
- doc.dict_field = {'hello': 'world'}
- doc.list_field = ['1', 2, {'hello': 'world'}]
+ doc.dict_field = {"hello": "world"}
+ doc.list_field = ["1", 2, {"hello": "world"}]
embedded_1 = Embedded()
- embedded_1.dict_field = {'hello': 'world'}
- embedded_1.list_field = ['1', 2, {'hello': 'world'}]
+ embedded_1.dict_field = {"hello": "world"}
+ embedded_1.list_field = ["1", 2, {"hello": "world"}]
doc.embedded_field = embedded_1
doc.save()
doc = doc.reload(10)
doc.list_field.append(1)
- doc.dict_field['woot'] = "woot"
+ doc.dict_field["woot"] = "woot"
doc.embedded_field.list_field.append(1)
- doc.embedded_field.dict_field['woot'] = "woot"
+ doc.embedded_field.dict_field["woot"] = "woot"
- self.assertEqual(doc._get_changed_fields(), [
- 'list_field', 'dict_field.woot', 'embedded_field.list_field',
- 'embedded_field.dict_field.woot'])
+ self.assertEqual(
+ doc._get_changed_fields(),
+ [
+ "list_field",
+ "dict_field.woot",
+ "embedded_field.list_field",
+ "embedded_field.dict_field.woot",
+ ],
+ )
doc.save()
self.assertEqual(len(doc.list_field), 4)
@@ -572,9 +598,9 @@ class InstanceTest(MongoDBTestCase):
doc.list_field.append(1)
doc.save()
- doc.dict_field['extra'] = 1
- doc = doc.reload(10, 'list_field')
- self.assertEqual(doc._get_changed_fields(), ['dict_field.extra'])
+ doc.dict_field["extra"] = 1
+ doc = doc.reload(10, "list_field")
+ self.assertEqual(doc._get_changed_fields(), ["dict_field.extra"])
self.assertEqual(len(doc.list_field), 5)
self.assertEqual(len(doc.dict_field), 3)
self.assertEqual(len(doc.embedded_field.list_field), 4)
@@ -596,19 +622,17 @@ class InstanceTest(MongoDBTestCase):
def test_reload_of_non_strict_with_special_field_name(self):
"""Ensures reloading works for documents with meta strict == False."""
+
class Post(Document):
- meta = {
- 'strict': False
- }
+ meta = {"strict": False}
title = StringField()
items = ListField()
Post.drop_collection()
- Post._get_collection().insert_one({
- "title": "Items eclipse",
- "items": ["more lorem", "even more ipsum"]
- })
+ Post._get_collection().insert_one(
+ {"title": "Items eclipse", "items": ["more lorem", "even more ipsum"]}
+ )
post = Post.objects.first()
post.reload()
@@ -617,22 +641,22 @@ class InstanceTest(MongoDBTestCase):
def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly."""
- person = self.Person(name='Test User', age=30, job=self.Job())
- self.assertEqual(person['name'], 'Test User')
+ person = self.Person(name="Test User", age=30, job=self.Job())
+ self.assertEqual(person["name"], "Test User")
- self.assertRaises(KeyError, person.__getitem__, 'salary')
- self.assertRaises(KeyError, person.__setitem__, 'salary', 50)
+ self.assertRaises(KeyError, person.__getitem__, "salary")
+ self.assertRaises(KeyError, person.__setitem__, "salary", 50)
- person['name'] = 'Another User'
- self.assertEqual(person['name'], 'Another User')
+ person["name"] = "Another User"
+ self.assertEqual(person["name"], "Another User")
# Length = length(assigned fields + id)
self.assertEqual(len(person), 5)
- self.assertIn('age', person)
+ self.assertIn("age", person)
person.age = None
- self.assertNotIn('age', person)
- self.assertNotIn('nationality', person)
+ self.assertNotIn("age", person)
+ self.assertNotIn("nationality", person)
def test_embedded_document_to_mongo(self):
class Person(EmbeddedDocument):
@@ -644,29 +668,33 @@ class InstanceTest(MongoDBTestCase):
class Employee(Person):
salary = IntField()
- self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
- ['_cls', 'name', 'age'])
+ self.assertEqual(
+ Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"]
+ )
self.assertEqual(
Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
- ['_cls', 'name', 'age', 'salary'])
+ ["_cls", "name", "age", "salary"],
+ )
def test_embedded_document_to_mongo_id(self):
class SubDoc(EmbeddedDocument):
id = StringField(required=True)
sub_doc = SubDoc(id="abc")
- self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
+ self.assertEqual(sub_doc.to_mongo().keys(), ["id"])
def test_embedded_document(self):
"""Ensure that embedded documents are set up correctly."""
+
class Comment(EmbeddedDocument):
content = StringField()
- self.assertIn('content', Comment._fields)
- self.assertNotIn('id', Comment._fields)
+ self.assertIn("content", Comment._fields)
+ self.assertNotIn("id", Comment._fields)
def test_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance."""
+
class Embedded(EmbeddedDocument):
string = StringField()
@@ -686,6 +714,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that embedded documents in complex fields can reference
parent instance.
"""
+
class Embedded(EmbeddedDocument):
string = StringField()
@@ -702,15 +731,19 @@ class InstanceTest(MongoDBTestCase):
def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs."""
+
class Embedded(EmbeddedDocument):
- string = StringField(db_field='s')
+ string = StringField(db_field="s")
class Doc(Document):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
- d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo(
- use_db_field=False).to_dict()
- self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
+ d = (
+ Doc(embedded_field=[Embedded(string="Hi")])
+ .to_mongo(use_db_field=False)
+ .to_dict()
+ )
+ self.assertEqual(d["embedded_field"], [{"string": "Hi"}])
def test_instance_is_set_on_setattr(self):
class Email(EmbeddedDocument):
@@ -722,7 +755,7 @@ class InstanceTest(MongoDBTestCase):
Account.drop_collection()
acc = Account()
- acc.email = Email(email='test@example.com')
+ acc.email = Email(email="test@example.com")
self.assertHasInstance(acc._data["email"], acc)
acc.save()
@@ -738,7 +771,7 @@ class InstanceTest(MongoDBTestCase):
Account.drop_collection()
acc = Account()
- acc.emails = [Email(email='test@example.com')]
+ acc.emails = [Email(email="test@example.com")]
self.assertHasInstance(acc._data["emails"][0], acc)
acc.save()
@@ -764,22 +797,19 @@ class InstanceTest(MongoDBTestCase):
@classmethod
def pre_save_post_validation(cls, sender, document, **kwargs):
- document.content = 'checked'
+ document.content = "checked"
- signals.pre_save_post_validation.connect(BlogPost.pre_save_post_validation, sender=BlogPost)
+ signals.pre_save_post_validation.connect(
+ BlogPost.pre_save_post_validation, sender=BlogPost
+ )
BlogPost.drop_collection()
- post = BlogPost(content='unchecked').save()
- self.assertEqual(post.content, 'checked')
+ post = BlogPost(content="unchecked").save()
+ self.assertEqual(post.content, "checked")
# Make sure pre_save_post_validation changes makes it to the db
raw_doc = get_as_pymongo(post)
- self.assertEqual(
- raw_doc,
- {
- 'content': 'checked',
- '_id': post.id
- })
+ self.assertEqual(raw_doc, {"content": "checked", "_id": post.id})
# Important to disconnect as it could cause some assertions in test_signals
# to fail (due to the garbage collection timing of this signal)
@@ -810,13 +840,7 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(t.cleaned, True)
raw_doc = get_as_pymongo(t)
# Make sure clean changes makes it to the db
- self.assertEqual(
- raw_doc,
- {
- 'status': 'published',
- 'cleaned': True,
- '_id': t.id
- })
+ self.assertEqual(raw_doc, {"status": "published", "cleaned": True, "_id": t.id})
def test_document_embedded_clean(self):
class TestEmbeddedDocument(EmbeddedDocument):
@@ -824,12 +848,12 @@ class InstanceTest(MongoDBTestCase):
y = IntField(required=True)
z = IntField(required=True)
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
def clean(self):
if self.z:
if self.z != self.x + self.y:
- raise ValidationError('Value of z != x + y')
+ raise ValidationError("Value of z != x + y")
else:
self.z = self.x + self.y
@@ -846,7 +870,7 @@ class InstanceTest(MongoDBTestCase):
expected_msg = "Value of z != x + y"
self.assertIn(expected_msg, cm.exception.message)
- self.assertEqual(cm.exception.to_dict(), {'doc': {'__all__': expected_msg}})
+ self.assertEqual(cm.exception.to_dict(), {"doc": {"__all__": expected_msg}})
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save()
self.assertEqual(t.doc.z, 35)
@@ -869,7 +893,7 @@ class InstanceTest(MongoDBTestCase):
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
with self.assertRaises(InvalidQueryError):
- doc1.modify({'id': doc2.id}, set__value=20)
+ doc1.modify({"id": doc2.id}, set__value=20)
self.assertDbEqual(docs)
@@ -878,7 +902,7 @@ class InstanceTest(MongoDBTestCase):
doc2 = self.Person(name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
- n_modified = doc1.modify({'name': doc2.name}, set__age=100)
+ n_modified = doc1.modify({"name": doc2.name}, set__age=100)
self.assertEqual(n_modified, 0)
self.assertDbEqual(docs)
@@ -888,7 +912,7 @@ class InstanceTest(MongoDBTestCase):
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
docs = [dict(doc1.to_mongo())]
- n_modified = doc2.modify({'name': doc2.name}, set__age=100)
+ n_modified = doc2.modify({"name": doc2.name}, set__age=100)
self.assertEqual(n_modified, 0)
self.assertDbEqual(docs)
@@ -896,7 +920,8 @@ class InstanceTest(MongoDBTestCase):
def test_modify_update(self):
other_doc = self.Person(name="bob", age=10).save()
doc = self.Person(
- name="jim", age=20, job=self.Job(name="10gen", years=3)).save()
+ name="jim", age=20, job=self.Job(name="10gen", years=3)
+ ).save()
doc_copy = doc._from_son(doc.to_mongo())
@@ -906,7 +931,8 @@ class InstanceTest(MongoDBTestCase):
doc.job.years = 3
n_modified = doc.modify(
- set__age=21, set__job__name="MongoDB", unset__job__years=True)
+ set__age=21, set__job__name="MongoDB", unset__job__years=True
+ )
self.assertEqual(n_modified, 1)
doc_copy.age = 21
doc_copy.job.name = "MongoDB"
@@ -926,62 +952,56 @@ class InstanceTest(MongoDBTestCase):
content = EmbeddedDocumentField(Content)
post = BlogPost.objects.create(
- tags=['python'], content=Content(keywords=['ipsum']))
-
- self.assertEqual(post.tags, ['python'])
- post.modify(push__tags__0=['code', 'mongo'])
- self.assertEqual(post.tags, ['code', 'mongo', 'python'])
-
- # Assert same order of the list items is maintained in the db
- self.assertEqual(
- BlogPost._get_collection().find_one({'_id': post.pk})['tags'],
- ['code', 'mongo', 'python']
+ tags=["python"], content=Content(keywords=["ipsum"])
)
- self.assertEqual(post.content.keywords, ['ipsum'])
- post.modify(push__content__keywords__0=['lorem'])
- self.assertEqual(post.content.keywords, ['lorem', 'ipsum'])
+ self.assertEqual(post.tags, ["python"])
+ post.modify(push__tags__0=["code", "mongo"])
+ self.assertEqual(post.tags, ["code", "mongo", "python"])
# Assert same order of the list items is maintained in the db
self.assertEqual(
- BlogPost._get_collection().find_one({'_id': post.pk})['content']['keywords'],
- ['lorem', 'ipsum']
+ BlogPost._get_collection().find_one({"_id": post.pk})["tags"],
+ ["code", "mongo", "python"],
+ )
+
+ self.assertEqual(post.content.keywords, ["ipsum"])
+ post.modify(push__content__keywords__0=["lorem"])
+ self.assertEqual(post.content.keywords, ["lorem", "ipsum"])
+
+ # Assert same order of the list items is maintained in the db
+ self.assertEqual(
+ BlogPost._get_collection().find_one({"_id": post.pk})["content"][
+ "keywords"
+ ],
+ ["lorem", "ipsum"],
)
def test_save(self):
"""Ensure that a document may be saved in the database."""
# Create person object and save it to the database
- person = self.Person(name='Test User', age=30)
+ person = self.Person(name="Test User", age=30)
person.save()
# Ensure that the object is in the database
raw_doc = get_as_pymongo(person)
self.assertEqual(
raw_doc,
- {
- '_cls': 'Person',
- 'name': 'Test User',
- 'age': 30,
- '_id': person.id
- })
+ {"_cls": "Person", "name": "Test User", "age": 30, "_id": person.id},
+ )
def test_save_skip_validation(self):
class Recipient(Document):
email = EmailField(required=True)
- recipient = Recipient(email='not-an-email')
+ recipient = Recipient(email="not-an-email")
with self.assertRaises(ValidationError):
recipient.save()
recipient.save(validate=False)
raw_doc = get_as_pymongo(recipient)
- self.assertEqual(
- raw_doc,
- {
- 'email': 'not-an-email',
- '_id': recipient.id
- })
+ self.assertEqual(raw_doc, {"email": "not-an-email", "_id": recipient.id})
def test_save_with_bad_id(self):
class Clown(Document):
@@ -1012,8 +1032,8 @@ class InstanceTest(MongoDBTestCase):
def test_save_max_recursion_not_hit(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
- friend = ReferenceField('self')
+ parent = ReferenceField("self")
+ friend = ReferenceField("self")
Person.drop_collection()
@@ -1031,28 +1051,28 @@ class InstanceTest(MongoDBTestCase):
# Confirm can save and it resets the changed fields without hitting
# max recursion error
p0 = Person.objects.first()
- p0.name = 'wpjunior'
+ p0.name = "wpjunior"
p0.save()
def test_save_max_recursion_not_hit_with_file_field(self):
class Foo(Document):
name = StringField()
picture = FileField()
- bar = ReferenceField('self')
+ bar = ReferenceField("self")
Foo.drop_collection()
- a = Foo(name='hello').save()
+ a = Foo(name="hello").save()
a.bar = a
- with open(TEST_IMAGE_PATH, 'rb') as test_image:
+ with open(TEST_IMAGE_PATH, "rb") as test_image:
a.picture = test_image
a.save()
# Confirm can save and it resets the changed fields without hitting
# max recursion error
b = Foo.objects.with_id(a.id)
- b.name = 'world'
+ b.name = "world"
b.save()
self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture)
@@ -1060,7 +1080,7 @@ class InstanceTest(MongoDBTestCase):
def test_save_cascades(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
+ parent = ReferenceField("self")
Person.drop_collection()
@@ -1082,7 +1102,7 @@ class InstanceTest(MongoDBTestCase):
def test_save_cascade_kwargs(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
+ parent = ReferenceField("self")
Person.drop_collection()
@@ -1102,9 +1122,9 @@ class InstanceTest(MongoDBTestCase):
def test_save_cascade_meta_false(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
+ parent = ReferenceField("self")
- meta = {'cascade': False}
+ meta = {"cascade": False}
Person.drop_collection()
@@ -1130,9 +1150,9 @@ class InstanceTest(MongoDBTestCase):
def test_save_cascade_meta_true(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
+ parent = ReferenceField("self")
- meta = {'cascade': False}
+ meta = {"cascade": False}
Person.drop_collection()
@@ -1194,7 +1214,7 @@ class InstanceTest(MongoDBTestCase):
w1 = Widget(toggle=False, save_id=UUID(1))
# ignore save_condition on new record creation
- w1.save(save_condition={'save_id': UUID(42)})
+ w1.save(save_condition={"save_id": UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.save_id, UUID(1))
@@ -1204,8 +1224,9 @@ class InstanceTest(MongoDBTestCase):
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
- self.assertRaises(SaveConditionError,
- w1.save, save_condition={'save_id': UUID(42)})
+ self.assertRaises(
+ SaveConditionError, w1.save, save_condition={"save_id": UUID(42)}
+ )
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 0)
@@ -1214,7 +1235,7 @@ class InstanceTest(MongoDBTestCase):
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
- w1.save(save_condition={'save_id': UUID(1)})
+ w1.save(save_condition={"save_id": UUID(1)})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
@@ -1227,27 +1248,29 @@ class InstanceTest(MongoDBTestCase):
flip(w1)
w1.save_id = UUID(2)
- w1.save(save_condition={'save_id': old_id})
+ w1.save(save_condition={"save_id": old_id})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 2)
flip(w2)
flip(w2)
- self.assertRaises(SaveConditionError,
- w2.save, save_condition={'save_id': old_id})
+ self.assertRaises(
+ SaveConditionError, w2.save, save_condition={"save_id": old_id}
+ )
w2.reload()
self.assertFalse(w2.toggle)
self.assertEqual(w2.count, 2)
# save_condition uses mongoengine-style operator syntax
flip(w1)
- w1.save(save_condition={'count__lt': w1.count})
+ w1.save(save_condition={"count__lt": w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
flip(w1)
- self.assertRaises(SaveConditionError,
- w1.save, save_condition={'count__gte': w1.count})
+ self.assertRaises(
+ SaveConditionError, w1.save, save_condition={"count__gte": w1.count}
+ )
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
@@ -1259,19 +1282,19 @@ class InstanceTest(MongoDBTestCase):
WildBoy.drop_collection()
- WildBoy(age=12, name='John').save()
+ WildBoy(age=12, name="John").save()
boy1 = WildBoy.objects().first()
boy2 = WildBoy.objects().first()
boy1.age = 99
boy1.save()
- boy2.name = 'Bob'
+ boy2.name = "Bob"
boy2.save()
fresh_boy = WildBoy.objects().first()
self.assertEqual(fresh_boy.age, 99)
- self.assertEqual(fresh_boy.name, 'Bob')
+ self.assertEqual(fresh_boy.name, "Bob")
def test_save_update_selectively_with_custom_pk(self):
# Prevents regression of #2082
@@ -1282,30 +1305,30 @@ class InstanceTest(MongoDBTestCase):
WildBoy.drop_collection()
- WildBoy(pk_id='A', age=12, name='John').save()
+ WildBoy(pk_id="A", age=12, name="John").save()
boy1 = WildBoy.objects().first()
boy2 = WildBoy.objects().first()
boy1.age = 99
boy1.save()
- boy2.name = 'Bob'
+ boy2.name = "Bob"
boy2.save()
fresh_boy = WildBoy.objects().first()
self.assertEqual(fresh_boy.age, 99)
- self.assertEqual(fresh_boy.name, 'Bob')
+ self.assertEqual(fresh_boy.name, "Bob")
def test_update(self):
"""Ensure that an existing document is updated instead of be
overwritten.
"""
# Create person object and save it to the database
- person = self.Person(name='Test User', age=30)
+ person = self.Person(name="Test User", age=30)
person.save()
# Create same person object, with same id, without age
- same_person = self.Person(name='Test')
+ same_person = self.Person(name="Test")
same_person.id = person.id
same_person.save()
@@ -1322,54 +1345,54 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(person.age, same_person.age)
# Confirm the saved values
- self.assertEqual(person.name, 'Test')
+ self.assertEqual(person.name, "Test")
self.assertEqual(person.age, 30)
# Test only / exclude only updates included fields
- person = self.Person.objects.only('name').get()
- person.name = 'User'
+ person = self.Person.objects.only("name").get()
+ person.name = "User"
person.save()
person.reload()
- self.assertEqual(person.name, 'User')
+ self.assertEqual(person.name, "User")
self.assertEqual(person.age, 30)
# test exclude only updates set fields
- person = self.Person.objects.exclude('name').get()
+ person = self.Person.objects.exclude("name").get()
person.age = 21
person.save()
person.reload()
- self.assertEqual(person.name, 'User')
+ self.assertEqual(person.name, "User")
self.assertEqual(person.age, 21)
# Test only / exclude can set non excluded / included fields
- person = self.Person.objects.only('name').get()
- person.name = 'Test'
+ person = self.Person.objects.only("name").get()
+ person.name = "Test"
person.age = 30
person.save()
person.reload()
- self.assertEqual(person.name, 'Test')
+ self.assertEqual(person.name, "Test")
self.assertEqual(person.age, 30)
# test exclude only updates set fields
- person = self.Person.objects.exclude('name').get()
- person.name = 'User'
+ person = self.Person.objects.exclude("name").get()
+ person.name = "User"
person.age = 21
person.save()
person.reload()
- self.assertEqual(person.name, 'User')
+ self.assertEqual(person.name, "User")
self.assertEqual(person.age, 21)
# Confirm does remove unrequired fields
- person = self.Person.objects.exclude('name').get()
+ person = self.Person.objects.exclude("name").get()
person.age = None
person.save()
person.reload()
- self.assertEqual(person.name, 'User')
+ self.assertEqual(person.name, "User")
self.assertEqual(person.age, None)
person = self.Person.objects.get()
@@ -1384,19 +1407,18 @@ class InstanceTest(MongoDBTestCase):
def test_update_rename_operator(self):
"""Test the $rename operator."""
coll = self.Person._get_collection()
- doc = self.Person(name='John').save()
- raw_doc = coll.find_one({'_id': doc.pk})
- self.assertEqual(set(raw_doc.keys()), set(['_id', '_cls', 'name']))
+ doc = self.Person(name="John").save()
+ raw_doc = coll.find_one({"_id": doc.pk})
+ self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "name"]))
- doc.update(rename__name='first_name')
- raw_doc = coll.find_one({'_id': doc.pk})
- self.assertEqual(set(raw_doc.keys()),
- set(['_id', '_cls', 'first_name']))
- self.assertEqual(raw_doc['first_name'], 'John')
+ doc.update(rename__name="first_name")
+ raw_doc = coll.find_one({"_id": doc.pk})
+ self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "first_name"]))
+ self.assertEqual(raw_doc["first_name"], "John")
def test_inserts_if_you_set_the_pk(self):
- p1 = self.Person(name='p1', id=bson.ObjectId()).save()
- p2 = self.Person(name='p2')
+ p1 = self.Person(name="p1", id=bson.ObjectId()).save()
+ p2 = self.Person(name="p2")
p2.id = bson.ObjectId()
p2.save()
@@ -1410,33 +1432,34 @@ class InstanceTest(MongoDBTestCase):
pass
class Doc(Document):
- string_field = StringField(default='1')
+ string_field = StringField(default="1")
int_field = IntField(default=1)
float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now)
embedded_document_field = EmbeddedDocumentField(
- EmbeddedDoc, default=lambda: EmbeddedDoc())
+ EmbeddedDoc, default=lambda: EmbeddedDoc()
+ )
list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=bson.ObjectId)
- reference_field = ReferenceField(Simple, default=lambda:
- Simple().save())
+ reference_field = ReferenceField(Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(
- default=lambda: Simple().save())
- sorted_list_field = SortedListField(IntField(),
- default=lambda: [1, 2, 3])
+ default=lambda: Simple().save()
+ )
+ sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(
- default=lambda: EmbeddedDoc())
+ default=lambda: EmbeddedDoc()
+ )
Simple.drop_collection()
Doc.drop_collection()
@@ -1454,13 +1477,13 @@ class InstanceTest(MongoDBTestCase):
# try updating a non-saved document
with self.assertRaises(OperationError):
- person = self.Person(name='dcrosta')
- person.update(set__name='Dan Crosta')
+ person = self.Person(name="dcrosta")
+ person.update(set__name="Dan Crosta")
- author = self.Person(name='dcrosta')
+ author = self.Person(name="dcrosta")
author.save()
- author.update(set__name='Dan Crosta')
+ author.update(set__name="Dan Crosta")
author.reload()
p1 = self.Person.objects.first()
@@ -1490,9 +1513,9 @@ class InstanceTest(MongoDBTestCase):
def test_embedded_update(self):
"""Test update on `EmbeddedDocumentField` fields."""
+
class Page(EmbeddedDocument):
- log_message = StringField(verbose_name="Log message",
- required=True)
+ log_message = StringField(verbose_name="Log message", required=True)
class Site(Document):
page = EmbeddedDocumentField(Page)
@@ -1512,28 +1535,30 @@ class InstanceTest(MongoDBTestCase):
def test_update_list_field(self):
"""Test update on `ListField` with $pull + $in.
"""
+
class Doc(Document):
foo = ListField(StringField())
Doc.drop_collection()
- doc = Doc(foo=['a', 'b', 'c'])
+ doc = Doc(foo=["a", "b", "c"])
doc.save()
# Update
doc = Doc.objects.first()
- doc.update(pull__foo__in=['a', 'c'])
+ doc.update(pull__foo__in=["a", "c"])
doc = Doc.objects.first()
- self.assertEqual(doc.foo, ['b'])
+ self.assertEqual(doc.foo, ["b"])
def test_embedded_update_db_field(self):
"""Test update on `EmbeddedDocumentField` fields when db_field
is other than default.
"""
+
class Page(EmbeddedDocument):
- log_message = StringField(verbose_name="Log message",
- db_field="page_log_message",
- required=True)
+ log_message = StringField(
+ verbose_name="Log message", db_field="page_log_message", required=True
+ )
class Site(Document):
page = EmbeddedDocumentField(Page)
@@ -1553,13 +1578,14 @@ class InstanceTest(MongoDBTestCase):
def test_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields."""
+
class User(self.Person):
active = BooleanField(default=True)
User.drop_collection()
# Create person object and save it to the database
- user = User(name='Test User', age=30, active=True)
+ user = User(name="Test User", age=30, active=True)
user.save()
user.reload()
@@ -1570,28 +1596,31 @@ class InstanceTest(MongoDBTestCase):
user.age = 21
user.save()
- same_person.name = 'User'
+ same_person.name = "User"
same_person.save()
person = self.Person.objects.get()
- self.assertEqual(person.name, 'User')
+ self.assertEqual(person.name, "User")
self.assertEqual(person.age, 21)
self.assertEqual(person.active, False)
- def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc(self):
+ def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc(
+ self
+ ):
# Refers to Issue #1685
class EmbeddedChildModel(EmbeddedDocument):
id = DictField(primary_key=True)
class ParentModel(Document):
- child = EmbeddedDocumentField(
- EmbeddedChildModel)
+ child = EmbeddedDocumentField(EmbeddedChildModel)
- emb = EmbeddedChildModel(id={'1': [1]})
+ emb = EmbeddedChildModel(id={"1": [1]})
changed_fields = ParentModel(child=emb)._get_changed_fields()
self.assertEqual(changed_fields, [])
- def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc(self):
+ def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc(
+ self
+ ):
# Refers to Issue #1685
class User(Document):
id = IntField(primary_key=True)
@@ -1604,12 +1633,12 @@ class InstanceTest(MongoDBTestCase):
Message.drop_collection()
# All objects share the same id, but each in a different collection
- user = User(id=1, name='user-name').save()
+ user = User(id=1, name="user-name").save()
message = Message(id=1, author=user).save()
- message.author.name = 'tutu'
+ message.author.name = "tutu"
self.assertEqual(message._get_changed_fields(), [])
- self.assertEqual(user._get_changed_fields(), ['name'])
+ self.assertEqual(user._get_changed_fields(), ["name"])
def test__get_changed_fields_same_ids_embedded(self):
# Refers to Issue #1768
@@ -1624,24 +1653,25 @@ class InstanceTest(MongoDBTestCase):
Message.drop_collection()
# All objects share the same id, but each in a different collection
- user = User(id=1, name='user-name') # .save()
+ user = User(id=1, name="user-name") # .save()
message = Message(id=1, author=user).save()
- message.author.name = 'tutu'
- self.assertEqual(message._get_changed_fields(), ['author.name'])
+ message.author.name = "tutu"
+ self.assertEqual(message._get_changed_fields(), ["author.name"])
message.save()
message_fetched = Message.objects.with_id(message.id)
- self.assertEqual(message_fetched.author.name, 'tutu')
+ self.assertEqual(message_fetched.author.name, "tutu")
def test_query_count_when_saving(self):
"""Ensure references don't cause extra fetches when saving"""
+
class Organization(Document):
name = StringField()
class User(Document):
name = StringField()
- orgs = ListField(ReferenceField('Organization'))
+ orgs = ListField(ReferenceField("Organization"))
class Feed(Document):
name = StringField()
@@ -1667,9 +1697,9 @@ class InstanceTest(MongoDBTestCase):
user = User.objects.first()
# Even if stored as ObjectId's internally mongoengine uses DBRefs
# As ObjectId's aren't automatically derefenced
- self.assertIsInstance(user._data['orgs'][0], DBRef)
+ self.assertIsInstance(user._data["orgs"][0], DBRef)
self.assertIsInstance(user.orgs[0], Organization)
- self.assertIsInstance(user._data['orgs'][0], Organization)
+ self.assertIsInstance(user._data["orgs"][0], Organization)
# Changing a value
with query_counter() as q:
@@ -1731,6 +1761,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that $set and $unset actions are performed in the
same operation.
"""
+
class FooBar(Document):
foo = StringField(default=None)
bar = StringField(default=None)
@@ -1738,11 +1769,11 @@ class InstanceTest(MongoDBTestCase):
FooBar.drop_collection()
# write an entity with a single prop
- foo = FooBar(foo='foo').save()
+ foo = FooBar(foo="foo").save()
- self.assertEqual(foo.foo, 'foo')
+ self.assertEqual(foo.foo, "foo")
del foo.foo
- foo.bar = 'bar'
+ foo.bar = "bar"
with query_counter() as q:
self.assertEqual(0, q)
@@ -1751,6 +1782,7 @@ class InstanceTest(MongoDBTestCase):
def test_save_only_changed_fields_recursive(self):
"""Ensure save only sets / unsets changed fields."""
+
class Comment(EmbeddedDocument):
published = BooleanField(default=True)
@@ -1762,7 +1794,7 @@ class InstanceTest(MongoDBTestCase):
User.drop_collection()
# Create person object and save it to the database
- person = User(name='Test User', age=30, active=True)
+ person = User(name="Test User", age=30, active=True)
person.comments.append(Comment())
person.save()
person.reload()
@@ -1777,17 +1809,17 @@ class InstanceTest(MongoDBTestCase):
self.assertFalse(person.comments[0].published)
# Simple dict w
- person.comments_dict['first_post'] = Comment()
+ person.comments_dict["first_post"] = Comment()
person.save()
person = self.Person.objects.get()
- self.assertTrue(person.comments_dict['first_post'].published)
+ self.assertTrue(person.comments_dict["first_post"].published)
- person.comments_dict['first_post'].published = False
+ person.comments_dict["first_post"].published = False
person.save()
person = self.Person.objects.get()
- self.assertFalse(person.comments_dict['first_post'].published)
+ self.assertFalse(person.comments_dict["first_post"].published)
def test_delete(self):
"""Ensure that document may be deleted using the delete method."""
@@ -1801,31 +1833,30 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a document may be saved with a custom _id."""
# Create person object and save it to the database
- person = self.Person(name='Test User', age=30,
- id='497ce96f395f2f052a494fd4')
+ person = self.Person(name="Test User", age=30, id="497ce96f395f2f052a494fd4")
person.save()
# Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()]
- person_obj = collection.find_one({'name': 'Test User'})
- self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
+ person_obj = collection.find_one({"name": "Test User"})
+ self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4")
def test_save_custom_pk(self):
"""Ensure that a document may be saved with a custom _id using
pk alias.
"""
# Create person object and save it to the database
- person = self.Person(name='Test User', age=30,
- pk='497ce96f395f2f052a494fd4')
+ person = self.Person(name="Test User", age=30, pk="497ce96f395f2f052a494fd4")
person.save()
# Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._get_collection_name()]
- person_obj = collection.find_one({'name': 'Test User'})
- self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
+ person_obj = collection.find_one({"name": "Test User"})
+ self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4")
def test_save_list(self):
"""Ensure that a list field may be properly saved."""
+
class Comment(EmbeddedDocument):
content = StringField()
@@ -1836,37 +1867,36 @@ class InstanceTest(MongoDBTestCase):
BlogPost.drop_collection()
- post = BlogPost(content='Went for a walk today...')
- post.tags = tags = ['fun', 'leisure']
- comments = [Comment(content='Good for you'), Comment(content='Yay.')]
+ post = BlogPost(content="Went for a walk today...")
+ post.tags = tags = ["fun", "leisure"]
+ comments = [Comment(content="Good for you"), Comment(content="Yay.")]
post.comments = comments
post.save()
collection = self.db[BlogPost._get_collection_name()]
post_obj = collection.find_one()
- self.assertEqual(post_obj['tags'], tags)
- for comment_obj, comment in zip(post_obj['comments'], comments):
- self.assertEqual(comment_obj['content'], comment['content'])
+ self.assertEqual(post_obj["tags"], tags)
+ for comment_obj, comment in zip(post_obj["comments"], comments):
+ self.assertEqual(comment_obj["content"], comment["content"])
def test_list_search_by_embedded(self):
class User(Document):
username = StringField(required=True)
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class Comment(EmbeddedDocument):
comment = StringField()
- user = ReferenceField(User,
- required=True)
+ user = ReferenceField(User, required=True)
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class Page(Document):
comments = ListField(EmbeddedDocumentField(Comment))
- meta = {'allow_inheritance': False,
- 'indexes': [
- {'fields': ['comments.user']}
- ]}
+ meta = {
+ "allow_inheritance": False,
+ "indexes": [{"fields": ["comments.user"]}],
+ }
User.drop_collection()
Page.drop_collection()
@@ -1880,14 +1910,22 @@ class InstanceTest(MongoDBTestCase):
u3 = User(username="hmarr")
u3.save()
- p1 = Page(comments=[Comment(user=u1, comment="Its very good"),
- Comment(user=u2, comment="Hello world"),
- Comment(user=u3, comment="Ping Pong"),
- Comment(user=u1, comment="I like a beer")])
+ p1 = Page(
+ comments=[
+ Comment(user=u1, comment="Its very good"),
+ Comment(user=u2, comment="Hello world"),
+ Comment(user=u3, comment="Ping Pong"),
+ Comment(user=u1, comment="I like a beer"),
+ ]
+ )
p1.save()
- p2 = Page(comments=[Comment(user=u1, comment="Its very good"),
- Comment(user=u2, comment="Hello world")])
+ p2 = Page(
+ comments=[
+ Comment(user=u1, comment="Its very good"),
+ Comment(user=u2, comment="Hello world"),
+ ]
+ )
p2.save()
p3 = Page(comments=[Comment(user=u3, comment="Its very good")])
@@ -1896,20 +1934,15 @@ class InstanceTest(MongoDBTestCase):
p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")])
p4.save()
- self.assertEqual(
- [p1, p2],
- list(Page.objects.filter(comments__user=u1)))
- self.assertEqual(
- [p1, p2, p4],
- list(Page.objects.filter(comments__user=u2)))
- self.assertEqual(
- [p1, p3],
- list(Page.objects.filter(comments__user=u3)))
+ self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1)))
+ self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2)))
+ self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3)))
def test_save_embedded_document(self):
"""Ensure that a document with an embedded document field may
be saved in the database.
"""
+
class EmployeeDetails(EmbeddedDocument):
position = StringField()
@@ -1918,26 +1951,26 @@ class InstanceTest(MongoDBTestCase):
details = EmbeddedDocumentField(EmployeeDetails)
# Create employee object and save it to the database
- employee = Employee(name='Test Employee', age=50, salary=20000)
- employee.details = EmployeeDetails(position='Developer')
+ employee = Employee(name="Test Employee", age=50, salary=20000)
+ employee.details = EmployeeDetails(position="Developer")
employee.save()
# Ensure that the object is in the database
collection = self.db[self.Person._get_collection_name()]
- employee_obj = collection.find_one({'name': 'Test Employee'})
- self.assertEqual(employee_obj['name'], 'Test Employee')
- self.assertEqual(employee_obj['age'], 50)
+ employee_obj = collection.find_one({"name": "Test Employee"})
+ self.assertEqual(employee_obj["name"], "Test Employee")
+ self.assertEqual(employee_obj["age"], 50)
# Ensure that the 'details' embedded object saved correctly
- self.assertEqual(employee_obj['details']['position'], 'Developer')
+ self.assertEqual(employee_obj["details"]["position"], "Developer")
def test_embedded_update_after_save(self):
"""Test update of `EmbeddedDocumentField` attached to a newly
saved document.
"""
+
class Page(EmbeddedDocument):
- log_message = StringField(verbose_name="Log message",
- required=True)
+ log_message = StringField(verbose_name="Log message", required=True)
class Site(Document):
page = EmbeddedDocumentField(Page)
@@ -1957,6 +1990,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a document with an embedded document field may
be saved in the database.
"""
+
class EmployeeDetails(EmbeddedDocument):
position = StringField()
@@ -1965,22 +1999,21 @@ class InstanceTest(MongoDBTestCase):
details = EmbeddedDocumentField(EmployeeDetails)
# Create employee object and save it to the database
- employee = Employee(name='Test Employee', age=50, salary=20000)
- employee.details = EmployeeDetails(position='Developer')
+ employee = Employee(name="Test Employee", age=50, salary=20000)
+ employee.details = EmployeeDetails(position="Developer")
employee.save()
# Test updating an embedded document
- promoted_employee = Employee.objects.get(name='Test Employee')
- promoted_employee.details.position = 'Senior Developer'
+ promoted_employee = Employee.objects.get(name="Test Employee")
+ promoted_employee.details.position = "Senior Developer"
promoted_employee.save()
promoted_employee.reload()
- self.assertEqual(promoted_employee.name, 'Test Employee')
+ self.assertEqual(promoted_employee.name, "Test Employee")
self.assertEqual(promoted_employee.age, 50)
# Ensure that the 'details' embedded object saved correctly
- self.assertEqual(
- promoted_employee.details.position, 'Senior Developer')
+ self.assertEqual(promoted_employee.details.position, "Senior Developer")
# Test removal
promoted_employee.details = None
@@ -1996,12 +2029,12 @@ class InstanceTest(MongoDBTestCase):
class Foo(EmbeddedDocument, NameMixin):
quantity = IntField()
- self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys()))
+ self.assertEqual(["name", "quantity"], sorted(Foo._fields.keys()))
class Bar(Document, NameMixin):
widgets = StringField()
- self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys()))
+ self.assertEqual(["id", "name", "widgets"], sorted(Bar._fields.keys()))
def test_mixin_inheritance(self):
class BaseMixIn(object):
@@ -2015,8 +2048,7 @@ class InstanceTest(MongoDBTestCase):
age = IntField()
TestDoc.drop_collection()
- t = TestDoc(count=12, data="test",
- comment="great!", age=19)
+ t = TestDoc(count=12, data="test", comment="great!", age=19)
t.save()
@@ -2031,17 +2063,18 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a document reference field may be saved in the
database.
"""
+
class BlogPost(Document):
- meta = {'collection': 'blogpost_1'}
+ meta = {"collection": "blogpost_1"}
content = StringField()
author = ReferenceField(self.Person)
BlogPost.drop_collection()
- author = self.Person(name='Test User')
+ author = self.Person(name="Test User")
author.save()
- post = BlogPost(content='Watched some TV today... how exciting.')
+ post = BlogPost(content="Watched some TV today... how exciting.")
# Should only reference author when saving
post.author = author
post.save()
@@ -2049,15 +2082,15 @@ class InstanceTest(MongoDBTestCase):
post_obj = BlogPost.objects.first()
# Test laziness
- self.assertIsInstance(post_obj._data['author'], bson.DBRef)
+ self.assertIsInstance(post_obj._data["author"], bson.DBRef)
self.assertIsInstance(post_obj.author, self.Person)
- self.assertEqual(post_obj.author.name, 'Test User')
+ self.assertEqual(post_obj.author.name, "Test User")
# Ensure that the dereferenced object may be changed and saved
post_obj.author.age = 25
post_obj.author.save()
- author = list(self.Person.objects(name='Test User'))[-1]
+ author = list(self.Person.objects(name="Test User"))[-1]
self.assertEqual(author.age, 25)
def test_duplicate_db_fields_raise_invalid_document_error(self):
@@ -2065,12 +2098,14 @@ class InstanceTest(MongoDBTestCase):
declare the same db_field.
"""
with self.assertRaises(InvalidDocumentError):
+
class Foo(Document):
name = StringField()
- name2 = StringField(db_field='name')
+ name2 = StringField(db_field="name")
def test_invalid_son(self):
"""Raise an error if loading invalid data."""
+
class Occurrence(EmbeddedDocument):
number = IntField()
@@ -2081,21 +2116,24 @@ class InstanceTest(MongoDBTestCase):
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
with self.assertRaises(InvalidDocumentError):
- Word._from_son({
- 'stem': [1, 2, 3],
- 'forms': 1,
- 'count': 'one',
- 'occurs': {"hello": None}
- })
+ Word._from_son(
+ {
+ "stem": [1, 2, 3],
+ "forms": 1,
+ "count": "one",
+ "occurs": {"hello": None},
+ }
+ )
# Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438
with self.assertRaises(ValueError):
- Word._from_son('this is not a valid SON dict')
+ Word._from_son("this is not a valid SON dict")
def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon
deletion.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@@ -2104,13 +2142,13 @@ class InstanceTest(MongoDBTestCase):
self.Person.drop_collection()
BlogPost.drop_collection()
- author = self.Person(name='Test User')
+ author = self.Person(name="Test User")
author.save()
- reviewer = self.Person(name='Re Viewer')
+ reviewer = self.Person(name="Re Viewer")
reviewer.save()
- post = BlogPost(content='Watched some TV')
+ post = BlogPost(content="Watched some TV")
post.author = author
post.reviewer = reviewer
post.save()
@@ -2128,24 +2166,26 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a referenced document is also deleted with
pull.
"""
+
class Record(Document):
name = StringField()
- children = ListField(ReferenceField('self', reverse_delete_rule=PULL))
+ children = ListField(ReferenceField("self", reverse_delete_rule=PULL))
Record.drop_collection()
- parent_record = Record(name='parent').save()
- child_record = Record(name='child').save()
+ parent_record = Record(name="parent").save()
+ child_record = Record(name="child").save()
parent_record.children.append(child_record)
parent_record.save()
child_record.delete()
- self.assertEqual(Record.objects(name='parent').get().children, [])
+ self.assertEqual(Record.objects(name="parent").get().children, [])
def test_reverse_delete_rule_with_custom_id_field(self):
"""Ensure that a referenced document with custom primary key
is also deleted upon deletion.
"""
+
class User(Document):
name = StringField(primary_key=True)
@@ -2156,8 +2196,8 @@ class InstanceTest(MongoDBTestCase):
User.drop_collection()
Book.drop_collection()
- user = User(name='Mike').save()
- reviewer = User(name='John').save()
+ user = User(name="Mike").save()
+ reviewer = User(name="John").save()
book = Book(author=user, reviewer=reviewer).save()
reviewer.delete()
@@ -2171,6 +2211,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that cascade delete rule doesn't mix id among
collections.
"""
+
class User(Document):
id = IntField(primary_key=True)
@@ -2203,6 +2244,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a referenced document is also deleted upon
deletion of a child document.
"""
+
class Writer(self.Person):
pass
@@ -2214,13 +2256,13 @@ class InstanceTest(MongoDBTestCase):
self.Person.drop_collection()
BlogPost.drop_collection()
- author = Writer(name='Test User')
+ author = Writer(name="Test User")
author.save()
- reviewer = Writer(name='Re Viewer')
+ reviewer = Writer(name="Re Viewer")
reviewer.save()
- post = BlogPost(content='Watched some TV')
+ post = BlogPost(content="Watched some TV")
post.author = author
post.reviewer = reviewer
post.save()
@@ -2237,23 +2279,26 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a referenced document is also deleted upon
deletion for complex fields.
"""
+
class BlogPost(Document):
content = StringField()
- authors = ListField(ReferenceField(
- self.Person, reverse_delete_rule=CASCADE))
- reviewers = ListField(ReferenceField(
- self.Person, reverse_delete_rule=NULLIFY))
+ authors = ListField(
+ ReferenceField(self.Person, reverse_delete_rule=CASCADE)
+ )
+ reviewers = ListField(
+ ReferenceField(self.Person, reverse_delete_rule=NULLIFY)
+ )
self.Person.drop_collection()
BlogPost.drop_collection()
- author = self.Person(name='Test User')
+ author = self.Person(name="Test User")
author.save()
- reviewer = self.Person(name='Re Viewer')
+ reviewer = self.Person(name="Re Viewer")
reviewer.save()
- post = BlogPost(content='Watched some TV')
+ post = BlogPost(content="Watched some TV")
post.authors = [author]
post.reviewers = [reviewer]
post.save()
@@ -2273,6 +2318,7 @@ class InstanceTest(MongoDBTestCase):
delete the author which triggers deletion of blogpost via
cascade blog post's pre_delete signal alters an editor attribute.
"""
+
class Editor(self.Person):
review_queue = IntField(default=0)
@@ -2292,32 +2338,32 @@ class InstanceTest(MongoDBTestCase):
BlogPost.drop_collection()
Editor.drop_collection()
- author = self.Person(name='Will S.').save()
- editor = Editor(name='Max P.', review_queue=1).save()
- BlogPost(content='wrote some books', author=author,
- editor=editor).save()
+ author = self.Person(name="Will S.").save()
+ editor = Editor(name="Max P.", review_queue=1).save()
+ BlogPost(content="wrote some books", author=author, editor=editor).save()
# delete the author, the post is also deleted due to the CASCADE rule
author.delete()
# the pre-delete signal should have decremented the editor's queue
- editor = Editor.objects(name='Max P.').get()
+ editor = Editor.objects(name="Max P.").get()
self.assertEqual(editor.review_queue, 0)
def test_two_way_reverse_delete_rule(self):
"""Ensure that Bi-Directional relationships work with
reverse_delete_rule
"""
+
class Bar(Document):
content = StringField()
- foo = ReferenceField('Foo')
+ foo = ReferenceField("Foo")
class Foo(Document):
content = StringField()
bar = ReferenceField(Bar)
- Bar.register_delete_rule(Foo, 'bar', NULLIFY)
- Foo.register_delete_rule(Bar, 'foo', NULLIFY)
+ Bar.register_delete_rule(Foo, "bar", NULLIFY)
+ Foo.register_delete_rule(Bar, "foo", NULLIFY)
Bar.drop_collection()
Foo.drop_collection()
@@ -2338,24 +2384,27 @@ class InstanceTest(MongoDBTestCase):
def test_invalid_reverse_delete_rule_raise_errors(self):
with self.assertRaises(InvalidDocumentError):
+
class Blog(Document):
content = StringField()
- authors = MapField(ReferenceField(
- self.Person, reverse_delete_rule=CASCADE))
+ authors = MapField(
+ ReferenceField(self.Person, reverse_delete_rule=CASCADE)
+ )
reviewers = DictField(
- field=ReferenceField(
- self.Person,
- reverse_delete_rule=NULLIFY))
+ field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY)
+ )
with self.assertRaises(InvalidDocumentError):
+
class Parents(EmbeddedDocument):
- father = ReferenceField('Person', reverse_delete_rule=DENY)
- mother = ReferenceField('Person', reverse_delete_rule=DENY)
+ father = ReferenceField("Person", reverse_delete_rule=DENY)
+ mother = ReferenceField("Person", reverse_delete_rule=DENY)
def test_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon
cascaded deletion.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@@ -2368,14 +2417,14 @@ class InstanceTest(MongoDBTestCase):
BlogPost.drop_collection()
Comment.drop_collection()
- author = self.Person(name='Test User')
+ author = self.Person(name="Test User")
author.save()
- post = BlogPost(content='Watched some TV')
+ post = BlogPost(content="Watched some TV")
post.author = author
post.save()
- comment = Comment(text='Kudos.')
+ comment = Comment(text="Kudos.")
comment.post = post
comment.save()
@@ -2388,6 +2437,7 @@ class InstanceTest(MongoDBTestCase):
"""Ensure that a document cannot be referenced if there are
still documents referring to it.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=DENY)
@@ -2395,20 +2445,22 @@ class InstanceTest(MongoDBTestCase):
self.Person.drop_collection()
BlogPost.drop_collection()
- author = self.Person(name='Test User')
+ author = self.Person(name="Test User")
author.save()
- post = BlogPost(content='Watched some TV')
+ post = BlogPost(content="Watched some TV")
post.author = author
post.save()
# Delete the Person should be denied
self.assertRaises(OperationError, author.delete) # Should raise denied error
- self.assertEqual(BlogPost.objects.count(), 1) # No objects may have been deleted
+ self.assertEqual(
+ BlogPost.objects.count(), 1
+ ) # No objects may have been deleted
self.assertEqual(self.Person.objects.count(), 1)
# Other users, that don't have BlogPosts must be removable, like normal
- author = self.Person(name='Another User')
+ author = self.Person(name="Another User")
author.save()
self.assertEqual(self.Person.objects.count(), 2)
@@ -2434,6 +2486,7 @@ class InstanceTest(MongoDBTestCase):
def test_document_hash(self):
"""Test document in list, dict, set."""
+
class User(Document):
pass
@@ -2491,9 +2544,11 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(len(all_user_set), 3)
def test_picklable(self):
- pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
+ pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"])
pickle_doc.embedded = PickleEmbedded()
- pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
+ pickled_doc = pickle.dumps(
+ pickle_doc
+ ) # make sure pickling works even before the doc is saved
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
@@ -2516,8 +2571,10 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
def test_regular_document_pickle(self):
- pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
- pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
+ pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"])
+ pickled_doc = pickle.dumps(
+ pickle_doc
+ ) # make sure pickling works even before the doc is saved
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
@@ -2527,21 +2584,23 @@ class InstanceTest(MongoDBTestCase):
fixtures.PickleTest = fixtures.NewDocumentPickleTest
resurrected = pickle.loads(pickled_doc)
- self.assertEqual(resurrected.__class__,
- fixtures.NewDocumentPickleTest)
- self.assertEqual(resurrected._fields_ordered,
- fixtures.NewDocumentPickleTest._fields_ordered)
- self.assertNotEqual(resurrected._fields_ordered,
- pickle_doc._fields_ordered)
+ self.assertEqual(resurrected.__class__, fixtures.NewDocumentPickleTest)
+ self.assertEqual(
+ resurrected._fields_ordered, fixtures.NewDocumentPickleTest._fields_ordered
+ )
+ self.assertNotEqual(resurrected._fields_ordered, pickle_doc._fields_ordered)
# The local PickleTest is still a ref to the original
fixtures.PickleTest = PickleTest
def test_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest(
- name="test", number=1, string="One", lists=['1', '2'])
+ name="test", number=1, string="One", lists=["1", "2"]
+ )
pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar")
- pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
+ pickled_doc = pickle.dumps(
+ pickle_doc
+ ) # make sure pickling works even before the doc is saved
pickle_doc.save()
@@ -2549,20 +2608,22 @@ class InstanceTest(MongoDBTestCase):
resurrected = pickle.loads(pickled_doc)
self.assertEqual(resurrected, pickle_doc)
- self.assertEqual(resurrected._fields_ordered,
- pickle_doc._fields_ordered)
- self.assertEqual(resurrected._dynamic_fields.keys(),
- pickle_doc._dynamic_fields.keys())
+ self.assertEqual(resurrected._fields_ordered, pickle_doc._fields_ordered)
+ self.assertEqual(
+ resurrected._dynamic_fields.keys(), pickle_doc._dynamic_fields.keys()
+ )
self.assertEqual(resurrected.embedded, pickle_doc.embedded)
- self.assertEqual(resurrected.embedded._fields_ordered,
- pickle_doc.embedded._fields_ordered)
- self.assertEqual(resurrected.embedded._dynamic_fields.keys(),
- pickle_doc.embedded._dynamic_fields.keys())
+ self.assertEqual(
+ resurrected.embedded._fields_ordered, pickle_doc.embedded._fields_ordered
+ )
+ self.assertEqual(
+ resurrected.embedded._dynamic_fields.keys(),
+ pickle_doc.embedded._dynamic_fields.keys(),
+ )
def test_picklable_on_signals(self):
- pickle_doc = PickleSignalsTest(
- number=1, string="One", lists=['1', '2'])
+ pickle_doc = PickleSignalsTest(number=1, string="One", lists=["1", "2"])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickle_doc.delete()
@@ -2572,12 +2633,13 @@ class InstanceTest(MongoDBTestCase):
the "validate" method.
"""
with self.assertRaises(InvalidDocumentError):
+
class Blog(Document):
validate = DictField()
def test_mutating_documents(self):
class B(EmbeddedDocument):
- field1 = StringField(default='field1')
+ field1 = StringField(default="field1")
class A(Document):
b = EmbeddedDocumentField(B, default=lambda: B())
@@ -2587,27 +2649,28 @@ class InstanceTest(MongoDBTestCase):
a = A()
a.save()
a.reload()
- self.assertEqual(a.b.field1, 'field1')
+ self.assertEqual(a.b.field1, "field1")
class C(EmbeddedDocument):
- c_field = StringField(default='cfield')
+ c_field = StringField(default="cfield")
class B(EmbeddedDocument):
- field1 = StringField(default='field1')
+ field1 = StringField(default="field1")
field2 = EmbeddedDocumentField(C, default=lambda: C())
class A(Document):
b = EmbeddedDocumentField(B, default=lambda: B())
a = A.objects()[0]
- a.b.field2.c_field = 'new value'
+ a.b.field2.c_field = "new value"
a.save()
a.reload()
- self.assertEqual(a.b.field2.c_field, 'new value')
+ self.assertEqual(a.b.field2.c_field, "new value")
def test_can_save_false_values(self):
"""Ensures you can save False values on save."""
+
class Doc(Document):
foo = StringField()
archived = BooleanField(default=False, required=True)
@@ -2623,6 +2686,7 @@ class InstanceTest(MongoDBTestCase):
def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs."""
+
class Doc(DynamicDocument):
foo = StringField()
@@ -2637,6 +2701,7 @@ class InstanceTest(MongoDBTestCase):
def test_do_not_save_unchanged_references(self):
"""Ensures cascading saves dont auto update"""
+
class Job(Document):
name = StringField()
@@ -2655,8 +2720,10 @@ class InstanceTest(MongoDBTestCase):
person = Person(name="name", age=10, job=job)
from pymongo.collection import Collection
+
orig_update = Collection.update
try:
+
def fake_update(*args, **kwargs):
self.fail("Unexpected update for %s" % args[0].name)
return orig_update(*args, **kwargs)
@@ -2670,9 +2737,9 @@ class InstanceTest(MongoDBTestCase):
"""DB Alias tests."""
# mongoenginetest - Is default connection alias from setUp()
# Register Aliases
- register_connection('testdb-1', 'mongoenginetest2')
- register_connection('testdb-2', 'mongoenginetest3')
- register_connection('testdb-3', 'mongoenginetest4')
+ register_connection("testdb-1", "mongoenginetest2")
+ register_connection("testdb-2", "mongoenginetest3")
+ register_connection("testdb-3", "mongoenginetest4")
class User(Document):
name = StringField()
@@ -2719,42 +2786,43 @@ class InstanceTest(MongoDBTestCase):
# Collections
self.assertEqual(
- User._get_collection(),
- get_db("testdb-1")[User._get_collection_name()])
+ User._get_collection(), get_db("testdb-1")[User._get_collection_name()]
+ )
self.assertEqual(
- Book._get_collection(),
- get_db("testdb-2")[Book._get_collection_name()])
+ Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]
+ )
self.assertEqual(
AuthorBooks._get_collection(),
- get_db("testdb-3")[AuthorBooks._get_collection_name()])
+ get_db("testdb-3")[AuthorBooks._get_collection_name()],
+ )
def test_db_alias_overrides(self):
"""Test db_alias can be overriden."""
# Register a connection with db_alias testdb-2
- register_connection('testdb-2', 'mongoenginetest2')
+ register_connection("testdb-2", "mongoenginetest2")
class A(Document):
"""Uses default db_alias
"""
+
name = StringField()
meta = {"allow_inheritance": True}
class B(A):
"""Uses testdb-2 db_alias
"""
+
meta = {"db_alias": "testdb-2"}
A.objects.all()
- self.assertEqual('testdb-2', B._meta.get('db_alias'))
- self.assertEqual('mongoenginetest',
- A._get_collection().database.name)
- self.assertEqual('mongoenginetest2',
- B._get_collection().database.name)
+ self.assertEqual("testdb-2", B._meta.get("db_alias"))
+ self.assertEqual("mongoenginetest", A._get_collection().database.name)
+ self.assertEqual("mongoenginetest2", B._get_collection().database.name)
def test_db_alias_propagates(self):
"""db_alias propagates?"""
- register_connection('testdb-1', 'mongoenginetest2')
+ register_connection("testdb-1", "mongoenginetest2")
class A(Document):
name = StringField()
@@ -2763,10 +2831,11 @@ class InstanceTest(MongoDBTestCase):
class B(A):
pass
- self.assertEqual('testdb-1', B._meta.get('db_alias'))
+ self.assertEqual("testdb-1", B._meta.get("db_alias"))
def test_db_ref_usage(self):
"""DB Ref usage in dict_fields."""
+
class User(Document):
name = StringField()
@@ -2774,9 +2843,7 @@ class InstanceTest(MongoDBTestCase):
name = StringField()
author = ReferenceField(User)
extra = DictField()
- meta = {
- 'ordering': ['+name']
- }
+ meta = {"ordering": ["+name"]}
def __unicode__(self):
return self.name
@@ -2798,12 +2865,19 @@ class InstanceTest(MongoDBTestCase):
peter = User.objects.create(name="Peter")
# Bob
- Book.objects.create(name="1", author=bob, extra={
- "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]})
- Book.objects.create(name="2", author=bob, extra={
- "a": bob.to_dbref(), "b": karl.to_dbref()})
- Book.objects.create(name="3", author=bob, extra={
- "a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]})
+ Book.objects.create(
+ name="1",
+ author=bob,
+ extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]},
+ )
+ Book.objects.create(
+ name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()}
+ )
+ Book.objects.create(
+ name="3",
+ author=bob,
+ extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]},
+ )
Book.objects.create(name="4", author=bob)
# Jon
@@ -2811,56 +2885,77 @@ class InstanceTest(MongoDBTestCase):
Book.objects.create(name="6", author=peter)
Book.objects.create(name="7", author=jon)
Book.objects.create(name="8", author=jon)
- Book.objects.create(name="9", author=jon,
- extra={"a": peter.to_dbref()})
+ Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()})
# Checks
- self.assertEqual(",".join([str(b) for b in Book.objects.all()]),
- "1,2,3,4,5,6,7,8,9")
+ self.assertEqual(
+ ",".join([str(b) for b in Book.objects.all()]), "1,2,3,4,5,6,7,8,9"
+ )
# bob related books
- self.assertEqual(",".join([str(b) for b in Book.objects.filter(
- Q(extra__a=bob) |
- Q(author=bob) |
- Q(extra__b=bob))]),
- "1,2,3,4")
+ self.assertEqual(
+ ",".join(
+ [
+ str(b)
+ for b in Book.objects.filter(
+ Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob)
+ )
+ ]
+ ),
+ "1,2,3,4",
+ )
# Susan & Karl related books
- self.assertEqual(",".join([str(b) for b in Book.objects.filter(
- Q(extra__a__all=[karl, susan]) |
- Q(author__all=[karl, susan]) |
- Q(extra__b__all=[
- karl.to_dbref(), susan.to_dbref()]))
- ]), "1")
+ self.assertEqual(
+ ",".join(
+ [
+ str(b)
+ for b in Book.objects.filter(
+ Q(extra__a__all=[karl, susan])
+ | Q(author__all=[karl, susan])
+ | Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()])
+ )
+ ]
+ ),
+ "1",
+ )
# $Where
- self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
- __raw__={
- "$where": """
+ self.assertEqual(
+ u",".join(
+ [
+ str(b)
+ for b in Book.objects.filter(
+ __raw__={
+ "$where": """
function(){
return this.name == '1' ||
this.name == '2';}"""
- })]),
- "1,2")
+ }
+ )
+ ]
+ ),
+ "1,2",
+ )
def test_switch_db_instance(self):
- register_connection('testdb-1', 'mongoenginetest2')
+ register_connection("testdb-1", "mongoenginetest2")
class Group(Document):
name = StringField()
Group.drop_collection()
- with switch_db(Group, 'testdb-1') as Group:
+ with switch_db(Group, "testdb-1") as Group:
Group.drop_collection()
Group(name="hello - default").save()
self.assertEqual(1, Group.objects.count())
group = Group.objects.first()
- group.switch_db('testdb-1')
+ group.switch_db("testdb-1")
group.name = "hello - testdb!"
group.save()
- with switch_db(Group, 'testdb-1') as Group:
+ with switch_db(Group, "testdb-1") as Group:
group = Group.objects.first()
self.assertEqual("hello - testdb!", group.name)
@@ -2869,10 +2964,10 @@ class InstanceTest(MongoDBTestCase):
# Slightly contrived now - perform an update
# Only works as they have the same object_id
- group.switch_db('testdb-1')
+ group.switch_db("testdb-1")
group.update(set__name="hello - update")
- with switch_db(Group, 'testdb-1') as Group:
+ with switch_db(Group, "testdb-1") as Group:
group = Group.objects.first()
self.assertEqual("hello - update", group.name)
Group.drop_collection()
@@ -2883,10 +2978,10 @@ class InstanceTest(MongoDBTestCase):
# Totally contrived now - perform a delete
# Only works as they have the same object_id
- group.switch_db('testdb-1')
+ group.switch_db("testdb-1")
group.delete()
- with switch_db(Group, 'testdb-1') as Group:
+ with switch_db(Group, "testdb-1") as Group:
self.assertEqual(0, Group.objects.count())
group = Group.objects.first()
@@ -2898,11 +2993,9 @@ class InstanceTest(MongoDBTestCase):
User.drop_collection()
- User._get_collection().insert_one({
- 'name': 'John',
- 'foo': 'Bar',
- 'data': [1, 2, 3]
- })
+ User._get_collection().insert_one(
+ {"name": "John", "foo": "Bar", "data": [1, 2, 3]}
+ )
self.assertRaises(FieldDoesNotExist, User.objects.first)
@@ -2910,22 +3003,20 @@ class InstanceTest(MongoDBTestCase):
class User(Document):
name = StringField()
- meta = {'strict': False}
+ meta = {"strict": False}
User.drop_collection()
- User._get_collection().insert_one({
- 'name': 'John',
- 'foo': 'Bar',
- 'data': [1, 2, 3]
- })
+ User._get_collection().insert_one(
+ {"name": "John", "foo": "Bar", "data": [1, 2, 3]}
+ )
user = User.objects.first()
- self.assertEqual(user.name, 'John')
- self.assertFalse(hasattr(user, 'foo'))
- self.assertEqual(user._data['foo'], 'Bar')
- self.assertFalse(hasattr(user, 'data'))
- self.assertEqual(user._data['data'], [1, 2, 3])
+ self.assertEqual(user.name, "John")
+ self.assertFalse(hasattr(user, "foo"))
+ self.assertEqual(user._data["foo"], "Bar")
+ self.assertFalse(hasattr(user, "data"))
+ self.assertEqual(user._data["data"], [1, 2, 3])
def test_load_undefined_fields_on_embedded_document(self):
class Thing(EmbeddedDocument):
@@ -2937,14 +3028,12 @@ class InstanceTest(MongoDBTestCase):
User.drop_collection()
- User._get_collection().insert_one({
- 'name': 'John',
- 'thing': {
- 'name': 'My thing',
- 'foo': 'Bar',
- 'data': [1, 2, 3]
+ User._get_collection().insert_one(
+ {
+ "name": "John",
+ "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]},
}
- })
+ )
self.assertRaises(FieldDoesNotExist, User.objects.first)
@@ -2956,18 +3045,16 @@ class InstanceTest(MongoDBTestCase):
name = StringField()
thing = EmbeddedDocumentField(Thing)
- meta = {'strict': False}
+ meta = {"strict": False}
User.drop_collection()
- User._get_collection().insert_one({
- 'name': 'John',
- 'thing': {
- 'name': 'My thing',
- 'foo': 'Bar',
- 'data': [1, 2, 3]
+ User._get_collection().insert_one(
+ {
+ "name": "John",
+ "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]},
}
- })
+ )
self.assertRaises(FieldDoesNotExist, User.objects.first)
@@ -2975,7 +3062,7 @@ class InstanceTest(MongoDBTestCase):
class Thing(EmbeddedDocument):
name = StringField()
- meta = {'strict': False}
+ meta = {"strict": False}
class User(Document):
name = StringField()
@@ -2983,22 +3070,20 @@ class InstanceTest(MongoDBTestCase):
User.drop_collection()
- User._get_collection().insert_one({
- 'name': 'John',
- 'thing': {
- 'name': 'My thing',
- 'foo': 'Bar',
- 'data': [1, 2, 3]
+ User._get_collection().insert_one(
+ {
+ "name": "John",
+ "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]},
}
- })
+ )
user = User.objects.first()
- self.assertEqual(user.name, 'John')
- self.assertEqual(user.thing.name, 'My thing')
- self.assertFalse(hasattr(user.thing, 'foo'))
- self.assertEqual(user.thing._data['foo'], 'Bar')
- self.assertFalse(hasattr(user.thing, 'data'))
- self.assertEqual(user.thing._data['data'], [1, 2, 3])
+ self.assertEqual(user.name, "John")
+ self.assertEqual(user.thing.name, "My thing")
+ self.assertFalse(hasattr(user.thing, "foo"))
+ self.assertEqual(user.thing._data["foo"], "Bar")
+ self.assertFalse(hasattr(user.thing, "data"))
+ self.assertEqual(user.thing._data["data"], [1, 2, 3])
def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument):
@@ -3009,10 +3094,10 @@ class InstanceTest(MongoDBTestCase):
Doc.drop_collection()
doc = Doc()
- setattr(doc, 'hello world', 1)
+ setattr(doc, "hello world", 1)
doc.save()
- one = Doc.objects.filter(**{'hello world': 1}).count()
+ one = Doc.objects.filter(**{"hello world": 1}).count()
self.assertEqual(1, one)
def test_shard_key(self):
@@ -3020,9 +3105,7 @@ class InstanceTest(MongoDBTestCase):
machine = StringField()
log = StringField()
- meta = {
- 'shard_key': ('machine',)
- }
+ meta = {"shard_key": ("machine",)}
LogEntry.drop_collection()
@@ -3044,24 +3127,22 @@ class InstanceTest(MongoDBTestCase):
foo = StringField()
class Bar(Document):
- meta = {
- 'shard_key': ('foo.foo',)
- }
+ meta = {"shard_key": ("foo.foo",)}
foo = EmbeddedDocumentField(Foo)
bar = StringField()
- foo_doc = Foo(foo='hello')
- bar_doc = Bar(foo=foo_doc, bar='world')
+ foo_doc = Foo(foo="hello")
+ bar_doc = Bar(foo=foo_doc, bar="world")
bar_doc.save()
self.assertTrue(bar_doc.id is not None)
- bar_doc.bar = 'baz'
+ bar_doc.bar = "baz"
bar_doc.save()
# try to change the shard key
with self.assertRaises(OperationError):
- bar_doc.foo.foo = 'something'
+ bar_doc.foo.foo = "something"
bar_doc.save()
def test_shard_key_primary(self):
@@ -3069,9 +3150,7 @@ class InstanceTest(MongoDBTestCase):
machine = StringField(primary_key=True)
log = StringField()
- meta = {
- 'shard_key': ('machine',)
- }
+ meta = {"shard_key": ("machine",)}
LogEntry.drop_collection()
@@ -3097,12 +3176,10 @@ class InstanceTest(MongoDBTestCase):
doc = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
- return (self.doc_name == other.doc_name and
- self.doc == other.doc)
+ return self.doc_name == other.doc_name and self.doc == other.doc
classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc"))
- dict_doc = Doc(**{"doc_name": "my doc",
- "doc": {"name": "embedded doc"}})
+ dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}})
self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data)
@@ -3116,15 +3193,18 @@ class InstanceTest(MongoDBTestCase):
docs = ListField(EmbeddedDocumentField(Embedded))
def __eq__(self, other):
- return (self.doc_name == other.doc_name and
- self.docs == other.docs)
+ return self.doc_name == other.doc_name and self.docs == other.docs
- classic_doc = Doc(doc_name="my doc", docs=[
- Embedded(name="embedded doc1"),
- Embedded(name="embedded doc2")])
- dict_doc = Doc(**{"doc_name": "my doc",
- "docs": [{"name": "embedded doc1"},
- {"name": "embedded doc2"}]})
+ classic_doc = Doc(
+ doc_name="my doc",
+ docs=[Embedded(name="embedded doc1"), Embedded(name="embedded doc2")],
+ )
+ dict_doc = Doc(
+ **{
+ "doc_name": "my doc",
+ "docs": [{"name": "embedded doc1"}, {"name": "embedded doc2"}],
+ }
+ )
self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data)
@@ -3134,8 +3214,8 @@ class InstanceTest(MongoDBTestCase):
with self.assertRaises(TypeError) as e:
person = self.Person("Test User", 42)
expected_msg = (
- 'Instantiating a document with positional arguments is not '
- 'supported. Please use `field_name=value` keyword arguments.'
+ "Instantiating a document with positional arguments is not "
+ "supported. Please use `field_name=value` keyword arguments."
)
self.assertEqual(str(e.exception), expected_msg)
@@ -3144,8 +3224,8 @@ class InstanceTest(MongoDBTestCase):
with self.assertRaises(TypeError) as e:
person = self.Person("Test User", age=42)
expected_msg = (
- 'Instantiating a document with positional arguments is not '
- 'supported. Please use `field_name=value` keyword arguments.'
+ "Instantiating a document with positional arguments is not "
+ "supported. Please use `field_name=value` keyword arguments."
)
self.assertEqual(str(e.exception), expected_msg)
@@ -3154,8 +3234,8 @@ class InstanceTest(MongoDBTestCase):
with self.assertRaises(TypeError) as e:
job = self.Job("Test Job", 4)
expected_msg = (
- 'Instantiating a document with positional arguments is not '
- 'supported. Please use `field_name=value` keyword arguments.'
+ "Instantiating a document with positional arguments is not "
+ "supported. Please use `field_name=value` keyword arguments."
)
self.assertEqual(str(e.exception), expected_msg)
@@ -3164,13 +3244,14 @@ class InstanceTest(MongoDBTestCase):
with self.assertRaises(TypeError) as e:
job = self.Job("Test Job", years=4)
expected_msg = (
- 'Instantiating a document with positional arguments is not '
- 'supported. Please use `field_name=value` keyword arguments.'
+ "Instantiating a document with positional arguments is not "
+ "supported. Please use `field_name=value` keyword arguments."
)
self.assertEqual(str(e.exception), expected_msg)
def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id'."""
+
class Person(Document):
name = StringField()
@@ -3178,8 +3259,8 @@ class InstanceTest(MongoDBTestCase):
Person(name="Harry Potter").save()
person = Person.objects.first()
- self.assertIn('id', person._data.keys())
- self.assertEqual(person._data.get('id'), person.id)
+ self.assertIn("id", person._data.keys())
+ self.assertEqual(person._data.get("id"), person.id)
def test_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument):
@@ -3220,8 +3301,8 @@ class InstanceTest(MongoDBTestCase):
system = NodesSystem.objects.first()
self.assertEqual(
- "UNDEFINED",
- system.nodes["node"].parameters["param"].macros["test"].value)
+ "UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value
+ )
def test_embedded_document_equality(self):
class Test(Document):
@@ -3231,7 +3312,7 @@ class InstanceTest(MongoDBTestCase):
ref = ReferenceField(Test)
Test.drop_collection()
- test = Test(field='123').save() # has id
+ test = Test(field="123").save() # has id
e = Embedded(ref=test)
f1 = Embedded._from_son(e.to_mongo())
@@ -3250,25 +3331,25 @@ class InstanceTest(MongoDBTestCase):
class Test(Document):
name = StringField()
- test2 = ReferenceField('Test2')
- test3 = ReferenceField('Test3')
+ test2 = ReferenceField("Test2")
+ test3 = ReferenceField("Test3")
Test.drop_collection()
Test2.drop_collection()
Test3.drop_collection()
- t2 = Test2(name='a')
+ t2 = Test2(name="a")
t2.save()
- t3 = Test3(name='x')
+ t3 = Test3(name="x")
t3.id = t2.id
t3.save()
- t = Test(name='b', test2=t2, test3=t3)
+ t = Test(name="b", test2=t2, test3=t3)
f = Test._from_son(t.to_mongo())
- dbref2 = f._data['test2']
+ dbref2 = f._data["test2"]
obj2 = f.test2
self.assertIsInstance(dbref2, DBRef)
self.assertIsInstance(obj2, Test2)
@@ -3276,7 +3357,7 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(obj2, dbref2)
self.assertEqual(dbref2, obj2)
- dbref3 = f._data['test3']
+ dbref3 = f._data["test3"]
obj3 = f.test3
self.assertIsInstance(dbref3, DBRef)
self.assertIsInstance(obj3, Test3)
@@ -3306,14 +3387,14 @@ class InstanceTest(MongoDBTestCase):
created_on = DateTimeField(default=lambda: datetime.utcnow())
name = StringField()
- p = Person(name='alon')
+ p = Person(name="alon")
p.save()
- orig_created_on = Person.objects().only('created_on')[0].created_on
+ orig_created_on = Person.objects().only("created_on")[0].created_on
- p2 = Person.objects().only('name')[0]
- p2.name = 'alon2'
+ p2 = Person.objects().only("name")[0]
+ p2.name = "alon2"
p2.save()
- p3 = Person.objects().only('created_on')[0]
+ p3 = Person.objects().only("created_on")[0]
self.assertEqual(orig_created_on, p3.created_on)
class Person(Document):
@@ -3331,8 +3412,8 @@ class InstanceTest(MongoDBTestCase):
# alter DB for the new default
coll = Person._get_collection()
for person in Person.objects.as_pymongo():
- if 'height' not in person:
- coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}})
+ if "height" not in person:
+ coll.update_one({"_id": person["_id"]}, {"$set": {"height": 189}})
self.assertEqual(Person.objects(height=189).count(), 1)
@@ -3340,12 +3421,17 @@ class InstanceTest(MongoDBTestCase):
# 771
class MyPerson(self.Person):
meta = dict(shard_key=["id"])
+
p = MyPerson.from_json('{"name": "name", "age": 27}', created=True)
self.assertEqual(p.id, None)
- p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
+ p.id = (
+ "12345"
+ ) # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
p = MyPerson._from_son({"name": "name", "age": 27}, created=True)
self.assertEqual(p.id, None)
- p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
+ p.id = (
+ "12345"
+ ) # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here
def test_from_son_created_False_without_id(self):
class MyPerson(Document):
@@ -3359,7 +3445,7 @@ class InstanceTest(MongoDBTestCase):
p.save()
self.assertIsNotNone(p.id)
saved_p = MyPerson.objects.get(id=p.id)
- self.assertEqual(saved_p.name, 'a_fancy_name')
+ self.assertEqual(saved_p.name, "a_fancy_name")
def test_from_son_created_False_with_id(self):
# 1854
@@ -3368,11 +3454,13 @@ class InstanceTest(MongoDBTestCase):
MyPerson.objects.delete()
- p = MyPerson.from_json('{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=False)
+ p = MyPerson.from_json(
+ '{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=False
+ )
self.assertFalse(p._created)
self.assertEqual(p._changed_fields, [])
- self.assertEqual(p.name, 'a_fancy_name')
- self.assertEqual(p.id, ObjectId('5b85a8b04ec5dc2da388296e'))
+ self.assertEqual(p.name, "a_fancy_name")
+ self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e"))
p.save()
with self.assertRaises(DoesNotExist):
@@ -3382,8 +3470,8 @@ class InstanceTest(MongoDBTestCase):
MyPerson.objects.get(id=p.id)
self.assertFalse(p._created)
- p.name = 'a new fancy name'
- self.assertEqual(p._changed_fields, ['name'])
+ p.name = "a new fancy name"
+ self.assertEqual(p._changed_fields, ["name"])
p.save()
saved_p = MyPerson.objects.get(id=p.id)
self.assertEqual(saved_p.name, p.name)
@@ -3394,16 +3482,18 @@ class InstanceTest(MongoDBTestCase):
MyPerson.objects.delete()
- p = MyPerson.from_json('{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=True)
+ p = MyPerson.from_json(
+ '{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=True
+ )
self.assertTrue(p._created)
self.assertEqual(p._changed_fields, [])
- self.assertEqual(p.name, 'a_fancy_name')
- self.assertEqual(p.id, ObjectId('5b85a8b04ec5dc2da388296e'))
+ self.assertEqual(p.name, "a_fancy_name")
+ self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e"))
p.save()
saved_p = MyPerson.objects.get(id=p.id)
self.assertEqual(saved_p, p)
- self.assertEqual(p.name, 'a_fancy_name')
+ self.assertEqual(p.name, "a_fancy_name")
def test_null_field(self):
# 734
@@ -3417,9 +3507,9 @@ class InstanceTest(MongoDBTestCase):
cdt_fld = ComplexDateTimeField(null=True)
User.objects.delete()
- u = User(name='user')
+ u = User(name="user")
u.save()
- u_from_db = User.objects.get(name='user')
+ u_from_db = User.objects.get(name="user")
u_from_db.height = None
u_from_db.save()
self.assertEqual(u_from_db.height, None)
@@ -3432,15 +3522,16 @@ class InstanceTest(MongoDBTestCase):
# 735
User.objects.delete()
- u = User(name='user')
+ u = User(name="user")
u.save()
- User.objects(name='user').update_one(set__height=None, upsert=True)
- u_from_db = User.objects.get(name='user')
+ User.objects(name="user").update_one(set__height=None, upsert=True)
+ u_from_db = User.objects.get(name="user")
self.assertEqual(u_from_db.height, None)
def test_not_saved_eq(self):
"""Ensure we can compare documents not saved.
"""
+
class Person(Document):
pass
@@ -3458,7 +3549,7 @@ class InstanceTest(MongoDBTestCase):
l = ListField(EmbeddedDocumentField(B))
A.objects.delete()
- A(l=[B(v='1'), B(v='2'), B(v='3')]).save()
+ A(l=[B(v="1"), B(v="2"), B(v="3")]).save()
a = A.objects.get()
self.assertEqual(a.l._instance, a)
for idx, b in enumerate(a.l):
@@ -3467,6 +3558,7 @@ class InstanceTest(MongoDBTestCase):
def test_falsey_pk(self):
"""Ensure that we can create and update a document with Falsey PK."""
+
class Person(Document):
age = IntField(primary_key=True)
height = FloatField()
@@ -3480,6 +3572,7 @@ class InstanceTest(MongoDBTestCase):
def test_push_with_position(self):
"""Ensure that push with position works properly for an instance."""
+
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
@@ -3491,10 +3584,11 @@ class InstanceTest(MongoDBTestCase):
blog.update(push__tags__0=["mongodb", "code"])
blog.reload()
- self.assertEqual(blog.tags, ['mongodb', 'code', 'python'])
+ self.assertEqual(blog.tags, ["mongodb", "code", "python"])
def test_push_nested_list(self):
"""Ensure that push update works in nested list"""
+
class BlogPost(Document):
slug = StringField()
tags = ListField()
@@ -3505,10 +3599,11 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(blog.tags, [["value1", 123]])
def test_accessing_objects_with_indexes_error(self):
- insert_result = self.db.company.insert_many([{'name': 'Foo'},
- {'name': 'Foo'}]) # Force 2 doc with same name
+ insert_result = self.db.company.insert_many(
+ [{"name": "Foo"}, {"name": "Foo"}]
+ ) # Force 2 doc with same name
REF_OID = insert_result.inserted_ids[0]
- self.db.user.insert_one({'company': REF_OID}) # Force 2 doc with same name
+ self.db.user.insert_one({"company": REF_OID}) # Force 2 doc with same name
class Company(Document):
name = StringField(unique=True)
@@ -3521,5 +3616,5 @@ class InstanceTest(MongoDBTestCase):
User.objects().select_related()
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py
index 251b65a2..33d5a6d9 100644
--- a/tests/document/json_serialisation.py
+++ b/tests/document/json_serialisation.py
@@ -13,9 +13,8 @@ __all__ = ("TestJson",)
class TestJson(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
def test_json_names(self):
"""
@@ -25,22 +24,24 @@ class TestJson(unittest.TestCase):
a to_json with the original class names and not the abreviated
mongodb document keys
"""
+
class Embedded(EmbeddedDocument):
- string = StringField(db_field='s')
+ string = StringField(db_field="s")
class Doc(Document):
- string = StringField(db_field='s')
- embedded = EmbeddedDocumentField(Embedded, db_field='e')
+ string = StringField(db_field="s")
+ embedded = EmbeddedDocumentField(Embedded, db_field="e")
doc = Doc(string="Hello", embedded=Embedded(string="Inner Hello"))
- doc_json = doc.to_json(sort_keys=True, use_db_field=False, separators=(',', ':'))
+ doc_json = doc.to_json(
+ sort_keys=True, use_db_field=False, separators=(",", ":")
+ )
expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}"""
self.assertEqual(doc_json, expected_json)
def test_json_simple(self):
-
class Embedded(EmbeddedDocument):
string = StringField()
@@ -49,12 +50,14 @@ class TestJson(unittest.TestCase):
embedded_field = EmbeddedDocumentField(Embedded)
def __eq__(self, other):
- return (self.string == other.string and
- self.embedded_field == other.embedded_field)
+ return (
+ self.string == other.string
+ and self.embedded_field == other.embedded_field
+ )
doc = Doc(string="Hi", embedded_field=Embedded(string="Hi"))
- doc_json = doc.to_json(sort_keys=True, separators=(',', ':'))
+ doc_json = doc.to_json(sort_keys=True, separators=(",", ":"))
expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}"""
self.assertEqual(doc_json, expected_json)
@@ -68,41 +71,43 @@ class TestJson(unittest.TestCase):
pass
class Doc(Document):
- string_field = StringField(default='1')
+ string_field = StringField(default="1")
int_field = IntField(default=1)
float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now)
- embedded_document_field = EmbeddedDocumentField(EmbeddedDoc,
- default=lambda: EmbeddedDoc())
+ embedded_document_field = EmbeddedDocumentField(
+ EmbeddedDoc, default=lambda: EmbeddedDoc()
+ )
list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=ObjectId)
- reference_field = ReferenceField(Simple, default=lambda:
- Simple().save())
+ reference_field = ReferenceField(Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(
- default=lambda: Simple().save())
- sorted_list_field = SortedListField(IntField(),
- default=lambda: [1, 2, 3])
+ default=lambda: Simple().save()
+ )
+ sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(
- default=lambda: EmbeddedDoc())
+ default=lambda: EmbeddedDoc()
+ )
def __eq__(self, other):
import json
+
return json.loads(self.to_json()) == json.loads(other.to_json())
doc = Doc()
self.assertEqual(doc, Doc.from_json(doc.to_json()))
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/document/validation.py b/tests/document/validation.py
index 30a285b2..78199231 100644
--- a/tests/document/validation.py
+++ b/tests/document/validation.py
@@ -8,49 +8,56 @@ __all__ = ("ValidatorErrorTest",)
class ValidatorErrorTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
def test_to_dict(self):
"""Ensure a ValidationError handles error to_dict correctly.
"""
- error = ValidationError('root')
+ error = ValidationError("root")
self.assertEqual(error.to_dict(), {})
# 1st level error schema
- error.errors = {'1st': ValidationError('bad 1st'), }
- self.assertIn('1st', error.to_dict())
- self.assertEqual(error.to_dict()['1st'], 'bad 1st')
+ error.errors = {"1st": ValidationError("bad 1st")}
+ self.assertIn("1st", error.to_dict())
+ self.assertEqual(error.to_dict()["1st"], "bad 1st")
# 2nd level error schema
- error.errors = {'1st': ValidationError('bad 1st', errors={
- '2nd': ValidationError('bad 2nd'),
- })}
- self.assertIn('1st', error.to_dict())
- self.assertIsInstance(error.to_dict()['1st'], dict)
- self.assertIn('2nd', error.to_dict()['1st'])
- self.assertEqual(error.to_dict()['1st']['2nd'], 'bad 2nd')
+ error.errors = {
+ "1st": ValidationError(
+ "bad 1st", errors={"2nd": ValidationError("bad 2nd")}
+ )
+ }
+ self.assertIn("1st", error.to_dict())
+ self.assertIsInstance(error.to_dict()["1st"], dict)
+ self.assertIn("2nd", error.to_dict()["1st"])
+ self.assertEqual(error.to_dict()["1st"]["2nd"], "bad 2nd")
# moar levels
- error.errors = {'1st': ValidationError('bad 1st', errors={
- '2nd': ValidationError('bad 2nd', errors={
- '3rd': ValidationError('bad 3rd', errors={
- '4th': ValidationError('Inception'),
- }),
- }),
- })}
- self.assertIn('1st', error.to_dict())
- self.assertIn('2nd', error.to_dict()['1st'])
- self.assertIn('3rd', error.to_dict()['1st']['2nd'])
- self.assertIn('4th', error.to_dict()['1st']['2nd']['3rd'])
- self.assertEqual(error.to_dict()['1st']['2nd']['3rd']['4th'],
- 'Inception')
+ error.errors = {
+ "1st": ValidationError(
+ "bad 1st",
+ errors={
+ "2nd": ValidationError(
+ "bad 2nd",
+ errors={
+ "3rd": ValidationError(
+ "bad 3rd", errors={"4th": ValidationError("Inception")}
+ )
+ },
+ )
+ },
+ )
+ }
+ self.assertIn("1st", error.to_dict())
+ self.assertIn("2nd", error.to_dict()["1st"])
+ self.assertIn("3rd", error.to_dict()["1st"]["2nd"])
+ self.assertIn("4th", error.to_dict()["1st"]["2nd"]["3rd"])
+ self.assertEqual(error.to_dict()["1st"]["2nd"]["3rd"]["4th"], "Inception")
self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])")
def test_model_validation(self):
-
class User(Document):
username = StringField(primary_key=True)
name = StringField(required=True)
@@ -59,9 +66,10 @@ class ValidatorErrorTest(unittest.TestCase):
User().validate()
except ValidationError as e:
self.assertIn("User:None", e.message)
- self.assertEqual(e.to_dict(), {
- 'username': 'Field is required',
- 'name': 'Field is required'})
+ self.assertEqual(
+ e.to_dict(),
+ {"username": "Field is required", "name": "Field is required"},
+ )
user = User(username="RossC0", name="Ross").save()
user.name = None
@@ -69,14 +77,13 @@ class ValidatorErrorTest(unittest.TestCase):
user.save()
except ValidationError as e:
self.assertIn("User:RossC0", e.message)
- self.assertEqual(e.to_dict(), {
- 'name': 'Field is required'})
+ self.assertEqual(e.to_dict(), {"name": "Field is required"})
def test_fields_rewrite(self):
class BasePerson(Document):
name = StringField()
age = IntField()
- meta = {'abstract': True}
+ meta = {"abstract": True}
class Person(BasePerson):
name = StringField(required=True)
@@ -87,6 +94,7 @@ class ValidatorErrorTest(unittest.TestCase):
def test_embedded_document_validation(self):
"""Ensure that embedded documents may be validated.
"""
+
class Comment(EmbeddedDocument):
date = DateTimeField()
content = StringField(required=True)
@@ -94,7 +102,7 @@ class ValidatorErrorTest(unittest.TestCase):
comment = Comment()
self.assertRaises(ValidationError, comment.validate)
- comment.content = 'test'
+ comment.content = "test"
comment.validate()
comment.date = 4
@@ -105,20 +113,20 @@ class ValidatorErrorTest(unittest.TestCase):
self.assertEqual(comment._instance, None)
def test_embedded_db_field_validate(self):
-
class SubDoc(EmbeddedDocument):
val = IntField(required=True)
class Doc(Document):
id = StringField(primary_key=True)
- e = EmbeddedDocumentField(SubDoc, db_field='eb')
+ e = EmbeddedDocumentField(SubDoc, db_field="eb")
try:
Doc(id="bad").validate()
except ValidationError as e:
self.assertIn("SubDoc:None", e.message)
- self.assertEqual(e.to_dict(), {
- "e": {'val': 'OK could not be converted to int'}})
+ self.assertEqual(
+ e.to_dict(), {"e": {"val": "OK could not be converted to int"}}
+ )
Doc.drop_collection()
@@ -127,24 +135,24 @@ class ValidatorErrorTest(unittest.TestCase):
doc = Doc.objects.first()
keys = doc._data.keys()
self.assertEqual(2, len(keys))
- self.assertIn('e', keys)
- self.assertIn('id', keys)
+ self.assertIn("e", keys)
+ self.assertIn("id", keys)
doc.e.val = "OK"
try:
doc.save()
except ValidationError as e:
self.assertIn("Doc:test", e.message)
- self.assertEqual(e.to_dict(), {
- "e": {'val': 'OK could not be converted to int'}})
+ self.assertEqual(
+ e.to_dict(), {"e": {"val": "OK could not be converted to int"}}
+ )
def test_embedded_weakref(self):
-
class SubDoc(EmbeddedDocument):
val = IntField(required=True)
class Doc(Document):
- e = EmbeddedDocumentField(SubDoc, db_field='eb')
+ e = EmbeddedDocumentField(SubDoc, db_field="eb")
Doc.drop_collection()
@@ -167,9 +175,10 @@ class ValidatorErrorTest(unittest.TestCase):
Test to ensure a ReferenceField can store a reference to a parent
class when inherited. Issue #954.
"""
+
class Parent(Document):
- meta = {'allow_inheritance': True}
- reference = ReferenceField('self')
+ meta = {"allow_inheritance": True}
+ reference = ReferenceField("self")
class Child(Parent):
pass
@@ -190,9 +199,10 @@ class ValidatorErrorTest(unittest.TestCase):
Test to ensure a ReferenceField can store a reference to a parent
class when inherited and when set via attribute. Issue #954.
"""
+
class Parent(Document):
- meta = {'allow_inheritance': True}
- reference = ReferenceField('self')
+ meta = {"allow_inheritance": True}
+ reference = ReferenceField("self")
class Child(Parent):
pass
@@ -210,5 +220,5 @@ class ValidatorErrorTest(unittest.TestCase):
self.fail("ValidationError raised: %s" % e.message)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/fields/fields.py b/tests/fields/fields.py
index 68baab46..87acf27f 100644
--- a/tests/fields/fields.py
+++ b/tests/fields/fields.py
@@ -6,27 +6,52 @@ from nose.plugins.skip import SkipTest
from bson import DBRef, ObjectId, SON
-from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \
- ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \
- GenericReferenceField, DoesNotExist, NotRegistered, OperationError, DynamicField, \
- FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\
- ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument
-from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry)
+from mongoengine import (
+ Document,
+ StringField,
+ IntField,
+ DateTimeField,
+ DateField,
+ ValidationError,
+ ComplexDateTimeField,
+ FloatField,
+ ListField,
+ ReferenceField,
+ DictField,
+ EmbeddedDocument,
+ EmbeddedDocumentField,
+ GenericReferenceField,
+ DoesNotExist,
+ NotRegistered,
+ OperationError,
+ DynamicField,
+ FieldDoesNotExist,
+ EmbeddedDocumentListField,
+ MultipleObjectsReturned,
+ NotUniqueError,
+ BooleanField,
+ ObjectIdField,
+ SortedListField,
+ GenericLazyReferenceField,
+ LazyReferenceField,
+ DynamicDocument,
+)
+from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry
from mongoengine.errors import DeprecatedError
from tests.utils import MongoDBTestCase
class FieldTest(MongoDBTestCase):
-
def test_default_values_nothing_set(self):
"""Ensure that default field values are used when creating
a document.
"""
+
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
- userid = StringField(default=lambda: 'test', required=True)
+ userid = StringField(default=lambda: "test", required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
day = DateField(default=datetime.date.today)
@@ -34,9 +59,7 @@ class FieldTest(MongoDBTestCase):
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved,
- ['age', 'created', 'day', 'name', 'userid']
- )
+ self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"])
self.assertTrue(person.validate() is None)
@@ -46,18 +69,19 @@ class FieldTest(MongoDBTestCase):
self.assertEqual(person.created, person.created)
self.assertEqual(person.day, person.day)
- self.assertEqual(person._data['name'], person.name)
- self.assertEqual(person._data['age'], person.age)
- self.assertEqual(person._data['userid'], person.userid)
- self.assertEqual(person._data['created'], person.created)
- self.assertEqual(person._data['day'], person.day)
+ self.assertEqual(person._data["name"], person.name)
+ self.assertEqual(person._data["age"], person.age)
+ self.assertEqual(person._data["userid"], person.userid)
+ self.assertEqual(person._data["created"], person.created)
+ self.assertEqual(person._data["day"], person.day)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(
- data_to_be_saved, ['age', 'created', 'day', 'name', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"])
- def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(self):
+ def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(
+ self
+ ):
# Covers introduction of a breaking change in the validation parameter (0.18)
def _not_empty(z):
return bool(z)
@@ -67,8 +91,10 @@ class FieldTest(MongoDBTestCase):
Person.drop_collection()
- error = ("validation argument for `name` must not return anything, "
- "it should raise a ValidationError if validation fails")
+ error = (
+ "validation argument for `name` must not return anything, "
+ "it should raise a ValidationError if validation fails"
+ )
with self.assertRaises(DeprecatedError) as ctx_err:
Person(name="").validate()
@@ -81,7 +107,7 @@ class FieldTest(MongoDBTestCase):
def test_custom_field_validation_raise_validation_error(self):
def _not_empty(z):
if not z:
- raise ValidationError('cantbeempty')
+ raise ValidationError("cantbeempty")
class Person(Document):
name = StringField(validation=_not_empty)
@@ -90,11 +116,17 @@ class FieldTest(MongoDBTestCase):
with self.assertRaises(ValidationError) as ctx_err:
Person(name="").validate()
- self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception))
+ self.assertEqual(
+ "ValidationError (Person:None) (cantbeempty: ['name'])",
+ str(ctx_err.exception),
+ )
with self.assertRaises(ValidationError):
Person(name="").save()
- self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception))
+ self.assertEqual(
+ "ValidationError (Person:None) (cantbeempty: ['name'])",
+ str(ctx_err.exception),
+ )
Person(name="garbage").validate()
Person(name="garbage").save()
@@ -103,10 +135,11 @@ class FieldTest(MongoDBTestCase):
"""Ensure that default field values are used even when
we explcitly initialize the doc with None values.
"""
+
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
- userid = StringField(default=lambda: 'test', required=True)
+ userid = StringField(default=lambda: "test", required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
# Trying setting values to None
@@ -114,7 +147,7 @@ class FieldTest(MongoDBTestCase):
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
self.assertTrue(person.validate() is None)
@@ -123,23 +156,24 @@ class FieldTest(MongoDBTestCase):
self.assertEqual(person.userid, person.userid)
self.assertEqual(person.created, person.created)
- self.assertEqual(person._data['name'], person.name)
- self.assertEqual(person._data['age'], person.age)
- self.assertEqual(person._data['userid'], person.userid)
- self.assertEqual(person._data['created'], person.created)
+ self.assertEqual(person._data["name"], person.name)
+ self.assertEqual(person._data["age"], person.age)
+ self.assertEqual(person._data["userid"], person.userid)
+ self.assertEqual(person._data["created"], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
def test_default_values_when_setting_to_None(self):
"""Ensure that default field values are used when creating
a document.
"""
+
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
- userid = StringField(default=lambda: 'test', required=True)
+ userid = StringField(default=lambda: "test", required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
person = Person()
@@ -150,25 +184,27 @@ class FieldTest(MongoDBTestCase):
# Confirm saving now would store values
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, None)
self.assertEqual(person.age, 30)
- self.assertEqual(person.userid, 'test')
+ self.assertEqual(person.userid, "test")
self.assertIsInstance(person.created, datetime.datetime)
- self.assertEqual(person._data['name'], person.name)
- self.assertEqual(person._data['age'], person.age)
- self.assertEqual(person._data['userid'], person.userid)
- self.assertEqual(person._data['created'], person.created)
+ self.assertEqual(person._data["name"], person.name)
+ self.assertEqual(person._data["age"], person.age)
+ self.assertEqual(person._data["userid"], person.userid)
+ self.assertEqual(person._data["created"], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
- def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc(self):
+ def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc(
+ self
+ ):
"""List field with default can be set to the empty list (strict)"""
# Issue #1733
class Doc(Document):
@@ -180,7 +216,9 @@ class FieldTest(MongoDBTestCase):
reloaded = Doc.objects.get(id=doc.id)
self.assertEqual(reloaded.x, [])
- def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc(self):
+ def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc(
+ self
+ ):
"""List field with default can be set to the empty list (dynamic)"""
# Issue #1733
class Doc(DynamicDocument):
@@ -188,7 +226,7 @@ class FieldTest(MongoDBTestCase):
doc = Doc(x=[1]).save()
doc.x = []
- doc.y = 2 # Was triggering the bug
+ doc.y = 2 # Was triggering the bug
doc.save()
reloaded = Doc.objects.get(id=doc.id)
self.assertEqual(reloaded.x, [])
@@ -197,41 +235,47 @@ class FieldTest(MongoDBTestCase):
"""Ensure that default field values are used after non-default
values are explicitly deleted.
"""
+
class Person(Document):
name = StringField()
age = IntField(default=30, required=False)
- userid = StringField(default=lambda: 'test', required=True)
+ userid = StringField(default=lambda: "test", required=True)
created = DateTimeField(default=datetime.datetime.utcnow)
- person = Person(name="Ross", age=50, userid='different',
- created=datetime.datetime(2014, 6, 12))
+ person = Person(
+ name="Ross",
+ age=50,
+ userid="different",
+ created=datetime.datetime(2014, 6, 12),
+ )
del person.name
del person.age
del person.userid
del person.created
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
self.assertTrue(person.validate() is None)
self.assertEqual(person.name, None)
self.assertEqual(person.age, 30)
- self.assertEqual(person.userid, 'test')
+ self.assertEqual(person.userid, "test")
self.assertIsInstance(person.created, datetime.datetime)
self.assertNotEqual(person.created, datetime.datetime(2014, 6, 12))
- self.assertEqual(person._data['name'], person.name)
- self.assertEqual(person._data['age'], person.age)
- self.assertEqual(person._data['userid'], person.userid)
- self.assertEqual(person._data['created'], person.created)
+ self.assertEqual(person._data["name"], person.name)
+ self.assertEqual(person._data["age"], person.age)
+ self.assertEqual(person._data["userid"], person.userid)
+ self.assertEqual(person._data["created"], person.created)
# Confirm introspection changes nothing
data_to_be_saved = sorted(person.to_mongo().keys())
- self.assertEqual(data_to_be_saved, ['age', 'created', 'userid'])
+ self.assertEqual(data_to_be_saved, ["age", "created", "userid"])
def test_required_values(self):
"""Ensure that required field constraints are enforced."""
+
class Person(Document):
name = StringField(required=True)
age = IntField(required=True)
@@ -246,6 +290,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that every fields should accept None if required is
False.
"""
+
class HandleNoneFields(Document):
str_fld = StringField()
int_fld = IntField()
@@ -255,7 +300,7 @@ class FieldTest(MongoDBTestCase):
HandleNoneFields.drop_collection()
doc = HandleNoneFields()
- doc.str_fld = u'spam ham egg'
+ doc.str_fld = u"spam ham egg"
doc.int_fld = 42
doc.flt_fld = 4.2
doc.com_dt_fld = datetime.datetime.utcnow()
@@ -281,6 +326,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that every field can handle null values from the
database.
"""
+
class HandleNoneFields(Document):
str_fld = StringField(required=True)
int_fld = IntField(required=True)
@@ -290,21 +336,17 @@ class FieldTest(MongoDBTestCase):
HandleNoneFields.drop_collection()
doc = HandleNoneFields()
- doc.str_fld = u'spam ham egg'
+ doc.str_fld = u"spam ham egg"
doc.int_fld = 42
doc.flt_fld = 4.2
doc.comp_dt_fld = datetime.datetime.utcnow()
doc.save()
# Unset all the fields
- obj = HandleNoneFields._get_collection().update({"_id": doc.id}, {
- "$unset": {
- "str_fld": 1,
- "int_fld": 1,
- "flt_fld": 1,
- "comp_dt_fld": 1
- }
- })
+ obj = HandleNoneFields._get_collection().update(
+ {"_id": doc.id},
+ {"$unset": {"str_fld": 1, "int_fld": 1, "flt_fld": 1, "comp_dt_fld": 1}},
+ )
# Retrive data from db and verify it.
ret = HandleNoneFields.objects.first()
@@ -321,16 +363,17 @@ class FieldTest(MongoDBTestCase):
"""Ensure that invalid values cannot be assigned to an
ObjectIdField.
"""
+
class Person(Document):
name = StringField()
- person = Person(name='Test User')
+ person = Person(name="Test User")
self.assertEqual(person.id, None)
person.id = 47
self.assertRaises(ValidationError, person.validate)
- person.id = 'abc'
+ person.id = "abc"
self.assertRaises(ValidationError, person.validate)
person.id = str(ObjectId())
@@ -338,26 +381,27 @@ class FieldTest(MongoDBTestCase):
def test_string_validation(self):
"""Ensure that invalid values cannot be assigned to string fields."""
+
class Person(Document):
name = StringField(max_length=20)
- userid = StringField(r'[0-9a-z_]+$')
+ userid = StringField(r"[0-9a-z_]+$")
person = Person(name=34)
self.assertRaises(ValidationError, person.validate)
# Test regex validation on userid
- person = Person(userid='test.User')
+ person = Person(userid="test.User")
self.assertRaises(ValidationError, person.validate)
- person.userid = 'test_user'
- self.assertEqual(person.userid, 'test_user')
+ person.userid = "test_user"
+ self.assertEqual(person.userid, "test_user")
person.validate()
# Test max length validation on name
- person = Person(name='Name that is more than twenty characters')
+ person = Person(name="Name that is more than twenty characters")
self.assertRaises(ValidationError, person.validate)
- person.name = 'Shorter name'
+ person.name = "Shorter name"
person.validate()
def test_db_field_validation(self):
@@ -365,25 +409,28 @@ class FieldTest(MongoDBTestCase):
# dot in the name
with self.assertRaises(ValueError):
+
class User(Document):
- name = StringField(db_field='user.name')
+ name = StringField(db_field="user.name")
# name starting with $
with self.assertRaises(ValueError):
+
class User(Document):
- name = StringField(db_field='$name')
+ name = StringField(db_field="$name")
# name containing a null character
with self.assertRaises(ValueError):
+
class User(Document):
- name = StringField(db_field='name\0')
+ name = StringField(db_field="name\0")
def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements."""
access_level_choices = (
- ('a', u'Administration'),
- ('b', u'Manager'),
- ('c', u'Staff'),
+ ("a", u"Administration"),
+ ("b", u"Manager"),
+ ("c", u"Staff"),
)
class User(Document):
@@ -400,41 +447,41 @@ class FieldTest(MongoDBTestCase):
authors_as_lazy = ListField(LazyReferenceField(User))
generic = ListField(GenericReferenceField())
generic_as_lazy = ListField(GenericLazyReferenceField())
- access_list = ListField(choices=access_level_choices, display_sep=', ')
+ access_list = ListField(choices=access_level_choices, display_sep=", ")
User.drop_collection()
BlogPost.drop_collection()
- post = BlogPost(content='Went for a walk today...')
+ post = BlogPost(content="Went for a walk today...")
post.validate()
- post.tags = 'fun'
+ post.tags = "fun"
self.assertRaises(ValidationError, post.validate)
post.tags = [1, 2]
self.assertRaises(ValidationError, post.validate)
- post.tags = ['fun', 'leisure']
+ post.tags = ["fun", "leisure"]
post.validate()
- post.tags = ('fun', 'leisure')
+ post.tags = ("fun", "leisure")
post.validate()
- post.access_list = 'a,b'
+ post.access_list = "a,b"
self.assertRaises(ValidationError, post.validate)
- post.access_list = ['c', 'd']
+ post.access_list = ["c", "d"]
self.assertRaises(ValidationError, post.validate)
- post.access_list = ['a', 'b']
+ post.access_list = ["a", "b"]
post.validate()
- self.assertEqual(post.get_access_list_display(), u'Administration, Manager')
+ self.assertEqual(post.get_access_list_display(), u"Administration, Manager")
- post.comments = ['a']
+ post.comments = ["a"]
self.assertRaises(ValidationError, post.validate)
- post.comments = 'yay'
+ post.comments = "yay"
self.assertRaises(ValidationError, post.validate)
- comments = [Comment(content='Good for you'), Comment(content='Yay.')]
+ comments = [Comment(content="Good for you"), Comment(content="Yay.")]
post.comments = comments
post.validate()
@@ -485,28 +532,28 @@ class FieldTest(MongoDBTestCase):
def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values.
"""
+
class Comment(EmbeddedDocument):
order = IntField()
content = StringField()
class BlogPost(Document):
content = StringField()
- comments = SortedListField(EmbeddedDocumentField(Comment),
- ordering='order')
+ comments = SortedListField(EmbeddedDocumentField(Comment), ordering="order")
tags = SortedListField(StringField())
BlogPost.drop_collection()
- post = BlogPost(content='Went for a walk today...')
+ post = BlogPost(content="Went for a walk today...")
post.save()
- post.tags = ['leisure', 'fun']
+ post.tags = ["leisure", "fun"]
post.save()
post.reload()
- self.assertEqual(post.tags, ['fun', 'leisure'])
+ self.assertEqual(post.tags, ["fun", "leisure"])
- comment1 = Comment(content='Good for you', order=1)
- comment2 = Comment(content='Yay.', order=0)
+ comment1 = Comment(content="Good for you", order=1)
+ comment2 = Comment(content="Yay.", order=0)
comments = [comment1, comment2]
post.comments = comments
post.save()
@@ -529,16 +576,17 @@ class FieldTest(MongoDBTestCase):
name = StringField()
class CategoryList(Document):
- categories = SortedListField(EmbeddedDocumentField(Category),
- ordering='count', reverse=True)
+ categories = SortedListField(
+ EmbeddedDocumentField(Category), ordering="count", reverse=True
+ )
name = StringField()
CategoryList.drop_collection()
catlist = CategoryList(name="Top categories")
- cat1 = Category(name='posts', count=10)
- cat2 = Category(name='food', count=100)
- cat3 = Category(name='drink', count=40)
+ cat1 = Category(name="posts", count=10)
+ cat2 = Category(name="food", count=100)
+ cat3 = Category(name="drink", count=40)
catlist.categories = [cat1, cat2, cat3]
catlist.save()
catlist.reload()
@@ -549,57 +597,59 @@ class FieldTest(MongoDBTestCase):
def test_list_field(self):
"""Ensure that list types work as expected."""
+
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
- post.info = 'my post'
+ post.info = "my post"
self.assertRaises(ValidationError, post.validate)
- post.info = {'title': 'test'}
+ post.info = {"title": "test"}
self.assertRaises(ValidationError, post.validate)
- post.info = ['test']
+ post.info = ["test"]
post.save()
post = BlogPost()
- post.info = [{'test': 'test'}]
+ post.info = [{"test": "test"}]
post.save()
post = BlogPost()
- post.info = [{'test': 3}]
+ post.info = [{"test": 3}]
post.save()
self.assertEqual(BlogPost.objects.count(), 3)
- self.assertEqual(
- BlogPost.objects.filter(info__exact='test').count(), 1)
- self.assertEqual(
- BlogPost.objects.filter(info__0__test='test').count(), 1)
+ self.assertEqual(BlogPost.objects.filter(info__exact="test").count(), 1)
+ self.assertEqual(BlogPost.objects.filter(info__0__test="test").count(), 1)
# Confirm handles non strings or non existing keys
+ self.assertEqual(BlogPost.objects.filter(info__0__test__exact="5").count(), 0)
self.assertEqual(
- BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
- self.assertEqual(
- BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
+ BlogPost.objects.filter(info__100__test__exact="test").count(), 0
+ )
# test queries by list
post = BlogPost()
- post.info = ['1', '2']
+ post.info = ["1", "2"]
post.save()
- post = BlogPost.objects(info=['1', '2']).get()
- post.info += ['3', '4']
+ post = BlogPost.objects(info=["1", "2"]).get()
+ post.info += ["3", "4"]
post.save()
- self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4']).count(), 1)
- post = BlogPost.objects(info=['1', '2', '3', '4']).get()
+ self.assertEqual(BlogPost.objects(info=["1", "2", "3", "4"]).count(), 1)
+ post = BlogPost.objects(info=["1", "2", "3", "4"]).get()
post.info *= 2
post.save()
- self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1)
+ self.assertEqual(
+ BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count(), 1
+ )
def test_list_field_manipulative_operators(self):
"""Ensure that ListField works with standard list operators that manipulate the list.
"""
+
class BlogPost(Document):
ref = StringField()
info = ListField(StringField())
@@ -608,162 +658,178 @@ class FieldTest(MongoDBTestCase):
post = BlogPost()
post.ref = "1234"
- post.info = ['0', '1', '2', '3', '4', '5']
+ post.info = ["0", "1", "2", "3", "4", "5"]
post.save()
def reset_post():
- post.info = ['0', '1', '2', '3', '4', '5']
+ post.info = ["0", "1", "2", "3", "4", "5"]
post.save()
# '__add__(listB)'
# listA+listB
# operator.add(listA, listB)
reset_post()
- temp = ['a', 'b']
+ temp = ["a", "b"]
post.info = post.info + temp
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"])
# '__delitem__(index)'
# aka 'del list[index]'
# aka 'operator.delitem(list, index)'
reset_post()
del post.info[2] # del from middle ('2')
- self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "1", "3", "4", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "1", "3", "4", "5"])
# '__delitem__(slice(i, j))'
# aka 'del list[i:j]'
# aka 'operator.delitem(list, slice(i,j))'
reset_post()
del post.info[1:3] # removes '1', '2'
- self.assertEqual(post.info, ['0', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "3", "4", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "3", "4", "5"])
# '__iadd__'
# aka 'list += list'
reset_post()
- temp = ['a', 'b']
+ temp = ["a", "b"]
post.info += temp
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"])
# '__imul__'
# aka 'list *= number'
reset_post()
post.info *= 2
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
# '__mul__'
# aka 'listA*listB'
reset_post()
post.info = post.info * 2
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
# '__rmul__'
# aka 'listB*listA'
reset_post()
post.info = 2 * post.info
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"]
+ )
# '__setitem__(index, value)'
# aka 'list[index]=value'
# aka 'setitem(list, value)'
reset_post()
- post.info[4] = 'a'
- self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
+ post.info[4] = "a"
+ self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"])
# __setitem__(index, value) with a negative index
reset_post()
- post.info[-2] = 'a'
- self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
+ post.info[-2] = "a"
+ self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"])
# '__setitem__(slice(i, j), listB)'
# aka 'listA[i:j] = listB'
# aka 'setitem(listA, slice(i, j), listB)'
reset_post()
- post.info[1:3] = ['h', 'e', 'l', 'l', 'o']
- self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
+ post.info[1:3] = ["h", "e", "l", "l", "o"]
+ self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"])
# '__setitem__(slice(i, j), listB)' with negative i and j
reset_post()
- post.info[-5:-3] = ['h', 'e', 'l', 'l', 'o']
- self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
+ post.info[-5:-3] = ["h", "e", "l", "l", "o"]
+ self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"])
# negative
# 'append'
reset_post()
- post.info.append('h')
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h'])
+ post.info.append("h")
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h'])
+ self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"])
# 'extend'
reset_post()
- post.info.extend(['h', 'e', 'l', 'l', 'o'])
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o'])
+ post.info.extend(["h", "e", "l", "l", "o"])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"]
+ )
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o'])
+ self.assertEqual(
+ post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"]
+ )
# 'insert'
# 'pop'
reset_post()
x = post.info.pop(2)
y = post.info.pop()
- self.assertEqual(post.info, ['0', '1', '3', '4'])
- self.assertEqual(x, '2')
- self.assertEqual(y, '5')
+ self.assertEqual(post.info, ["0", "1", "3", "4"])
+ self.assertEqual(x, "2")
+ self.assertEqual(y, "5")
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '3', '4'])
+ self.assertEqual(post.info, ["0", "1", "3", "4"])
# 'remove'
reset_post()
- post.info.remove('2')
- self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
+ post.info.remove("2")
+ self.assertEqual(post.info, ["0", "1", "3", "4", "5"])
post.save()
post.reload()
- self.assertEqual(post.info, ['0', '1', '3', '4', '5'])
+ self.assertEqual(post.info, ["0", "1", "3", "4", "5"])
# 'reverse'
reset_post()
post.info.reverse()
- self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0'])
+ self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"])
post.save()
post.reload()
- self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0'])
+ self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"])
# 'sort': though this operator method does manipulate the list, it is
# tested in the 'test_list_field_lexicograpic_operators' function
@@ -775,7 +841,7 @@ class FieldTest(MongoDBTestCase):
post = BlogPost()
post.ref = "1234"
- post.info = ['0', '1', '2', '3', '4', '5']
+ post.info = ["0", "1", "2", "3", "4", "5"]
# '__hash__'
# aka 'hash(list)'
@@ -785,6 +851,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that ListField works with standard list operators that
do lexigraphic ordering.
"""
+
class BlogPost(Document):
ref = StringField()
text_info = ListField(StringField())
@@ -810,7 +877,7 @@ class FieldTest(MongoDBTestCase):
blogLargeB.oid_info = [
"54495ad94c934721ede76f90",
"54495ad94c934721ede76d23",
- "54495ad94c934721ede76d00"
+ "54495ad94c934721ede76d00",
]
blogLargeB.bool_info = [False, True]
blogLargeB.save()
@@ -852,7 +919,7 @@ class FieldTest(MongoDBTestCase):
sorted_target_list = [
ObjectId("54495ad94c934721ede76d00"),
ObjectId("54495ad94c934721ede76d23"),
- ObjectId("54495ad94c934721ede76f90")
+ ObjectId("54495ad94c934721ede76f90"),
]
self.assertEqual(blogLargeB.text_info, ["a", "j", "z"])
self.assertEqual(blogLargeB.oid_info, sorted_target_list)
@@ -865,13 +932,14 @@ class FieldTest(MongoDBTestCase):
def test_list_assignment(self):
"""Ensure that list field element assignment and slicing work."""
+
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
- post.info = ['e1', 'e2', 3, '4', 5]
+ post.info = ["e1", "e2", 3, "4", 5]
post.save()
post.info[0] = 1
@@ -879,35 +947,35 @@ class FieldTest(MongoDBTestCase):
post.reload()
self.assertEqual(post.info[0], 1)
- post.info[1:3] = ['n2', 'n3']
+ post.info[1:3] = ["n2", "n3"]
post.save()
post.reload()
- self.assertEqual(post.info, [1, 'n2', 'n3', '4', 5])
+ self.assertEqual(post.info, [1, "n2", "n3", "4", 5])
- post.info[-1] = 'n5'
+ post.info[-1] = "n5"
post.save()
post.reload()
- self.assertEqual(post.info, [1, 'n2', 'n3', '4', 'n5'])
+ self.assertEqual(post.info, [1, "n2", "n3", "4", "n5"])
post.info[-2] = 4
post.save()
post.reload()
- self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5'])
+ self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"])
post.info[1:-1] = [2]
post.save()
post.reload()
- self.assertEqual(post.info, [1, 2, 'n5'])
+ self.assertEqual(post.info, [1, 2, "n5"])
- post.info[:-1] = [1, 'n2', 'n3', 4]
+ post.info[:-1] = [1, "n2", "n3", 4]
post.save()
post.reload()
- self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5'])
+ self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"])
post.info[-4:3] = [2, 3]
post.save()
post.reload()
- self.assertEqual(post.info, [1, 2, 3, 4, 'n5'])
+ self.assertEqual(post.info, [1, 2, 3, 4, "n5"])
def test_list_field_passed_in_value(self):
class Foo(Document):
@@ -921,12 +989,13 @@ class FieldTest(MongoDBTestCase):
foo = Foo(bars=[])
foo.bars.append(bar)
- self.assertEqual(repr(foo.bars), '[]')
+ self.assertEqual(repr(foo.bars), "[]")
def test_list_field_strict(self):
"""Ensure that list field handles validation if provided
a strict field type.
"""
+
class Simple(Document):
mapping = ListField(field=IntField())
@@ -943,17 +1012,19 @@ class FieldTest(MongoDBTestCase):
def test_list_field_rejects_strings(self):
"""Strings aren't valid list field data types."""
+
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple()
- e.mapping = 'hello world'
+ e.mapping = "hello world"
self.assertRaises(ValidationError, e.save)
def test_complex_field_required(self):
"""Ensure required cant be None / Empty."""
+
class Simple(Document):
mapping = ListField(required=True)
@@ -975,6 +1046,7 @@ class FieldTest(MongoDBTestCase):
"""If a complex field is set to the same value, it should not
be marked as changed.
"""
+
class Simple(Document):
mapping = ListField()
@@ -999,7 +1071,7 @@ class FieldTest(MongoDBTestCase):
simple = Simple(widgets=[1, 2, 3, 4]).save()
simple.widgets[:3] = []
- self.assertEqual(['widgets'], simple._changed_fields)
+ self.assertEqual(["widgets"], simple._changed_fields)
simple.save()
simple = simple.reload()
@@ -1011,7 +1083,7 @@ class FieldTest(MongoDBTestCase):
simple = Simple(widgets=[1, 2, 3, 4]).save()
del simple.widgets[:3]
- self.assertEqual(['widgets'], simple._changed_fields)
+ self.assertEqual(["widgets"], simple._changed_fields)
simple.save()
simple = simple.reload()
@@ -1023,7 +1095,7 @@ class FieldTest(MongoDBTestCase):
simple = Simple(widgets=[1, 2, 3, 4]).save()
simple.widgets[-1] = 5
- self.assertEqual(['widgets.3'], simple._changed_fields)
+ self.assertEqual(["widgets.3"], simple._changed_fields)
simple.save()
simple = simple.reload()
@@ -1031,8 +1103,9 @@ class FieldTest(MongoDBTestCase):
def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types."""
+
class SettingBase(EmbeddedDocument):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class StringSetting(SettingBase):
value = StringField()
@@ -1046,12 +1119,17 @@ class FieldTest(MongoDBTestCase):
Simple.drop_collection()
e = Simple()
- e.mapping.append(StringSetting(value='foo'))
+ e.mapping.append(StringSetting(value="foo"))
e.mapping.append(IntegerSetting(value=42))
- e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
- 'complex': IntegerSetting(value=42),
- 'list': [IntegerSetting(value=42),
- StringSetting(value='foo')]})
+ e.mapping.append(
+ {
+ "number": 1,
+ "string": "Hi!",
+ "float": 1.001,
+ "complex": IntegerSetting(value=42),
+ "list": [IntegerSetting(value=42), StringSetting(value="foo")],
+ }
+ )
e.save()
e2 = Simple.objects.get(id=e.id)
@@ -1059,35 +1137,36 @@ class FieldTest(MongoDBTestCase):
self.assertIsInstance(e2.mapping[1], IntegerSetting)
# Test querying
+ self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1)
+ self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEqual(
- Simple.objects.filter(mapping__1__value=42).count(), 1)
+ Simple.objects.filter(mapping__2__complex__value=42).count(), 1
+ )
self.assertEqual(
- Simple.objects.filter(mapping__2__number=1).count(), 1)
+ Simple.objects.filter(mapping__2__list__0__value=42).count(), 1
+ )
self.assertEqual(
- Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
- self.assertEqual(
- Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
- self.assertEqual(
- Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
+ Simple.objects.filter(mapping__2__list__1__value="foo").count(), 1
+ )
# Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
- self.assertEqual(
- Simple.objects.filter(mapping__1__value=10).count(), 1)
+ self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1)
- Simple.objects().update(
- set__mapping__2__list__1=StringSetting(value='Boo'))
+ Simple.objects().update(set__mapping__2__list__1=StringSetting(value="Boo"))
self.assertEqual(
- Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
+ Simple.objects.filter(mapping__2__list__1__value="foo").count(), 0
+ )
self.assertEqual(
- Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
+ Simple.objects.filter(mapping__2__list__1__value="Boo").count(), 1
+ )
def test_embedded_db_field(self):
class Embedded(EmbeddedDocument):
- number = IntField(default=0, db_field='i')
+ number = IntField(default=0, db_field="i")
class Test(Document):
- embedded = EmbeddedDocumentField(Embedded, db_field='x')
+ embedded = EmbeddedDocumentField(Embedded, db_field="x")
Test.drop_collection()
@@ -1100,58 +1179,54 @@ class FieldTest(MongoDBTestCase):
test = Test.objects.get()
self.assertEqual(test.embedded.number, 2)
doc = self.db.test.find_one()
- self.assertEqual(doc['x']['i'], 2)
+ self.assertEqual(doc["x"]["i"], 2)
def test_double_embedded_db_field(self):
"""Make sure multiple layers of embedded docs resolve db fields
properly and can be initialized using dicts.
"""
+
class C(EmbeddedDocument):
txt = StringField()
class B(EmbeddedDocument):
- c = EmbeddedDocumentField(C, db_field='fc')
+ c = EmbeddedDocumentField(C, db_field="fc")
class A(Document):
- b = EmbeddedDocumentField(B, db_field='fb')
+ b = EmbeddedDocumentField(B, db_field="fb")
- a = A(
- b=B(
- c=C(txt='hi')
- )
- )
+ a = A(b=B(c=C(txt="hi")))
a.validate()
- a = A(b={'c': {'txt': 'hi'}})
+ a = A(b={"c": {"txt": "hi"}})
a.validate()
def test_double_embedded_db_field_from_son(self):
"""Make sure multiple layers of embedded docs resolve db fields
from SON properly.
"""
+
class C(EmbeddedDocument):
txt = StringField()
class B(EmbeddedDocument):
- c = EmbeddedDocumentField(C, db_field='fc')
+ c = EmbeddedDocumentField(C, db_field="fc")
class A(Document):
- b = EmbeddedDocumentField(B, db_field='fb')
+ b = EmbeddedDocumentField(B, db_field="fb")
- a = A._from_son(SON([
- ('fb', SON([
- ('fc', SON([
- ('txt', 'hi')
- ]))
- ]))
- ]))
- self.assertEqual(a.b.c.txt, 'hi')
+ a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))]))
+ self.assertEqual(a.b.c.txt, "hi")
- def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self):
- raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet")
+ def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(
+ self
+ ):
+ raise SkipTest(
+ "Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet"
+ )
class MyDoc2(Document):
- emb = EmbeddedDocumentField('MyDoc')
+ emb = EmbeddedDocumentField("MyDoc")
class MyDoc(EmbeddedDocument):
name = StringField()
@@ -1160,6 +1235,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that invalid embedded documents cannot be assigned to
embedded document fields.
"""
+
class Comment(EmbeddedDocument):
content = StringField()
@@ -1173,30 +1249,31 @@ class FieldTest(MongoDBTestCase):
Person.drop_collection()
- person = Person(name='Test User')
- person.preferences = 'My Preferences'
+ person = Person(name="Test User")
+ person.preferences = "My Preferences"
self.assertRaises(ValidationError, person.validate)
# Check that only the right embedded doc works
- person.preferences = Comment(content='Nice blog post...')
+ person.preferences = Comment(content="Nice blog post...")
self.assertRaises(ValidationError, person.validate)
# Check that the embedded doc is valid
person.preferences = PersonPreferences()
self.assertRaises(ValidationError, person.validate)
- person.preferences = PersonPreferences(food='Cheese', number=47)
- self.assertEqual(person.preferences.food, 'Cheese')
+ person.preferences = PersonPreferences(food="Cheese", number=47)
+ self.assertEqual(person.preferences.food, "Cheese")
person.validate()
def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided
to EmbeddedDocumentFields of the superclass' type.
"""
+
class User(EmbeddedDocument):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class PowerUser(User):
power = IntField()
@@ -1207,8 +1284,8 @@ class FieldTest(MongoDBTestCase):
BlogPost.drop_collection()
- post = BlogPost(content='What I did today...')
- post.author = PowerUser(name='Test User', power=47)
+ post = BlogPost(content="What I did today...")
+ post.author = PowerUser(name="Test User", power=47)
post.save()
self.assertEqual(47, BlogPost.objects.first().author.power)
@@ -1217,21 +1294,22 @@ class FieldTest(MongoDBTestCase):
"""Ensure that nested list of subclassed embedded documents is
handled correctly.
"""
+
class Group(EmbeddedDocument):
name = StringField()
content = ListField(StringField())
class Basedoc(Document):
groups = ListField(EmbeddedDocumentField(Group))
- meta = {'abstract': True}
+ meta = {"abstract": True}
class User(Basedoc):
- doctype = StringField(require=True, default='userdata')
+ doctype = StringField(require=True, default="userdata")
User.drop_collection()
- content = ['la', 'le', 'lu']
- group = Group(name='foo', content=content)
+ content = ["la", "le", "lu"]
+ group = Group(name="foo", content=content)
foobar = User(groups=[group])
foobar.save()
@@ -1241,6 +1319,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure an exception is raised when dereferencing an unknown
document.
"""
+
class Foo(Document):
pass
@@ -1257,20 +1336,21 @@ class FieldTest(MongoDBTestCase):
# Reference is no longer valid
foo.delete()
bar = Bar.objects.get()
- self.assertRaises(DoesNotExist, getattr, bar, 'ref')
- self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref')
+ self.assertRaises(DoesNotExist, getattr, bar, "ref")
+ self.assertRaises(DoesNotExist, getattr, bar, "generic_ref")
# When auto_dereference is disabled, there is no trouble returning DBRef
bar = Bar.objects.get()
expected = foo.to_dbref()
- bar._fields['ref']._auto_dereference = False
+ bar._fields["ref"]._auto_dereference = False
self.assertEqual(bar.ref, expected)
- bar._fields['generic_ref']._auto_dereference = False
- self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'})
+ bar._fields["generic_ref"]._auto_dereference = False
+ self.assertEqual(bar.generic_ref, {"_ref": expected, "_cls": "Foo"})
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -1280,9 +1360,9 @@ class FieldTest(MongoDBTestCase):
User.drop_collection()
Group.drop_collection()
- user1 = User(name='user1')
+ user1 = User(name="user1")
user1.save()
- user2 = User(name='user2')
+ user2 = User(name="user2")
user2.save()
group = Group(members=[user1, user2])
@@ -1296,24 +1376,25 @@ class FieldTest(MongoDBTestCase):
def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents.
"""
+
class Employee(Document):
name = StringField()
- boss = ReferenceField('self')
- friends = ListField(ReferenceField('self'))
+ boss = ReferenceField("self")
+ friends = ListField(ReferenceField("self"))
Employee.drop_collection()
- bill = Employee(name='Bill Lumbergh')
+ bill = Employee(name="Bill Lumbergh")
bill.save()
- michael = Employee(name='Michael Bolton')
+ michael = Employee(name="Michael Bolton")
michael.save()
- samir = Employee(name='Samir Nagheenanajar')
+ samir = Employee(name="Samir Nagheenanajar")
samir.save()
friends = [michael, samir]
- peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
+ peter = Employee(name="Peter Gibbons", boss=bill, friends=friends)
peter.save()
peter = Employee.objects.with_id(peter.id)
@@ -1323,13 +1404,14 @@ class FieldTest(MongoDBTestCase):
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
+
class TreeNode(EmbeddedDocument):
name = StringField()
- children = ListField(EmbeddedDocumentField('self'))
+ children = ListField(EmbeddedDocumentField("self"))
class Tree(Document):
name = StringField()
- children = ListField(EmbeddedDocumentField('TreeNode'))
+ children = ListField(EmbeddedDocumentField("TreeNode"))
Tree.drop_collection()
@@ -1356,18 +1438,18 @@ class FieldTest(MongoDBTestCase):
self.assertEqual(tree.children[0].children[1].name, third_child.name)
# Test updating
- tree.children[0].name = 'I am Child 1'
- tree.children[0].children[0].name = 'I am Child 2'
- tree.children[0].children[1].name = 'I am Child 3'
+ tree.children[0].name = "I am Child 1"
+ tree.children[0].children[0].name = "I am Child 2"
+ tree.children[0].children[1].name = "I am Child 3"
tree.save()
- self.assertEqual(tree.children[0].name, 'I am Child 1')
- self.assertEqual(tree.children[0].children[0].name, 'I am Child 2')
- self.assertEqual(tree.children[0].children[1].name, 'I am Child 3')
+ self.assertEqual(tree.children[0].name, "I am Child 1")
+ self.assertEqual(tree.children[0].children[0].name, "I am Child 2")
+ self.assertEqual(tree.children[0].children[1].name, "I am Child 3")
# Test removal
self.assertEqual(len(tree.children[0].children), 2)
- del(tree.children[0].children[1])
+ del tree.children[0].children[1]
tree.save()
self.assertEqual(len(tree.children[0].children), 1)
@@ -1388,6 +1470,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that an abstract document cannot be dropped given it
has no underlying collection.
"""
+
class AbstractDoc(Document):
name = StringField()
meta = {"abstract": True}
@@ -1397,6 +1480,7 @@ class FieldTest(MongoDBTestCase):
def test_reference_class_with_abstract_parent(self):
"""Ensure that a class with an abstract parent can be referenced.
"""
+
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
@@ -1421,6 +1505,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that an abstract class instance cannot be used in the
reference of that abstract class.
"""
+
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
@@ -1442,6 +1527,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure that an an abstract reference fails validation when given a
Document that does not inherit from the abstract type.
"""
+
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
@@ -1463,9 +1549,10 @@ class FieldTest(MongoDBTestCase):
def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items.
"""
+
class Link(Document):
title = StringField()
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class Post(Document):
title = StringField()
@@ -1502,6 +1589,7 @@ class FieldTest(MongoDBTestCase):
def test_generic_reference_list(self):
"""Ensure that a ListField properly dereferences generic references.
"""
+
class Link(Document):
title = StringField()
@@ -1533,6 +1621,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
+
class Link(Document):
title = StringField()
@@ -1550,7 +1639,7 @@ class FieldTest(MongoDBTestCase):
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
- del(_document_registry["Link"])
+ del _document_registry["Link"]
user = User.objects.first()
try:
@@ -1560,7 +1649,6 @@ class FieldTest(MongoDBTestCase):
pass
def test_generic_reference_is_none(self):
-
class Person(Document):
name = StringField()
city = GenericReferenceField()
@@ -1568,11 +1656,11 @@ class FieldTest(MongoDBTestCase):
Person.drop_collection()
Person(name="Wilson Jr").save()
- self.assertEqual(repr(Person.objects(city=None)),
- "[]")
+ self.assertEqual(repr(Person.objects(city=None)), "[]")
def test_generic_reference_choices(self):
"""Ensure that a GenericReferenceField can handle choices."""
+
class Link(Document):
title = StringField()
@@ -1604,6 +1692,7 @@ class FieldTest(MongoDBTestCase):
def test_generic_reference_string_choices(self):
"""Ensure that a GenericReferenceField can handle choices as strings
"""
+
class Link(Document):
title = StringField()
@@ -1611,7 +1700,7 @@ class FieldTest(MongoDBTestCase):
title = StringField()
class Bookmark(Document):
- bookmark_object = GenericReferenceField(choices=('Post', Link))
+ bookmark_object = GenericReferenceField(choices=("Post", Link))
Link.drop_collection()
Post.drop_collection()
@@ -1636,11 +1725,12 @@ class FieldTest(MongoDBTestCase):
"""Ensure that a GenericReferenceField can handle choices on
non-derefenreced (i.e. DBRef) elements
"""
+
class Post(Document):
title = StringField()
class Bookmark(Document):
- bookmark_object = GenericReferenceField(choices=(Post, ))
+ bookmark_object = GenericReferenceField(choices=(Post,))
other_field = StringField()
Post.drop_collection()
@@ -1654,13 +1744,14 @@ class FieldTest(MongoDBTestCase):
bm = Bookmark.objects.get(id=bm.id)
# bookmark_object is now a DBRef
- bm.other_field = 'dummy_change'
+ bm.other_field = "dummy_change"
bm.save()
def test_generic_reference_list_choices(self):
"""Ensure that a ListField properly dereferences generic references and
respects choices.
"""
+
class Link(Document):
title = StringField()
@@ -1692,6 +1783,7 @@ class FieldTest(MongoDBTestCase):
def test_generic_reference_list_item_modification(self):
"""Ensure that modifications of related documents (through generic reference) don't influence on querying
"""
+
class Post(Document):
title = StringField()
@@ -1721,6 +1813,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure we can search for a specific generic reference by
providing its ObjectId.
"""
+
class Doc(Document):
ref = GenericReferenceField()
@@ -1729,13 +1822,14 @@ class FieldTest(MongoDBTestCase):
doc1 = Doc.objects.create()
doc2 = Doc.objects.create(ref=doc1)
- doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
+ doc = Doc.objects.get(ref=DBRef("doc", doc1.pk))
self.assertEqual(doc, doc2)
def test_generic_reference_is_not_tracked_in_parent_doc(self):
"""Ensure that modifications of related documents (through generic reference) don't influence
the owner changed fields (#1934)
"""
+
class Doc1(Document):
name = StringField()
@@ -1746,14 +1840,14 @@ class FieldTest(MongoDBTestCase):
Doc1.drop_collection()
Doc2.drop_collection()
- doc1 = Doc1(name='garbage1').save()
- doc11 = Doc1(name='garbage11').save()
+ doc1 = Doc1(name="garbage1").save()
+ doc11 = Doc1(name="garbage11").save()
doc2 = Doc2(ref=doc1, refs=[doc11]).save()
- doc2.ref.name = 'garbage2'
+ doc2.ref.name = "garbage2"
self.assertEqual(doc2._get_changed_fields(), [])
- doc2.refs[0].name = 'garbage3'
+ doc2.refs[0].name = "garbage3"
self.assertEqual(doc2._get_changed_fields(), [])
self.assertEqual(doc2._delta(), ({}, {}))
@@ -1761,6 +1855,7 @@ class FieldTest(MongoDBTestCase):
"""Ensure we can search for a specific generic reference by
providing its DBRef.
"""
+
class Doc(Document):
ref = GenericReferenceField()
@@ -1777,17 +1872,19 @@ class FieldTest(MongoDBTestCase):
def test_choices_allow_using_sets_as_choices(self):
"""Ensure that sets can be used when setting choices
"""
- class Shirt(Document):
- size = StringField(choices={'M', 'L'})
- Shirt(size='M').validate()
+ class Shirt(Document):
+ size = StringField(choices={"M", "L"})
+
+ Shirt(size="M").validate()
def test_choices_validation_allow_no_value(self):
"""Ensure that .validate passes and no value was provided
for a field setup with choices
"""
+
class Shirt(Document):
- size = StringField(choices=('S', 'M'))
+ size = StringField(choices=("S", "M"))
shirt = Shirt()
shirt.validate()
@@ -1795,17 +1892,19 @@ class FieldTest(MongoDBTestCase):
def test_choices_validation_accept_possible_value(self):
"""Ensure that value is in a container of allowed values.
"""
- class Shirt(Document):
- size = StringField(choices=('S', 'M'))
- shirt = Shirt(size='S')
+ class Shirt(Document):
+ size = StringField(choices=("S", "M"))
+
+ shirt = Shirt(size="S")
shirt.validate()
def test_choices_validation_reject_unknown_value(self):
"""Ensure that unallowed value are rejected upon validation
"""
+
class Shirt(Document):
- size = StringField(choices=('S', 'M'))
+ size = StringField(choices=("S", "M"))
shirt = Shirt(size="XS")
with self.assertRaises(ValidationError):
@@ -1815,12 +1914,23 @@ class FieldTest(MongoDBTestCase):
"""Test dynamic helper for returning the display value of a choices
field.
"""
+
class Shirt(Document):
- size = StringField(max_length=3, choices=(
- ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
- ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
- style = StringField(max_length=3, choices=(
- ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
+ size = StringField(
+ max_length=3,
+ choices=(
+ ("S", "Small"),
+ ("M", "Medium"),
+ ("L", "Large"),
+ ("XL", "Extra Large"),
+ ("XXL", "Extra Extra Large"),
+ ),
+ )
+ style = StringField(
+ max_length=3,
+ choices=(("S", "Small"), ("B", "Baggy"), ("W", "Wide")),
+ default="W",
+ )
Shirt.drop_collection()
@@ -1829,30 +1939,30 @@ class FieldTest(MongoDBTestCase):
# Make sure get__display returns the default value (or None)
self.assertEqual(shirt1.get_size_display(), None)
- self.assertEqual(shirt1.get_style_display(), 'Wide')
+ self.assertEqual(shirt1.get_style_display(), "Wide")
- shirt1.size = 'XXL'
- shirt1.style = 'B'
- shirt2.size = 'M'
- shirt2.style = 'S'
- self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
- self.assertEqual(shirt1.get_style_display(), 'Baggy')
- self.assertEqual(shirt2.get_size_display(), 'Medium')
- self.assertEqual(shirt2.get_style_display(), 'Small')
+ shirt1.size = "XXL"
+ shirt1.style = "B"
+ shirt2.size = "M"
+ shirt2.style = "S"
+ self.assertEqual(shirt1.get_size_display(), "Extra Extra Large")
+ self.assertEqual(shirt1.get_style_display(), "Baggy")
+ self.assertEqual(shirt2.get_size_display(), "Medium")
+ self.assertEqual(shirt2.get_style_display(), "Small")
# Set as Z - an invalid choice
- shirt1.size = 'Z'
- shirt1.style = 'Z'
- self.assertEqual(shirt1.get_size_display(), 'Z')
- self.assertEqual(shirt1.get_style_display(), 'Z')
+ shirt1.size = "Z"
+ shirt1.style = "Z"
+ self.assertEqual(shirt1.get_size_display(), "Z")
+ self.assertEqual(shirt1.get_style_display(), "Z")
self.assertRaises(ValidationError, shirt1.validate)
def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.
"""
+
class Shirt(Document):
- size = StringField(max_length=3,
- choices=('S', 'M', 'L', 'XL', 'XXL'))
+ size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL"))
Shirt.drop_collection()
@@ -1869,37 +1979,37 @@ class FieldTest(MongoDBTestCase):
"""Test dynamic helper for returning the display value of a choices
field.
"""
+
class Shirt(Document):
- size = StringField(max_length=3,
- choices=('S', 'M', 'L', 'XL', 'XXL'))
- style = StringField(max_length=3,
- choices=('Small', 'Baggy', 'wide'),
- default='Small')
+ size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL"))
+ style = StringField(
+ max_length=3, choices=("Small", "Baggy", "wide"), default="Small"
+ )
Shirt.drop_collection()
shirt = Shirt()
self.assertEqual(shirt.get_size_display(), None)
- self.assertEqual(shirt.get_style_display(), 'Small')
+ self.assertEqual(shirt.get_style_display(), "Small")
shirt.size = "XXL"
shirt.style = "Baggy"
- self.assertEqual(shirt.get_size_display(), 'XXL')
- self.assertEqual(shirt.get_style_display(), 'Baggy')
+ self.assertEqual(shirt.get_size_display(), "XXL")
+ self.assertEqual(shirt.get_style_display(), "Baggy")
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
- self.assertEqual(shirt.get_size_display(), 'Z')
- self.assertEqual(shirt.get_style_display(), 'Z')
+ self.assertEqual(shirt.get_size_display(), "Z")
+ self.assertEqual(shirt.get_style_display(), "Z")
self.assertRaises(ValidationError, shirt.validate)
def test_simple_choices_validation_invalid_value(self):
"""Ensure that error messages are correct.
"""
- SIZES = ('S', 'M', 'L', 'XL', 'XXL')
- COLORS = (('R', 'Red'), ('B', 'Blue'))
+ SIZES = ("S", "M", "L", "XL", "XXL")
+ COLORS = (("R", "Red"), ("B", "Blue"))
SIZE_MESSAGE = u"Value must be one of ('S', 'M', 'L', 'XL', 'XXL')"
COLOR_MESSAGE = u"Value must be one of ['R', 'B']"
@@ -1924,11 +2034,12 @@ class FieldTest(MongoDBTestCase):
except ValidationError as error:
# get the validation rules
error_dict = error.to_dict()
- self.assertEqual(error_dict['size'], SIZE_MESSAGE)
- self.assertEqual(error_dict['color'], COLOR_MESSAGE)
+ self.assertEqual(error_dict["size"], SIZE_MESSAGE)
+ self.assertEqual(error_dict["color"], COLOR_MESSAGE)
def test_recursive_validation(self):
"""Ensure that a validation result to_dict is available."""
+
class Author(EmbeddedDocument):
name = StringField(required=True)
@@ -1940,9 +2051,9 @@ class FieldTest(MongoDBTestCase):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
- bob = Author(name='Bob')
- post = Post(title='hello world')
- post.comments.append(Comment(content='hello', author=bob))
+ bob = Author(name="Bob")
+ post = Post(title="hello world")
+ post.comments.append(Comment(content="hello", author=bob))
post.comments.append(Comment(author=bob))
self.assertRaises(ValidationError, post.validate)
@@ -1950,30 +2061,31 @@ class FieldTest(MongoDBTestCase):
post.validate()
except ValidationError as error:
# ValidationError.errors property
- self.assertTrue(hasattr(error, 'errors'))
+ self.assertTrue(hasattr(error, "errors"))
self.assertIsInstance(error.errors, dict)
- self.assertIn('comments', error.errors)
- self.assertIn(1, error.errors['comments'])
- self.assertIsInstance(error.errors['comments'][1]['content'], ValidationError)
+ self.assertIn("comments", error.errors)
+ self.assertIn(1, error.errors["comments"])
+ self.assertIsInstance(
+ error.errors["comments"][1]["content"], ValidationError
+ )
# ValidationError.schema property
error_dict = error.to_dict()
self.assertIsInstance(error_dict, dict)
- self.assertIn('comments', error_dict)
- self.assertIn(1, error_dict['comments'])
- self.assertIn('content', error_dict['comments'][1])
- self.assertEqual(error_dict['comments'][1]['content'],
- u'Field is required')
+ self.assertIn("comments", error_dict)
+ self.assertIn(1, error_dict["comments"])
+ self.assertIn("content", error_dict["comments"][1])
+ self.assertEqual(error_dict["comments"][1]["content"], u"Field is required")
- post.comments[1].content = 'here we go'
+ post.comments[1].content = "here we go"
post.validate()
def test_tuples_as_tuples(self):
"""Ensure that tuples remain tuples when they are inside
a ComplexBaseField.
"""
- class EnumField(BaseField):
+ class EnumField(BaseField):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
@@ -1988,7 +2100,7 @@ class FieldTest(MongoDBTestCase):
TestDoc.drop_collection()
- tuples = [(100, 'Testing')]
+ tuples = [(100, "Testing")]
doc = TestDoc()
doc.items = tuples
doc.save()
@@ -2000,12 +2112,12 @@ class FieldTest(MongoDBTestCase):
def test_dynamic_fields_class(self):
class Doc2(Document):
- field_1 = StringField(db_field='f')
+ field_1 = StringField(db_field="f")
class Doc(Document):
my_id = IntField(primary_key=True)
- embed_me = DynamicField(db_field='e')
- field_x = StringField(db_field='x')
+ embed_me = DynamicField(db_field="e")
+ field_x = StringField(db_field="x")
Doc.drop_collection()
Doc2.drop_collection()
@@ -2022,12 +2134,12 @@ class FieldTest(MongoDBTestCase):
def test_dynamic_fields_embedded_class(self):
class Embed(EmbeddedDocument):
- field_1 = StringField(db_field='f')
+ field_1 = StringField(db_field="f")
class Doc(Document):
my_id = IntField(primary_key=True)
- embed_me = DynamicField(db_field='e')
- field_x = StringField(db_field='x')
+ embed_me = DynamicField(db_field="e")
+ field_x = StringField(db_field="x")
Doc.drop_collection()
@@ -2038,6 +2150,7 @@ class FieldTest(MongoDBTestCase):
def test_dynamicfield_dump_document(self):
"""Ensure a DynamicField can handle another document's dump."""
+
class Doc(Document):
field = DynamicField()
@@ -2049,7 +2162,7 @@ class FieldTest(MongoDBTestCase):
id = IntField(primary_key=True, default=1)
recursive = DynamicField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class ToEmbedChild(ToEmbedParent):
pass
@@ -2070,7 +2183,7 @@ class FieldTest(MongoDBTestCase):
def test_cls_field(self):
class Animal(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Fish(Animal):
pass
@@ -2088,7 +2201,9 @@ class FieldTest(MongoDBTestCase):
Dog().save()
Fish().save()
Human().save()
- self.assertEqual(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2)
+ self.assertEqual(
+ Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2
+ )
self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0)
def test_sparse_field(self):
@@ -2104,32 +2219,34 @@ class FieldTest(MongoDBTestCase):
trying to instantiate a document with a field that's not
defined.
"""
+
class Doc(Document):
foo = StringField()
with self.assertRaises(FieldDoesNotExist):
- Doc(bar='test')
+ Doc(bar="test")
def test_undefined_field_exception_with_strict(self):
"""Tests if a `FieldDoesNotExist` exception is raised when
trying to instantiate a document with a field that's not
defined, even when strict is set to False.
"""
+
class Doc(Document):
foo = StringField()
- meta = {'strict': False}
+ meta = {"strict": False}
with self.assertRaises(FieldDoesNotExist):
- Doc(bar='test')
+ Doc(bar="test")
class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
-
def setUp(self):
"""
Create two BlogPost entries in the database, each with
several EmbeddedDocuments.
"""
+
class Comments(EmbeddedDocument):
author = StringField()
message = StringField()
@@ -2142,20 +2259,24 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.Comments = Comments
self.BlogPost = BlogPost
- self.post1 = self.BlogPost(comments=[
- self.Comments(author='user1', message='message1'),
- self.Comments(author='user2', message='message1')
- ]).save()
+ self.post1 = self.BlogPost(
+ comments=[
+ self.Comments(author="user1", message="message1"),
+ self.Comments(author="user2", message="message1"),
+ ]
+ ).save()
- self.post2 = self.BlogPost(comments=[
- self.Comments(author='user2', message='message2'),
- self.Comments(author='user2', message='message3'),
- self.Comments(author='user3', message='message1')
- ]).save()
+ self.post2 = self.BlogPost(
+ comments=[
+ self.Comments(author="user2", message="message2"),
+ self.Comments(author="user2", message="message3"),
+ self.Comments(author="user3", message="message1"),
+ ]
+ ).save()
def test_fails_upon_validate_if_provide_a_doc_instead_of_a_list_of_doc(self):
# Relates to Issue #1464
- comment = self.Comments(author='John')
+ comment = self.Comments(author="John")
class Title(Document):
content = StringField()
@@ -2166,14 +2287,18 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
with self.assertRaises(ValidationError) as ctx_err:
post.validate()
self.assertIn("'comments'", str(ctx_err.exception))
- self.assertIn('Only lists and tuples may be used in a list field', str(ctx_err.exception))
+ self.assertIn(
+ "Only lists and tuples may be used in a list field", str(ctx_err.exception)
+ )
# Test with a Document
- post = self.BlogPost(comments=Title(content='garbage'))
+ post = self.BlogPost(comments=Title(content="garbage"))
with self.assertRaises(ValidationError) as e:
post.validate()
self.assertIn("'comments'", str(ctx_err.exception))
- self.assertIn('Only lists and tuples may be used in a list field', str(ctx_err.exception))
+ self.assertIn(
+ "Only lists and tuples may be used in a list field", str(ctx_err.exception)
+ )
def test_no_keyword_filter(self):
"""
@@ -2190,44 +2315,40 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests the filter method of a List of Embedded Documents
with a single keyword.
"""
- filtered = self.post1.comments.filter(author='user1')
+ filtered = self.post1.comments.filter(author="user1")
# Ensure only 1 entry was returned.
self.assertEqual(len(filtered), 1)
# Ensure the entry returned is the correct entry.
- self.assertEqual(filtered[0].author, 'user1')
+ self.assertEqual(filtered[0].author, "user1")
def test_multi_keyword_filter(self):
"""
Tests the filter method of a List of Embedded Documents
with multiple keywords.
"""
- filtered = self.post2.comments.filter(
- author='user2', message='message2'
- )
+ filtered = self.post2.comments.filter(author="user2", message="message2")
# Ensure only 1 entry was returned.
self.assertEqual(len(filtered), 1)
# Ensure the entry returned is the correct entry.
- self.assertEqual(filtered[0].author, 'user2')
- self.assertEqual(filtered[0].message, 'message2')
+ self.assertEqual(filtered[0].author, "user2")
+ self.assertEqual(filtered[0].message, "message2")
def test_chained_filter(self):
"""
Tests chained filter methods of a List of Embedded Documents
"""
- filtered = self.post2.comments.filter(author='user2').filter(
- message='message2'
- )
+ filtered = self.post2.comments.filter(author="user2").filter(message="message2")
# Ensure only 1 entry was returned.
self.assertEqual(len(filtered), 1)
# Ensure the entry returned is the correct entry.
- self.assertEqual(filtered[0].author, 'user2')
- self.assertEqual(filtered[0].message, 'message2')
+ self.assertEqual(filtered[0].author, "user2")
+ self.assertEqual(filtered[0].message, "message2")
def test_unknown_keyword_filter(self):
"""
@@ -2252,36 +2373,34 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests the exclude method of a List of Embedded Documents
with a single keyword.
"""
- excluded = self.post1.comments.exclude(author='user1')
+ excluded = self.post1.comments.exclude(author="user1")
# Ensure only 1 entry was returned.
self.assertEqual(len(excluded), 1)
# Ensure the entry returned is the correct entry.
- self.assertEqual(excluded[0].author, 'user2')
+ self.assertEqual(excluded[0].author, "user2")
def test_multi_keyword_exclude(self):
"""
Tests the exclude method of a List of Embedded Documents
with multiple keywords.
"""
- excluded = self.post2.comments.exclude(
- author='user3', message='message1'
- )
+ excluded = self.post2.comments.exclude(author="user3", message="message1")
# Ensure only 2 entries were returned.
self.assertEqual(len(excluded), 2)
# Ensure the entries returned are the correct entries.
- self.assertEqual(excluded[0].author, 'user2')
- self.assertEqual(excluded[1].author, 'user2')
+ self.assertEqual(excluded[0].author, "user2")
+ self.assertEqual(excluded[1].author, "user2")
def test_non_matching_exclude(self):
"""
Tests the exclude method of a List of Embedded Documents
when the keyword does not match any entries.
"""
- excluded = self.post2.comments.exclude(author='user4')
+ excluded = self.post2.comments.exclude(author="user4")
# Ensure the 3 entries still exist.
self.assertEqual(len(excluded), 3)
@@ -2299,16 +2418,16 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests the exclude method after a filter method of a List of
Embedded Documents.
"""
- excluded = self.post2.comments.filter(author='user2').exclude(
- message='message2'
+ excluded = self.post2.comments.filter(author="user2").exclude(
+ message="message2"
)
# Ensure only 1 entry was returned.
self.assertEqual(len(excluded), 1)
# Ensure the entry returned is the correct entry.
- self.assertEqual(excluded[0].author, 'user2')
- self.assertEqual(excluded[0].message, 'message3')
+ self.assertEqual(excluded[0].author, "user2")
+ self.assertEqual(excluded[0].message, "message3")
def test_count(self):
"""
@@ -2321,7 +2440,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
"""
Tests the filter + count method of a List of Embedded Documents.
"""
- count = self.post1.comments.filter(author='user1').count()
+ count = self.post1.comments.filter(author="user1").count()
self.assertEqual(count, 1)
def test_single_keyword_get(self):
@@ -2329,19 +2448,19 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests the get method of a List of Embedded Documents using a
single keyword.
"""
- comment = self.post1.comments.get(author='user1')
+ comment = self.post1.comments.get(author="user1")
self.assertIsInstance(comment, self.Comments)
- self.assertEqual(comment.author, 'user1')
+ self.assertEqual(comment.author, "user1")
def test_multi_keyword_get(self):
"""
Tests the get method of a List of Embedded Documents using
multiple keywords.
"""
- comment = self.post2.comments.get(author='user2', message='message2')
+ comment = self.post2.comments.get(author="user2", message="message2")
self.assertIsInstance(comment, self.Comments)
- self.assertEqual(comment.author, 'user2')
- self.assertEqual(comment.message, 'message2')
+ self.assertEqual(comment.author, "user2")
+ self.assertEqual(comment.message, "message2")
def test_no_keyword_multiple_return_get(self):
"""
@@ -2357,7 +2476,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
to return multiple documents.
"""
with self.assertRaises(MultipleObjectsReturned):
- self.post2.comments.get(author='user2')
+ self.post2.comments.get(author="user2")
def test_unknown_keyword_get(self):
"""
@@ -2373,7 +2492,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
returns no results.
"""
with self.assertRaises(DoesNotExist):
- self.post1.comments.get(author='user3')
+ self.post1.comments.get(author="user3")
def test_first(self):
"""
@@ -2390,20 +2509,17 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
"""
Test the create method of a List of Embedded Documents.
"""
- comment = self.post1.comments.create(
- author='user4', message='message1'
- )
+ comment = self.post1.comments.create(author="user4", message="message1")
self.post1.save()
# Ensure the returned value is the comment object.
self.assertIsInstance(comment, self.Comments)
- self.assertEqual(comment.author, 'user4')
- self.assertEqual(comment.message, 'message1')
+ self.assertEqual(comment.author, "user4")
+ self.assertEqual(comment.message, "message1")
# Ensure the new comment was actually saved to the database.
self.assertIn(
- comment,
- self.BlogPost.objects(comments__author='user4')[0].comments
+ comment, self.BlogPost.objects(comments__author="user4")[0].comments
)
def test_filtered_create(self):
@@ -2412,20 +2528,19 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
to a call to the filter method. Filtering should have no effect
on creation.
"""
- comment = self.post1.comments.filter(author='user1').create(
- author='user4', message='message1'
+ comment = self.post1.comments.filter(author="user1").create(
+ author="user4", message="message1"
)
self.post1.save()
# Ensure the returned value is the comment object.
self.assertIsInstance(comment, self.Comments)
- self.assertEqual(comment.author, 'user4')
- self.assertEqual(comment.message, 'message1')
+ self.assertEqual(comment.author, "user4")
+ self.assertEqual(comment.message, "message1")
# Ensure the new comment was actually saved to the database.
self.assertIn(
- comment,
- self.BlogPost.objects(comments__author='user4')[0].comments
+ comment, self.BlogPost.objects(comments__author="user4")[0].comments
)
def test_no_keyword_update(self):
@@ -2438,15 +2553,9 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
self.post1.save()
# Ensure that nothing was altered.
- self.assertIn(
- original[0],
- self.BlogPost.objects(id=self.post1.id)[0].comments
- )
+ self.assertIn(original[0], self.BlogPost.objects(id=self.post1.id)[0].comments)
- self.assertIn(
- original[1],
- self.BlogPost.objects(id=self.post1.id)[0].comments
- )
+ self.assertIn(original[1], self.BlogPost.objects(id=self.post1.id)[0].comments)
# Ensure the method returned 0 as the number of entries
# modified
@@ -2457,14 +2566,14 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests the update method of a List of Embedded Documents with
a single keyword.
"""
- number = self.post1.comments.update(author='user4')
+ number = self.post1.comments.update(author="user4")
self.post1.save()
comments = self.BlogPost.objects(id=self.post1.id)[0].comments
# Ensure that the database was updated properly.
- self.assertEqual(comments[0].author, 'user4')
- self.assertEqual(comments[1].author, 'user4')
+ self.assertEqual(comments[0].author, "user4")
+ self.assertEqual(comments[1].author, "user4")
# Ensure the method returned 2 as the number of entries
# modified
@@ -2474,27 +2583,25 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
"""
Tests that unicode strings handled correctly
"""
- post = self.BlogPost(comments=[
- self.Comments(author='user1', message=u'сообщение'),
- self.Comments(author='user2', message=u'хабарлама')
- ]).save()
- self.assertEqual(post.comments.get(message=u'сообщение').author,
- 'user1')
+ post = self.BlogPost(
+ comments=[
+ self.Comments(author="user1", message=u"сообщение"),
+ self.Comments(author="user2", message=u"хабарлама"),
+ ]
+ ).save()
+ self.assertEqual(post.comments.get(message=u"сообщение").author, "user1")
def test_save(self):
"""
Tests the save method of a List of Embedded Documents.
"""
comments = self.post1.comments
- new_comment = self.Comments(author='user4')
+ new_comment = self.Comments(author="user4")
comments.append(new_comment)
comments.save()
# Ensure that the new comment has been added to the database.
- self.assertIn(
- new_comment,
- self.BlogPost.objects(id=self.post1.id)[0].comments
- )
+ self.assertIn(new_comment, self.BlogPost.objects(id=self.post1.id)[0].comments)
def test_delete(self):
"""
@@ -2505,9 +2612,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
# Ensure that all the comments under post1 were deleted in the
# database.
- self.assertListEqual(
- self.BlogPost.objects(id=self.post1.id)[0].comments, []
- )
+ self.assertListEqual(self.BlogPost.objects(id=self.post1.id)[0].comments, [])
# Ensure that post1 comments were deleted from the list.
self.assertListEqual(self.post1.comments, [])
@@ -2525,6 +2630,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
that have a unique field can be saved, but if the unique field is
also sparse than multiple documents with an empty list can be saved.
"""
+
class EmbeddedWithUnique(EmbeddedDocument):
number = IntField(unique=True)
@@ -2553,16 +2659,12 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
after the filter method has been called.
"""
comment = self.post1.comments[1]
- number = self.post1.comments.filter(author='user2').delete()
+ number = self.post1.comments.filter(author="user2").delete()
self.post1.save()
# Ensure that only the user2 comment was deleted.
- self.assertNotIn(
- comment, self.BlogPost.objects(id=self.post1.id)[0].comments
- )
- self.assertEqual(
- len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1
- )
+ self.assertNotIn(comment, self.BlogPost.objects(id=self.post1.id)[0].comments)
+ self.assertEqual(len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1)
# Ensure that the user2 comment no longer exists in the list.
self.assertNotIn(comment, self.post1.comments)
@@ -2577,7 +2679,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
Tests that custom data is saved in the field object
and doesn't interfere with the rest of field functionalities.
"""
- custom_data = {'a': 'a_value', 'b': [1, 2]}
+ custom_data = {"a": "a_value", "b": [1, 2]}
class CustomData(Document):
a_field = IntField()
@@ -2587,10 +2689,10 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase):
a1 = CustomData(a_field=1, c_field=2).save()
self.assertEqual(2, a1.c_field)
- self.assertFalse(hasattr(a1.c_field, 'custom_data'))
- self.assertTrue(hasattr(CustomData.c_field, 'custom_data'))
- self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a'])
+ self.assertFalse(hasattr(a1.c_field, "custom_data"))
+ self.assertTrue(hasattr(CustomData.c_field, "custom_data"))
+ self.assertEqual(custom_data["a"], CustomData.c_field.custom_data["a"])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py
index a7722458..dd2fe609 100644
--- a/tests/fields/file_tests.py
+++ b/tests/fields/file_tests.py
@@ -14,36 +14,37 @@ from mongoengine.python_support import StringIO
try:
from PIL import Image
+
HAS_PIL = True
except ImportError:
HAS_PIL = False
from tests.utils import MongoDBTestCase
-TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
-TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
+TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "mongoengine.png")
+TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), "mongodb_leaf.png")
def get_file(path):
"""Use a BytesIO instead of a file to allow
to have a one-liner and avoid that the file remains opened"""
bytes_io = StringIO()
- with open(path, 'rb') as f:
+ with open(path, "rb") as f:
bytes_io.write(f.read())
bytes_io.seek(0)
return bytes_io
class FileTest(MongoDBTestCase):
-
def tearDown(self):
- self.db.drop_collection('fs.files')
- self.db.drop_collection('fs.chunks')
+ self.db.drop_collection("fs.files")
+ self.db.drop_collection("fs.chunks")
def test_file_field_optional(self):
# Make sure FileField is optional and not required
class DemoFile(Document):
the_file = FileField()
+
DemoFile.objects.create()
def test_file_fields(self):
@@ -55,8 +56,8 @@ class FileTest(MongoDBTestCase):
PutFile.drop_collection()
- text = six.b('Hello, World!')
- content_type = 'text/plain'
+ text = six.b("Hello, World!")
+ content_type = "text/plain"
putfile = PutFile()
putfile.the_file.put(text, content_type=content_type, filename="hello")
@@ -64,7 +65,10 @@ class FileTest(MongoDBTestCase):
result = PutFile.objects.first()
self.assertEqual(putfile, result)
- self.assertEqual("%s" % result.the_file, "" % result.the_file.grid_id)
+ self.assertEqual(
+ "%s" % result.the_file,
+ "" % result.the_file.grid_id,
+ )
self.assertEqual(result.the_file.read(), text)
self.assertEqual(result.the_file.content_type, content_type)
result.the_file.delete() # Remove file from GridFS
@@ -89,14 +93,15 @@ class FileTest(MongoDBTestCase):
def test_file_fields_stream(self):
"""Ensure that file fields can be written to and their data retrieved
"""
+
class StreamFile(Document):
the_file = FileField()
StreamFile.drop_collection()
- text = six.b('Hello, World!')
- more_text = six.b('Foo Bar')
- content_type = 'text/plain'
+ text = six.b("Hello, World!")
+ more_text = six.b("Foo Bar")
+ content_type = "text/plain"
streamfile = StreamFile()
streamfile.the_file.new_file(content_type=content_type)
@@ -124,14 +129,15 @@ class FileTest(MongoDBTestCase):
"""Ensure that a file field can be written to after it has been saved as
None
"""
+
class StreamFile(Document):
the_file = FileField()
StreamFile.drop_collection()
- text = six.b('Hello, World!')
- more_text = six.b('Foo Bar')
- content_type = 'text/plain'
+ text = six.b("Hello, World!")
+ more_text = six.b("Foo Bar")
+ content_type = "text/plain"
streamfile = StreamFile()
streamfile.save()
@@ -157,12 +163,11 @@ class FileTest(MongoDBTestCase):
self.assertTrue(result.the_file.read() is None)
def test_file_fields_set(self):
-
class SetFile(Document):
the_file = FileField()
- text = six.b('Hello, World!')
- more_text = six.b('Foo Bar')
+ text = six.b("Hello, World!")
+ more_text = six.b("Foo Bar")
SetFile.drop_collection()
@@ -184,7 +189,6 @@ class FileTest(MongoDBTestCase):
result.the_file.delete()
def test_file_field_no_default(self):
-
class GridDocument(Document):
the_file = FileField()
@@ -199,7 +203,7 @@ class FileTest(MongoDBTestCase):
doc_a.save()
doc_b = GridDocument.objects.with_id(doc_a.id)
- doc_b.the_file.replace(f, filename='doc_b')
+ doc_b.the_file.replace(f, filename="doc_b")
doc_b.save()
self.assertNotEqual(doc_b.the_file.grid_id, None)
@@ -208,13 +212,13 @@ class FileTest(MongoDBTestCase):
self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id)
# Test with default
- doc_d = GridDocument(the_file=six.b(''))
+ doc_d = GridDocument(the_file=six.b(""))
doc_d.save()
doc_e = GridDocument.objects.with_id(doc_d.id)
self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id)
- doc_e.the_file.replace(f, filename='doc_e')
+ doc_e.the_file.replace(f, filename="doc_e")
doc_e.save()
doc_f = GridDocument.objects.with_id(doc_e.id)
@@ -222,11 +226,12 @@ class FileTest(MongoDBTestCase):
db = GridDocument._get_db()
grid_fs = gridfs.GridFS(db)
- self.assertEqual(['doc_b', 'doc_e'], grid_fs.list())
+ self.assertEqual(["doc_b", "doc_e"], grid_fs.list())
def test_file_uniqueness(self):
"""Ensure that each instance of a FileField is unique
"""
+
class TestFile(Document):
name = StringField()
the_file = FileField()
@@ -234,7 +239,7 @@ class FileTest(MongoDBTestCase):
# First instance
test_file = TestFile()
test_file.name = "Hello, World!"
- test_file.the_file.put(six.b('Hello, World!'))
+ test_file.the_file.put(six.b("Hello, World!"))
test_file.save()
# Second instance
@@ -255,20 +260,21 @@ class FileTest(MongoDBTestCase):
photo = FileField()
Animal.drop_collection()
- marmot = Animal(genus='Marmota', family='Sciuridae')
+ marmot = Animal(genus="Marmota", family="Sciuridae")
marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk
- marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar')
+ marmot.photo.put(marmot_photo_content, content_type="image/jpeg", foo="bar")
marmot.photo.close()
marmot.save()
marmot = Animal.objects.get()
- self.assertEqual(marmot.photo.content_type, 'image/jpeg')
- self.assertEqual(marmot.photo.foo, 'bar')
+ self.assertEqual(marmot.photo.content_type, "image/jpeg")
+ self.assertEqual(marmot.photo.foo, "bar")
def test_file_reassigning(self):
class TestFile(Document):
the_file = FileField()
+
TestFile.drop_collection()
test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save()
@@ -282,13 +288,15 @@ class FileTest(MongoDBTestCase):
def test_file_boolean(self):
"""Ensure that a boolean test of a FileField indicates its presence
"""
+
class TestFile(Document):
the_file = FileField()
+
TestFile.drop_collection()
test_file = TestFile()
self.assertFalse(bool(test_file.the_file))
- test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain')
+ test_file.the_file.put(six.b("Hello, World!"), content_type="text/plain")
test_file.save()
self.assertTrue(bool(test_file.the_file))
@@ -297,6 +305,7 @@ class FileTest(MongoDBTestCase):
def test_file_cmp(self):
"""Test comparing against other types"""
+
class TestFile(Document):
the_file = FileField()
@@ -305,11 +314,12 @@ class FileTest(MongoDBTestCase):
def test_file_disk_space(self):
""" Test disk space usage when we delete/replace a file """
+
class TestFile(Document):
the_file = FileField()
- text = six.b('Hello, World!')
- content_type = 'text/plain'
+ text = six.b("Hello, World!")
+ content_type = "text/plain"
testfile = TestFile()
testfile.the_file.put(text, content_type=content_type, filename="hello")
@@ -352,7 +362,7 @@ class FileTest(MongoDBTestCase):
testfile.the_file.put(text, content_type=content_type, filename="hello")
testfile.save()
- text = six.b('Bonjour, World!')
+ text = six.b("Bonjour, World!")
testfile.the_file.replace(text, content_type=content_type, filename="hello")
testfile.save()
@@ -370,7 +380,7 @@ class FileTest(MongoDBTestCase):
def test_image_field(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestImage(Document):
image = ImageField()
@@ -386,7 +396,9 @@ class FileTest(MongoDBTestCase):
t.image.put(f)
self.fail("Should have raised an invalidation error")
except ValidationError as e:
- self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)
+ self.assertEqual(
+ "%s" % e, "Invalid image: cannot identify image file %s" % f
+ )
t = TestImage()
t.image.put(get_file(TEST_IMAGE_PATH))
@@ -394,7 +406,7 @@ class FileTest(MongoDBTestCase):
t = TestImage.objects.first()
- self.assertEqual(t.image.format, 'PNG')
+ self.assertEqual(t.image.format, "PNG")
w, h = t.image.size
self.assertEqual(w, 371)
@@ -404,10 +416,11 @@ class FileTest(MongoDBTestCase):
def test_image_field_reassigning(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestFile(Document):
the_file = ImageField()
+
TestFile.drop_collection()
test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save()
@@ -420,7 +433,7 @@ class FileTest(MongoDBTestCase):
def test_image_field_resize(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestImage(Document):
image = ImageField(size=(185, 37))
@@ -433,7 +446,7 @@ class FileTest(MongoDBTestCase):
t = TestImage.objects.first()
- self.assertEqual(t.image.format, 'PNG')
+ self.assertEqual(t.image.format, "PNG")
w, h = t.image.size
self.assertEqual(w, 185)
@@ -443,7 +456,7 @@ class FileTest(MongoDBTestCase):
def test_image_field_resize_force(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestImage(Document):
image = ImageField(size=(185, 37, True))
@@ -456,7 +469,7 @@ class FileTest(MongoDBTestCase):
t = TestImage.objects.first()
- self.assertEqual(t.image.format, 'PNG')
+ self.assertEqual(t.image.format, "PNG")
w, h = t.image.size
self.assertEqual(w, 185)
@@ -466,7 +479,7 @@ class FileTest(MongoDBTestCase):
def test_image_field_thumbnail(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestImage(Document):
image = ImageField(thumbnail_size=(92, 18))
@@ -479,19 +492,18 @@ class FileTest(MongoDBTestCase):
t = TestImage.objects.first()
- self.assertEqual(t.image.thumbnail.format, 'PNG')
+ self.assertEqual(t.image.thumbnail.format, "PNG")
self.assertEqual(t.image.thumbnail.width, 92)
self.assertEqual(t.image.thumbnail.height, 18)
t.image.delete()
def test_file_multidb(self):
- register_connection('test_files', 'test_files')
+ register_connection("test_files", "test_files")
class TestFile(Document):
name = StringField()
- the_file = FileField(db_alias="test_files",
- collection_name="macumba")
+ the_file = FileField(db_alias="test_files", collection_name="macumba")
TestFile.drop_collection()
@@ -502,23 +514,21 @@ class FileTest(MongoDBTestCase):
# First instance
test_file = TestFile()
test_file.name = "Hello, World!"
- test_file.the_file.put(six.b('Hello, World!'),
- name="hello.txt")
+ test_file.the_file.put(six.b("Hello, World!"), name="hello.txt")
test_file.save()
data = get_db("test_files").macumba.files.find_one()
- self.assertEqual(data.get('name'), 'hello.txt')
+ self.assertEqual(data.get("name"), "hello.txt")
test_file = TestFile.objects.first()
- self.assertEqual(test_file.the_file.read(), six.b('Hello, World!'))
+ self.assertEqual(test_file.the_file.read(), six.b("Hello, World!"))
test_file = TestFile.objects.first()
- test_file.the_file = six.b('HELLO, WORLD!')
+ test_file.the_file = six.b("HELLO, WORLD!")
test_file.save()
test_file = TestFile.objects.first()
- self.assertEqual(test_file.the_file.read(),
- six.b('HELLO, WORLD!'))
+ self.assertEqual(test_file.the_file.read(), six.b("HELLO, WORLD!"))
def test_copyable(self):
class PutFile(Document):
@@ -526,8 +536,8 @@ class FileTest(MongoDBTestCase):
PutFile.drop_collection()
- text = six.b('Hello, World!')
- content_type = 'text/plain'
+ text = six.b("Hello, World!")
+ content_type = "text/plain"
putfile = PutFile()
putfile.the_file.put(text, content_type=content_type)
@@ -542,7 +552,7 @@ class FileTest(MongoDBTestCase):
def test_get_image_by_grid_id(self):
if not HAS_PIL:
- raise SkipTest('PIL not installed')
+ raise SkipTest("PIL not installed")
class TestImage(Document):
@@ -559,8 +569,9 @@ class FileTest(MongoDBTestCase):
test = TestImage.objects.first()
grid_id = test.image1.grid_id
- self.assertEqual(1, TestImage.objects(Q(image1=grid_id)
- or Q(image2=grid_id)).count())
+ self.assertEqual(
+ 1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count()
+ )
def test_complex_field_filefield(self):
"""Ensure you can add meta data to file"""
@@ -571,21 +582,21 @@ class FileTest(MongoDBTestCase):
photos = ListField(FileField())
Animal.drop_collection()
- marmot = Animal(genus='Marmota', family='Sciuridae')
+ marmot = Animal(genus="Marmota", family="Sciuridae")
- with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk
- photos_field = marmot._fields['photos'].field
- new_proxy = photos_field.get_proxy_obj('photos', marmot)
- new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar')
+ with open(TEST_IMAGE_PATH, "rb") as marmot_photo: # Retrieve a photo from disk
+ photos_field = marmot._fields["photos"].field
+ new_proxy = photos_field.get_proxy_obj("photos", marmot)
+ new_proxy.put(marmot_photo, content_type="image/jpeg", foo="bar")
marmot.photos.append(new_proxy)
marmot.save()
marmot = Animal.objects.get()
- self.assertEqual(marmot.photos[0].content_type, 'image/jpeg')
- self.assertEqual(marmot.photos[0].foo, 'bar')
+ self.assertEqual(marmot.photos[0].content_type, "image/jpeg")
+ self.assertEqual(marmot.photos[0].foo, "bar")
self.assertEqual(marmot.photos[0].get().length, 8313)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/fields/geo.py b/tests/fields/geo.py
index 37ed97f5..446d7171 100644
--- a/tests/fields/geo.py
+++ b/tests/fields/geo.py
@@ -4,28 +4,27 @@ import unittest
from mongoengine import *
from mongoengine.connection import get_db
-__all__ = ("GeoFieldTest", )
+__all__ = ("GeoFieldTest",)
class GeoFieldTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
self.db = get_db()
def _test_for_expected_error(self, Cls, loc, expected):
try:
Cls(loc=loc).validate()
- self.fail('Should not validate the location {0}'.format(loc))
+ self.fail("Should not validate the location {0}".format(loc))
except ValidationError as e:
- self.assertEqual(expected, e.to_dict()['loc'])
+ self.assertEqual(expected, e.to_dict()["loc"])
def test_geopoint_validation(self):
class Location(Document):
loc = GeoPointField()
invalid_coords = [{"x": 1, "y": 2}, 5, "a"]
- expected = 'GeoPointField can only accept tuples or lists of (x, y)'
+ expected = "GeoPointField can only accept tuples or lists of (x, y)"
for coord in invalid_coords:
self._test_for_expected_error(Location, coord, expected)
@@ -40,7 +39,7 @@ class GeoFieldTest(unittest.TestCase):
expected = "Both values (%s) in point must be float or int" % repr(coord)
self._test_for_expected_error(Location, coord, expected)
- invalid_coords = [21, 4, 'a']
+ invalid_coords = [21, 4, "a"]
for coord in invalid_coords:
expected = "GeoPointField can only accept tuples or lists of (x, y)"
self._test_for_expected_error(Location, coord, expected)
@@ -50,7 +49,9 @@ class GeoFieldTest(unittest.TestCase):
loc = PointField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'PointField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = (
+ "PointField can only accept a valid GeoJson dictionary or lists of (x, y)"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": []}
@@ -77,19 +78,16 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, coord, expected)
Location(loc=[1, 2]).validate()
- Location(loc={
- "type": "Point",
- "coordinates": [
- 81.4471435546875,
- 23.61432859499169
- ]}).validate()
+ Location(
+ loc={"type": "Point", "coordinates": [81.4471435546875, 23.61432859499169]}
+ ).validate()
def test_linestring_validation(self):
class Location(Document):
loc = LineStringField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = "LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
@@ -97,7 +95,9 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "LineString", "coordinates": [[1, 2, 3]]}
- expected = "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ expected = (
+ "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [5, "a"]
@@ -105,16 +105,25 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[1]]
- expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
+ expected = (
+ "Invalid LineString:\nValue (%s) must be a two-dimensional point"
+ % repr(invalid_coords[0])
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[1, 2, 3]]
- expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
+ expected = (
+ "Invalid LineString:\nValue (%s) must be a two-dimensional point"
+ % repr(invalid_coords[0])
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[{}, {}]], [("a", "b")]]
for coord in invalid_coords:
- expected = "Invalid LineString:\nBoth values (%s) in point must be float or int" % repr(coord[0])
+ expected = (
+ "Invalid LineString:\nBoth values (%s) in point must be float or int"
+ % repr(coord[0])
+ )
self._test_for_expected_error(Location, coord, expected)
Location(loc=[[1, 2], [3, 4], [5, 6], [1, 2]]).validate()
@@ -124,7 +133,9 @@ class GeoFieldTest(unittest.TestCase):
loc = PolygonField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = (
+ "PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
@@ -136,7 +147,9 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[5, "a"]]]
- expected = "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int"
+ expected = (
+ "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[]]]
@@ -162,7 +175,7 @@ class GeoFieldTest(unittest.TestCase):
loc = MultiPointField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'MultiPointField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = "MultiPointField can only accept a valid GeoJson dictionary or lists of (x, y)"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
@@ -188,19 +201,19 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, coord, expected)
Location(loc=[[1, 2]]).validate()
- Location(loc={
- "type": "MultiPoint",
- "coordinates": [
- [1, 2],
- [81.4471435546875, 23.61432859499169]
- ]}).validate()
+ Location(
+ loc={
+ "type": "MultiPoint",
+ "coordinates": [[1, 2], [81.4471435546875, 23.61432859499169]],
+ }
+ ).validate()
def test_multilinestring_validation(self):
class Location(Document):
loc = MultiLineStringField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'MultiLineStringField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = "MultiLineStringField can only accept a valid GeoJson dictionary or lists of (x, y)"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
@@ -216,16 +229,25 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[1]]]
- expected = "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0][0])
+ expected = (
+ "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point"
+ % repr(invalid_coords[0][0])
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[1, 2, 3]]]
- expected = "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0][0])
+ expected = (
+ "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point"
+ % repr(invalid_coords[0][0])
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[[{}, {}]]], [[("a", "b")]]]
for coord in invalid_coords:
- expected = "Invalid MultiLineString:\nBoth values (%s) in point must be float or int" % repr(coord[0][0])
+ expected = (
+ "Invalid MultiLineString:\nBoth values (%s) in point must be float or int"
+ % repr(coord[0][0])
+ )
self._test_for_expected_error(Location, coord, expected)
Location(loc=[[[1, 2], [3, 4], [5, 6], [1, 2]]]).validate()
@@ -235,7 +257,7 @@ class GeoFieldTest(unittest.TestCase):
loc = MultiPolygonField()
invalid_coords = {"x": 1, "y": 2}
- expected = 'MultiPolygonField can only accept a valid GeoJson dictionary or lists of (x, y)'
+ expected = "MultiPolygonField can only accept a valid GeoJson dictionary or lists of (x, y)"
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
@@ -243,7 +265,9 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = {"type": "MultiPolygon", "coordinates": [[[[1, 2, 3]]]]}
- expected = "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ expected = (
+ "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[[5, "a"]]]]
@@ -255,7 +279,9 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[[1, 2, 3]]]]
- expected = "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ expected = (
+ "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[[{}, {}]]], [[("a", "b")]]]
@@ -263,7 +289,9 @@ class GeoFieldTest(unittest.TestCase):
self._test_for_expected_error(Location, invalid_coords, expected)
invalid_coords = [[[[1, 2], [3, 4]]]]
- expected = "Invalid MultiPolygon:\nLineStrings must start and end at the same point"
+ expected = (
+ "Invalid MultiPolygon:\nLineStrings must start and end at the same point"
+ )
self._test_for_expected_error(Location, invalid_coords, expected)
Location(loc=[[[[1, 2], [3, 4], [5, 6], [1, 2]]]]).validate()
@@ -271,17 +299,19 @@ class GeoFieldTest(unittest.TestCase):
def test_indexes_geopoint(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
+
class Event(Document):
title = StringField()
location = GeoPointField()
geo_indicies = Event._geo_indices()
- self.assertEqual(geo_indicies, [{'fields': [('location', '2d')]}])
+ self.assertEqual(geo_indicies, [{"fields": [("location", "2d")]}])
def test_geopoint_embedded_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields on
embedded documents.
"""
+
class Venue(EmbeddedDocument):
location = GeoPointField()
name = StringField()
@@ -291,11 +321,12 @@ class GeoFieldTest(unittest.TestCase):
venue = EmbeddedDocumentField(Venue)
geo_indicies = Event._geo_indices()
- self.assertEqual(geo_indicies, [{'fields': [('venue.location', '2d')]}])
+ self.assertEqual(geo_indicies, [{"fields": [("venue.location", "2d")]}])
def test_indexes_2dsphere(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
+
class Event(Document):
title = StringField()
point = PointField()
@@ -303,13 +334,14 @@ class GeoFieldTest(unittest.TestCase):
polygon = PolygonField()
geo_indicies = Event._geo_indices()
- self.assertIn({'fields': [('line', '2dsphere')]}, geo_indicies)
- self.assertIn({'fields': [('polygon', '2dsphere')]}, geo_indicies)
- self.assertIn({'fields': [('point', '2dsphere')]}, geo_indicies)
+ self.assertIn({"fields": [("line", "2dsphere")]}, geo_indicies)
+ self.assertIn({"fields": [("polygon", "2dsphere")]}, geo_indicies)
+ self.assertIn({"fields": [("point", "2dsphere")]}, geo_indicies)
def test_indexes_2dsphere_embedded(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
+
class Venue(EmbeddedDocument):
name = StringField()
point = PointField()
@@ -321,12 +353,11 @@ class GeoFieldTest(unittest.TestCase):
venue = EmbeddedDocumentField(Venue)
geo_indicies = Event._geo_indices()
- self.assertIn({'fields': [('venue.line', '2dsphere')]}, geo_indicies)
- self.assertIn({'fields': [('venue.polygon', '2dsphere')]}, geo_indicies)
- self.assertIn({'fields': [('venue.point', '2dsphere')]}, geo_indicies)
+ self.assertIn({"fields": [("venue.line", "2dsphere")]}, geo_indicies)
+ self.assertIn({"fields": [("venue.polygon", "2dsphere")]}, geo_indicies)
+ self.assertIn({"fields": [("venue.point", "2dsphere")]}, geo_indicies)
def test_geo_indexes_recursion(self):
-
class Location(Document):
name = StringField()
location = GeoPointField()
@@ -338,11 +369,11 @@ class GeoFieldTest(unittest.TestCase):
Location.drop_collection()
Parent.drop_collection()
- Parent(name='Berlin').save()
+ Parent(name="Berlin").save()
info = Parent._get_collection().index_information()
- self.assertNotIn('location_2d', info)
+ self.assertNotIn("location_2d", info)
info = Location._get_collection().index_information()
- self.assertIn('location_2d', info)
+ self.assertIn("location_2d", info)
self.assertEqual(len(Parent._geo_indices()), 0)
self.assertEqual(len(Location._geo_indices()), 1)
@@ -354,9 +385,7 @@ class GeoFieldTest(unittest.TestCase):
location = PointField(auto_index=False)
datetime = DateTimeField()
- meta = {
- 'indexes': [[("location", "2dsphere"), ("datetime", 1)]]
- }
+ meta = {"indexes": [[("location", "2dsphere"), ("datetime", 1)]]}
self.assertEqual([], Log._geo_indices())
@@ -364,8 +393,10 @@ class GeoFieldTest(unittest.TestCase):
Log.ensure_indexes()
info = Log._get_collection().index_information()
- self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
- [('location', '2dsphere'), ('datetime', 1)])
+ self.assertEqual(
+ info["location_2dsphere_datetime_1"]["key"],
+ [("location", "2dsphere"), ("datetime", 1)],
+ )
# Test listing explicitly
class Log(Document):
@@ -373,9 +404,7 @@ class GeoFieldTest(unittest.TestCase):
datetime = DateTimeField()
meta = {
- 'indexes': [
- {'fields': [("location", "2dsphere"), ("datetime", 1)]}
- ]
+ "indexes": [{"fields": [("location", "2dsphere"), ("datetime", 1)]}]
}
self.assertEqual([], Log._geo_indices())
@@ -384,9 +413,11 @@ class GeoFieldTest(unittest.TestCase):
Log.ensure_indexes()
info = Log._get_collection().index_information()
- self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
- [('location', '2dsphere'), ('datetime', 1)])
+ self.assertEqual(
+ info["location_2dsphere_datetime_1"]["key"],
+ [("location", "2dsphere"), ("datetime", 1)],
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py
index 8af75d4e..df4bf2de 100644
--- a/tests/fields/test_binary_field.py
+++ b/tests/fields/test_binary_field.py
@@ -9,19 +9,22 @@ from bson import Binary
from mongoengine import *
from tests.utils import MongoDBTestCase
-BIN_VALUE = six.b('\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5')
+BIN_VALUE = six.b(
+ "\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5"
+)
class TestBinaryField(MongoDBTestCase):
def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved.
"""
+
class Attachment(Document):
content_type = StringField()
blob = BinaryField()
- BLOB = six.b('\xe6\x00\xc4\xff\x07')
- MIME_TYPE = 'application/octet-stream'
+ BLOB = six.b("\xe6\x00\xc4\xff\x07")
+ MIME_TYPE = "application/octet-stream"
Attachment.drop_collection()
@@ -35,6 +38,7 @@ class TestBinaryField(MongoDBTestCase):
def test_validation_succeeds(self):
"""Ensure that valid values can be assigned to binary fields.
"""
+
class AttachmentRequired(Document):
blob = BinaryField(required=True)
@@ -43,11 +47,11 @@ class TestBinaryField(MongoDBTestCase):
attachment_required = AttachmentRequired()
self.assertRaises(ValidationError, attachment_required.validate)
- attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07'))
+ attachment_required.blob = Binary(six.b("\xe6\x00\xc4\xff\x07"))
attachment_required.validate()
- _5_BYTES = six.b('\xe6\x00\xc4\xff\x07')
- _4_BYTES = six.b('\xe6\x00\xc4\xff')
+ _5_BYTES = six.b("\xe6\x00\xc4\xff\x07")
+ _4_BYTES = six.b("\xe6\x00\xc4\xff")
self.assertRaises(ValidationError, AttachmentSizeLimit(blob=_5_BYTES).validate)
AttachmentSizeLimit(blob=_4_BYTES).validate()
@@ -57,7 +61,7 @@ class TestBinaryField(MongoDBTestCase):
class Attachment(Document):
blob = BinaryField()
- for invalid_data in (2, u'Im_a_unicode', ['some_str']):
+ for invalid_data in (2, u"Im_a_unicode", ["some_str"]):
self.assertRaises(ValidationError, Attachment(blob=invalid_data).validate)
def test__primary(self):
@@ -108,17 +112,17 @@ class TestBinaryField(MongoDBTestCase):
def test_modify_operation__set(self):
"""Ensures no regression of bug #1127"""
+
class MyDocument(Document):
some_field = StringField()
bin_field = BinaryField()
MyDocument.drop_collection()
- doc = MyDocument.objects(some_field='test').modify(
- upsert=True, new=True,
- set__bin_field=BIN_VALUE
+ doc = MyDocument.objects(some_field="test").modify(
+ upsert=True, new=True, set__bin_field=BIN_VALUE
)
- self.assertEqual(doc.some_field, 'test')
+ self.assertEqual(doc.some_field, "test")
if six.PY3:
self.assertEqual(doc.bin_field, BIN_VALUE)
else:
@@ -126,15 +130,18 @@ class TestBinaryField(MongoDBTestCase):
def test_update_one(self):
"""Ensures no regression of bug #1127"""
+
class MyDocument(Document):
bin_field = BinaryField()
MyDocument.drop_collection()
- bin_data = six.b('\xe6\x00\xc4\xff\x07')
+ bin_data = six.b("\xe6\x00\xc4\xff\x07")
doc = MyDocument(bin_field=bin_data).save()
- n_updated = MyDocument.objects(bin_field=bin_data).update_one(bin_field=BIN_VALUE)
+ n_updated = MyDocument.objects(bin_field=bin_data).update_one(
+ bin_field=BIN_VALUE
+ )
self.assertEqual(n_updated, 1)
fetched = MyDocument.objects.with_id(doc.id)
if six.PY3:
diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py
index 7a2a3db6..22ebb6f7 100644
--- a/tests/fields/test_boolean_field.py
+++ b/tests/fields/test_boolean_field.py
@@ -11,15 +11,13 @@ class TestBooleanField(MongoDBTestCase):
person = Person(admin=True)
person.save()
- self.assertEqual(
- get_as_pymongo(person),
- {'_id': person.id,
- 'admin': True})
+ self.assertEqual(get_as_pymongo(person), {"_id": person.id, "admin": True})
def test_validation(self):
"""Ensure that invalid values cannot be assigned to boolean
fields.
"""
+
class Person(Document):
admin = BooleanField()
@@ -29,9 +27,9 @@ class TestBooleanField(MongoDBTestCase):
person.admin = 2
self.assertRaises(ValidationError, person.validate)
- person.admin = 'Yes'
+ person.admin = "Yes"
self.assertRaises(ValidationError, person.validate)
- person.admin = 'False'
+ person.admin = "False"
self.assertRaises(ValidationError, person.validate)
def test_weirdness_constructor(self):
@@ -39,11 +37,12 @@ class TestBooleanField(MongoDBTestCase):
which causes some weird behavior. We dont necessarily want to maintain this behavior
but its a known issue
"""
+
class Person(Document):
admin = BooleanField()
- new_person = Person(admin='False')
+ new_person = Person(admin="False")
self.assertTrue(new_person.admin)
- new_person = Person(admin='0')
+ new_person = Person(admin="0")
self.assertTrue(new_person.admin)
diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py
index 470ecc5d..4e467587 100644
--- a/tests/fields/test_cached_reference_field.py
+++ b/tests/fields/test_cached_reference_field.py
@@ -7,12 +7,12 @@ from tests.utils import MongoDBTestCase
class TestCachedReferenceField(MongoDBTestCase):
-
def test_get_and_save(self):
"""
Tests #1047: CachedReferenceField creates DBRefs on to_python,
but can't save them on to_mongo.
"""
+
class Animal(Document):
name = StringField()
tag = StringField()
@@ -24,10 +24,11 @@ class TestCachedReferenceField(MongoDBTestCase):
Animal.drop_collection()
Ocorrence.drop_collection()
- Ocorrence(person="testte",
- animal=Animal(name="Leopard", tag="heavy").save()).save()
+ Ocorrence(
+ person="testte", animal=Animal(name="Leopard", tag="heavy").save()
+ ).save()
p = Ocorrence.objects.get()
- p.person = 'new_testte'
+ p.person = "new_testte"
p.save()
def test_general_things(self):
@@ -37,8 +38,7 @@ class TestCachedReferenceField(MongoDBTestCase):
class Ocorrence(Document):
person = StringField()
- animal = CachedReferenceField(
- Animal, fields=['tag'])
+ animal = CachedReferenceField(Animal, fields=["tag"])
Animal.drop_collection()
Ocorrence.drop_collection()
@@ -55,19 +55,18 @@ class TestCachedReferenceField(MongoDBTestCase):
self.assertEqual(Ocorrence.objects(animal=None).count(), 1)
- self.assertEqual(
- a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk})
+ self.assertEqual(a.to_mongo(fields=["tag"]), {"tag": "heavy", "_id": a.pk})
- self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
+ self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy")
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
- count = Ocorrence.objects(animal__tag='heavy').count()
+ count = Ocorrence.objects(animal__tag="heavy").count()
self.assertEqual(count, 1)
- ocorrence = Ocorrence.objects(animal__tag='heavy').first()
+ ocorrence = Ocorrence.objects(animal__tag="heavy").first()
self.assertEqual(ocorrence.person, "teste")
self.assertIsInstance(ocorrence.animal, Animal)
@@ -78,28 +77,21 @@ class TestCachedReferenceField(MongoDBTestCase):
class SocialTest(Document):
group = StringField()
- person = CachedReferenceField(
- PersonAuto,
- fields=('salary',))
+ person = CachedReferenceField(PersonAuto, fields=("salary",))
PersonAuto.drop_collection()
SocialTest.drop_collection()
- p = PersonAuto(name="Alberto", salary=Decimal('7000.00'))
+ p = PersonAuto(name="Alberto", salary=Decimal("7000.00"))
p.save()
s = SocialTest(group="dev", person=p)
s.save()
self.assertEqual(
- SocialTest.objects._collection.find_one({'person.salary': 7000.00}), {
- '_id': s.pk,
- 'group': s.group,
- 'person': {
- '_id': p.pk,
- 'salary': 7000.00
- }
- })
+ SocialTest.objects._collection.find_one({"person.salary": 7000.00}),
+ {"_id": s.pk, "group": s.group, "person": {"_id": p.pk, "salary": 7000.00}},
+ )
def test_cached_reference_field_reference(self):
class Group(Document):
@@ -111,17 +103,14 @@ class TestCachedReferenceField(MongoDBTestCase):
class SocialData(Document):
obs = StringField()
- tags = ListField(
- StringField())
- person = CachedReferenceField(
- Person,
- fields=('group',))
+ tags = ListField(StringField())
+ person = CachedReferenceField(Person, fields=("group",))
Group.drop_collection()
Person.drop_collection()
SocialData.drop_collection()
- g1 = Group(name='dev')
+ g1 = Group(name="dev")
g1.save()
g2 = Group(name="designers")
@@ -136,22 +125,21 @@ class TestCachedReferenceField(MongoDBTestCase):
p3 = Person(name="Afro design", group=g2)
p3.save()
- s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2'])
+ s1 = SocialData(obs="testing 123", person=p1, tags=["tag1", "tag2"])
s1.save()
- s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4'])
+ s2 = SocialData(obs="testing 321", person=p3, tags=["tag3", "tag4"])
s2.save()
- self.assertEqual(SocialData.objects._collection.find_one(
- {'tags': 'tag2'}), {
- '_id': s1.pk,
- 'obs': 'testing 123',
- 'tags': ['tag1', 'tag2'],
- 'person': {
- '_id': p1.pk,
- 'group': g1.pk
- }
- })
+ self.assertEqual(
+ SocialData.objects._collection.find_one({"tags": "tag2"}),
+ {
+ "_id": s1.pk,
+ "obs": "testing 123",
+ "tags": ["tag1", "tag2"],
+ "person": {"_id": p1.pk, "group": g1.pk},
+ },
+ )
self.assertEqual(SocialData.objects(person__group=g2).count(), 1)
self.assertEqual(SocialData.objects(person__group=g2).first(), s2)
@@ -163,23 +151,18 @@ class TestCachedReferenceField(MongoDBTestCase):
Product.drop_collection()
class Basket(Document):
- products = ListField(CachedReferenceField(Product, fields=['name']))
+ products = ListField(CachedReferenceField(Product, fields=["name"]))
Basket.drop_collection()
- product1 = Product(name='abc').save()
- product2 = Product(name='def').save()
+ product1 = Product(name="abc").save()
+ product2 = Product(name="def").save()
basket = Basket(products=[product1]).save()
self.assertEqual(
Basket.objects._collection.find_one(),
{
- '_id': basket.pk,
- 'products': [
- {
- '_id': product1.pk,
- 'name': product1.name
- }
- ]
- }
+ "_id": basket.pk,
+ "products": [{"_id": product1.pk, "name": product1.name}],
+ },
)
# push to list
basket.update(push__products=product2)
@@ -187,161 +170,135 @@ class TestCachedReferenceField(MongoDBTestCase):
self.assertEqual(
Basket.objects._collection.find_one(),
{
- '_id': basket.pk,
- 'products': [
- {
- '_id': product1.pk,
- 'name': product1.name
- },
- {
- '_id': product2.pk,
- 'name': product2.name
- }
- ]
- }
+ "_id": basket.pk,
+ "products": [
+ {"_id": product1.pk, "name": product1.name},
+ {"_id": product2.pk, "name": product2.name},
+ ],
+ },
)
def test_cached_reference_field_update_all(self):
class Person(Document):
- TYPES = (
- ('pf', "PF"),
- ('pj', "PJ")
- )
+ TYPES = (("pf", "PF"), ("pj", "PJ"))
name = StringField()
tp = StringField(choices=TYPES)
- father = CachedReferenceField('self', fields=('tp',))
+ father = CachedReferenceField("self", fields=("tp",))
Person.drop_collection()
a1 = Person(name="Wilson Father", tp="pj")
a1.save()
- a2 = Person(name='Wilson Junior', tp='pf', father=a1)
+ a2 = Person(name="Wilson Junior", tp="pf", father=a1)
a2.save()
a2 = Person.objects.with_id(a2.id)
self.assertEqual(a2.father.tp, a1.tp)
- self.assertEqual(dict(a2.to_mongo()), {
- "_id": a2.pk,
- "name": u"Wilson Junior",
- "tp": u"pf",
- "father": {
- "_id": a1.pk,
- "tp": u"pj"
- }
- })
+ self.assertEqual(
+ dict(a2.to_mongo()),
+ {
+ "_id": a2.pk,
+ "name": u"Wilson Junior",
+ "tp": u"pf",
+ "father": {"_id": a1.pk, "tp": u"pj"},
+ },
+ )
- self.assertEqual(Person.objects(father=a1)._query, {
- 'father._id': a1.pk
- })
+ self.assertEqual(Person.objects(father=a1)._query, {"father._id": a1.pk})
self.assertEqual(Person.objects(father=a1).count(), 1)
Person.objects.update(set__tp="pf")
Person.father.sync_all()
a2.reload()
- self.assertEqual(dict(a2.to_mongo()), {
- "_id": a2.pk,
- "name": u"Wilson Junior",
- "tp": u"pf",
- "father": {
- "_id": a1.pk,
- "tp": u"pf"
- }
- })
+ self.assertEqual(
+ dict(a2.to_mongo()),
+ {
+ "_id": a2.pk,
+ "name": u"Wilson Junior",
+ "tp": u"pf",
+ "father": {"_id": a1.pk, "tp": u"pf"},
+ },
+ )
def test_cached_reference_fields_on_embedded_documents(self):
with self.assertRaises(InvalidDocumentError):
+
class Test(Document):
name = StringField()
- type('WrongEmbeddedDocument', (
- EmbeddedDocument,), {
- 'test': CachedReferenceField(Test)
- })
+ type(
+ "WrongEmbeddedDocument",
+ (EmbeddedDocument,),
+ {"test": CachedReferenceField(Test)},
+ )
def test_cached_reference_auto_sync(self):
class Person(Document):
- TYPES = (
- ('pf', "PF"),
- ('pj', "PJ")
- )
+ TYPES = (("pf", "PF"), ("pj", "PJ"))
name = StringField()
- tp = StringField(
- choices=TYPES
- )
+ tp = StringField(choices=TYPES)
- father = CachedReferenceField('self', fields=('tp',))
+ father = CachedReferenceField("self", fields=("tp",))
Person.drop_collection()
a1 = Person(name="Wilson Father", tp="pj")
a1.save()
- a2 = Person(name='Wilson Junior', tp='pf', father=a1)
+ a2 = Person(name="Wilson Junior", tp="pf", father=a1)
a2.save()
- a1.tp = 'pf'
+ a1.tp = "pf"
a1.save()
a2.reload()
- self.assertEqual(dict(a2.to_mongo()), {
- '_id': a2.pk,
- 'name': 'Wilson Junior',
- 'tp': 'pf',
- 'father': {
- '_id': a1.pk,
- 'tp': 'pf'
- }
- })
+ self.assertEqual(
+ dict(a2.to_mongo()),
+ {
+ "_id": a2.pk,
+ "name": "Wilson Junior",
+ "tp": "pf",
+ "father": {"_id": a1.pk, "tp": "pf"},
+ },
+ )
def test_cached_reference_auto_sync_disabled(self):
class Persone(Document):
- TYPES = (
- ('pf', "PF"),
- ('pj', "PJ")
- )
+ TYPES = (("pf", "PF"), ("pj", "PJ"))
name = StringField()
- tp = StringField(
- choices=TYPES
- )
+ tp = StringField(choices=TYPES)
- father = CachedReferenceField(
- 'self', fields=('tp',), auto_sync=False)
+ father = CachedReferenceField("self", fields=("tp",), auto_sync=False)
Persone.drop_collection()
a1 = Persone(name="Wilson Father", tp="pj")
a1.save()
- a2 = Persone(name='Wilson Junior', tp='pf', father=a1)
+ a2 = Persone(name="Wilson Junior", tp="pf", father=a1)
a2.save()
- a1.tp = 'pf'
+ a1.tp = "pf"
a1.save()
- self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), {
- '_id': a2.pk,
- 'name': 'Wilson Junior',
- 'tp': 'pf',
- 'father': {
- '_id': a1.pk,
- 'tp': 'pj'
- }
- })
+ self.assertEqual(
+ Persone.objects._collection.find_one({"_id": a2.pk}),
+ {
+ "_id": a2.pk,
+ "name": "Wilson Junior",
+ "tp": "pf",
+ "father": {"_id": a1.pk, "tp": "pj"},
+ },
+ )
def test_cached_reference_embedded_fields(self):
class Owner(EmbeddedDocument):
- TPS = (
- ('n', "Normal"),
- ('u', "Urgent")
- )
+ TPS = (("n", "Normal"), ("u", "Urgent"))
name = StringField()
- tp = StringField(
- verbose_name="Type",
- db_field="t",
- choices=TPS)
+ tp = StringField(verbose_name="Type", db_field="t", choices=TPS)
class Animal(Document):
name = StringField()
@@ -351,43 +308,38 @@ class TestCachedReferenceField(MongoDBTestCase):
class Ocorrence(Document):
person = StringField()
- animal = CachedReferenceField(
- Animal, fields=['tag', 'owner.tp'])
+ animal = CachedReferenceField(Animal, fields=["tag", "owner.tp"])
Animal.drop_collection()
Ocorrence.drop_collection()
- a = Animal(name="Leopard", tag="heavy",
- owner=Owner(tp='u', name="Wilson Júnior")
- )
+ a = Animal(
+ name="Leopard", tag="heavy", owner=Owner(tp="u", name="Wilson Júnior")
+ )
a.save()
o = Ocorrence(person="teste", animal=a)
o.save()
- self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), {
- '_id': a.pk,
- 'tag': 'heavy',
- 'owner': {
- 't': 'u'
- }
- })
- self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
- self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u')
+ self.assertEqual(
+ dict(a.to_mongo(fields=["tag", "owner.tp"])),
+ {"_id": a.pk, "tag": "heavy", "owner": {"t": "u"}},
+ )
+ self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy")
+ self.assertEqual(o.to_mongo()["animal"]["owner"]["t"], "u")
# Check to_mongo with fields
- self.assertNotIn('animal', o.to_mongo(fields=['person']))
+ self.assertNotIn("animal", o.to_mongo(fields=["person"]))
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
- count = Ocorrence.objects(
- animal__tag='heavy', animal__owner__tp='u').count()
+ count = Ocorrence.objects(animal__tag="heavy", animal__owner__tp="u").count()
self.assertEqual(count, 1)
ocorrence = Ocorrence.objects(
- animal__tag='heavy',
- animal__owner__tp='u').first()
+ animal__tag="heavy", animal__owner__tp="u"
+ ).first()
self.assertEqual(ocorrence.person, "teste")
self.assertIsInstance(ocorrence.animal, Animal)
@@ -404,43 +356,39 @@ class TestCachedReferenceField(MongoDBTestCase):
class Ocorrence(Document):
person = StringField()
- animal = CachedReferenceField(
- Animal, fields=['tag', 'owner.tags'])
+ animal = CachedReferenceField(Animal, fields=["tag", "owner.tags"])
Animal.drop_collection()
Ocorrence.drop_collection()
- a = Animal(name="Leopard", tag="heavy",
- owner=Owner(tags=['cool', 'funny'],
- name="Wilson Júnior")
- )
+ a = Animal(
+ name="Leopard",
+ tag="heavy",
+ owner=Owner(tags=["cool", "funny"], name="Wilson Júnior"),
+ )
a.save()
o = Ocorrence(person="teste 2", animal=a)
o.save()
- self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), {
- '_id': a.pk,
- 'tag': 'heavy',
- 'owner': {
- 'tags': ['cool', 'funny']
- }
- })
+ self.assertEqual(
+ dict(a.to_mongo(fields=["tag", "owner.tags"])),
+ {"_id": a.pk, "tag": "heavy", "owner": {"tags": ["cool", "funny"]}},
+ )
- self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy')
- self.assertEqual(o.to_mongo()['animal']['owner']['tags'],
- ['cool', 'funny'])
+ self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy")
+ self.assertEqual(o.to_mongo()["animal"]["owner"]["tags"], ["cool", "funny"])
# counts
Ocorrence(person="teste 2").save()
Ocorrence(person="teste 3").save()
query = Ocorrence.objects(
- animal__tag='heavy', animal__owner__tags='cool')._query
- self.assertEqual(
- query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'})
+ animal__tag="heavy", animal__owner__tags="cool"
+ )._query
+ self.assertEqual(query, {"animal.owner.tags": "cool", "animal.tag": "heavy"})
ocorrence = Ocorrence.objects(
- animal__tag='heavy',
- animal__owner__tags='cool').first()
+ animal__tag="heavy", animal__owner__tags="cool"
+ ).first()
self.assertEqual(ocorrence.person, "teste 2")
self.assertIsInstance(ocorrence.animal, Animal)
diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py
index 58dc4b43..4eea5bdc 100644
--- a/tests/fields/test_complex_datetime_field.py
+++ b/tests/fields/test_complex_datetime_field.py
@@ -14,9 +14,10 @@ class ComplexDateTimeFieldTest(MongoDBTestCase):
"""Tests for complex datetime fields - which can handle
microseconds without rounding.
"""
+
class LogEntry(Document):
date = ComplexDateTimeField()
- date_with_dots = ComplexDateTimeField(separator='.')
+ date_with_dots = ComplexDateTimeField(separator=".")
LogEntry.drop_collection()
@@ -62,17 +63,25 @@ class ComplexDateTimeFieldTest(MongoDBTestCase):
mm = dd = hh = ii = ss = [1, 10]
for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond):
- stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date']
- self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None)
+ stored = LogEntry(date=datetime.datetime(*values)).to_mongo()["date"]
+ self.assertTrue(
+ re.match("^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$", stored)
+ is not None
+ )
# Test separator
- stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots']
- self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None)
+ stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()[
+ "date_with_dots"
+ ]
+ self.assertTrue(
+ re.match("^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$", stored) is not None
+ )
def test_complexdatetime_usage(self):
"""Tests for complex datetime fields - which can handle
microseconds without rounding.
"""
+
class LogEntry(Document):
date = ComplexDateTimeField()
@@ -123,22 +132,21 @@ class ComplexDateTimeFieldTest(MongoDBTestCase):
# Test microsecond-level ordering/filtering
for microsecond in (99, 999, 9999, 10000):
- LogEntry(
- date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond)
- ).save()
+ LogEntry(date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond)).save()
- logs = list(LogEntry.objects.order_by('date'))
+ logs = list(LogEntry.objects.order_by("date"))
for next_idx, log in enumerate(logs[:-1], start=1):
next_log = logs[next_idx]
self.assertTrue(log.date < next_log.date)
- logs = list(LogEntry.objects.order_by('-date'))
+ logs = list(LogEntry.objects.order_by("-date"))
for next_idx, log in enumerate(logs[:-1], start=1):
next_log = logs[next_idx]
self.assertTrue(log.date > next_log.date)
logs = LogEntry.objects.filter(
- date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000))
+ date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)
+ )
self.assertEqual(logs.count(), 4)
def test_no_default_value(self):
@@ -156,6 +164,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase):
def test_default_static_value(self):
NOW = datetime.datetime.utcnow()
+
class Log(Document):
timestamp = ComplexDateTimeField(default=NOW)
diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py
index 82adb514..da572134 100644
--- a/tests/fields/test_date_field.py
+++ b/tests/fields/test_date_field.py
@@ -18,10 +18,11 @@ class TestDateField(MongoDBTestCase):
Ensure an exception is raised when trying to
cast an empty string to datetime.
"""
+
class MyDoc(Document):
dt = DateField()
- md = MyDoc(dt='')
+ md = MyDoc(dt="")
self.assertRaises(ValidationError, md.save)
def test_date_from_whitespace_string(self):
@@ -29,16 +30,18 @@ class TestDateField(MongoDBTestCase):
Ensure an exception is raised when trying to
cast a whitespace-only string to datetime.
"""
+
class MyDoc(Document):
dt = DateField()
- md = MyDoc(dt=' ')
+ md = MyDoc(dt=" ")
self.assertRaises(ValidationError, md.save)
def test_default_values_today(self):
"""Ensure that default field values are used when creating
a document.
"""
+
class Person(Document):
day = DateField(default=datetime.date.today)
@@ -46,13 +49,14 @@ class TestDateField(MongoDBTestCase):
person.validate()
self.assertEqual(person.day, person.day)
self.assertEqual(person.day, datetime.date.today())
- self.assertEqual(person._data['day'], person.day)
+ self.assertEqual(person._data["day"], person.day)
def test_date(self):
"""Tests showing pymongo date fields
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
"""
+
class LogEntry(Document):
date = DateField()
@@ -95,6 +99,7 @@ class TestDateField(MongoDBTestCase):
def test_regular_usage(self):
"""Tests for regular datetime fields"""
+
class LogEntry(Document):
date = DateField()
@@ -106,12 +111,12 @@ class TestDateField(MongoDBTestCase):
log.validate()
log.save()
- for query in (d1, d1.isoformat(' ')):
+ for query in (d1, d1.isoformat(" ")):
log1 = LogEntry.objects.get(date=query)
self.assertEqual(log, log1)
if dateutil:
- log1 = LogEntry.objects.get(date=d1.isoformat('T'))
+ log1 = LogEntry.objects.get(date=d1.isoformat("T"))
self.assertEqual(log, log1)
# create additional 19 log entries for a total of 20
@@ -142,6 +147,7 @@ class TestDateField(MongoDBTestCase):
"""Ensure that invalid values cannot be assigned to datetime
fields.
"""
+
class LogEntry(Document):
time = DateField()
@@ -152,14 +158,14 @@ class TestDateField(MongoDBTestCase):
log.time = datetime.date.today()
log.validate()
- log.time = datetime.datetime.now().isoformat(' ')
+ log.time = datetime.datetime.now().isoformat(" ")
log.validate()
if dateutil:
- log.time = datetime.datetime.now().isoformat('T')
+ log.time = datetime.datetime.now().isoformat("T")
log.validate()
log.time = -1
self.assertRaises(ValidationError, log.validate)
- log.time = 'ABC'
+ log.time = "ABC"
self.assertRaises(ValidationError, log.validate)
diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py
index 92f0668a..c911390a 100644
--- a/tests/fields/test_datetime_field.py
+++ b/tests/fields/test_datetime_field.py
@@ -19,10 +19,11 @@ class TestDateTimeField(MongoDBTestCase):
Ensure an exception is raised when trying to
cast an empty string to datetime.
"""
+
class MyDoc(Document):
dt = DateTimeField()
- md = MyDoc(dt='')
+ md = MyDoc(dt="")
self.assertRaises(ValidationError, md.save)
def test_datetime_from_whitespace_string(self):
@@ -30,16 +31,18 @@ class TestDateTimeField(MongoDBTestCase):
Ensure an exception is raised when trying to
cast a whitespace-only string to datetime.
"""
+
class MyDoc(Document):
dt = DateTimeField()
- md = MyDoc(dt=' ')
+ md = MyDoc(dt=" ")
self.assertRaises(ValidationError, md.save)
def test_default_value_utcnow(self):
"""Ensure that default field values are used when creating
a document.
"""
+
class Person(Document):
created = DateTimeField(default=dt.datetime.utcnow)
@@ -48,8 +51,10 @@ class TestDateTimeField(MongoDBTestCase):
person.validate()
person_created_t0 = person.created
self.assertLess(person.created - utcnow, dt.timedelta(seconds=1))
- self.assertEqual(person_created_t0, person.created) # make sure it does not change
- self.assertEqual(person._data['created'], person.created)
+ self.assertEqual(
+ person_created_t0, person.created
+ ) # make sure it does not change
+ self.assertEqual(person._data["created"], person.created)
def test_handling_microseconds(self):
"""Tests showing pymongo datetime fields handling of microseconds.
@@ -58,6 +63,7 @@ class TestDateTimeField(MongoDBTestCase):
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
"""
+
class LogEntry(Document):
date = DateTimeField()
@@ -103,6 +109,7 @@ class TestDateTimeField(MongoDBTestCase):
def test_regular_usage(self):
"""Tests for regular datetime fields"""
+
class LogEntry(Document):
date = DateTimeField()
@@ -114,12 +121,12 @@ class TestDateTimeField(MongoDBTestCase):
log.validate()
log.save()
- for query in (d1, d1.isoformat(' ')):
+ for query in (d1, d1.isoformat(" ")):
log1 = LogEntry.objects.get(date=query)
self.assertEqual(log, log1)
if dateutil:
- log1 = LogEntry.objects.get(date=d1.isoformat('T'))
+ log1 = LogEntry.objects.get(date=d1.isoformat("T"))
self.assertEqual(log, log1)
# create additional 19 log entries for a total of 20
@@ -150,8 +157,7 @@ class TestDateTimeField(MongoDBTestCase):
self.assertEqual(logs.count(), 10)
logs = LogEntry.objects.filter(
- date__lte=dt.datetime(1980, 1, 1),
- date__gte=dt.datetime(1975, 1, 1),
+ date__lte=dt.datetime(1980, 1, 1), date__gte=dt.datetime(1975, 1, 1)
)
self.assertEqual(logs.count(), 5)
@@ -159,6 +165,7 @@ class TestDateTimeField(MongoDBTestCase):
"""Ensure that invalid values cannot be assigned to datetime
fields.
"""
+
class LogEntry(Document):
time = DateTimeField()
@@ -169,32 +176,32 @@ class TestDateTimeField(MongoDBTestCase):
log.time = dt.date.today()
log.validate()
- log.time = dt.datetime.now().isoformat(' ')
+ log.time = dt.datetime.now().isoformat(" ")
log.validate()
- log.time = '2019-05-16 21:42:57.897847'
+ log.time = "2019-05-16 21:42:57.897847"
log.validate()
if dateutil:
- log.time = dt.datetime.now().isoformat('T')
+ log.time = dt.datetime.now().isoformat("T")
log.validate()
log.time = -1
self.assertRaises(ValidationError, log.validate)
- log.time = 'ABC'
+ log.time = "ABC"
self.assertRaises(ValidationError, log.validate)
- log.time = '2019-05-16 21:GARBAGE:12'
+ log.time = "2019-05-16 21:GARBAGE:12"
self.assertRaises(ValidationError, log.validate)
- log.time = '2019-05-16 21:42:57.GARBAGE'
+ log.time = "2019-05-16 21:42:57.GARBAGE"
self.assertRaises(ValidationError, log.validate)
- log.time = '2019-05-16 21:42:57.123.456'
+ log.time = "2019-05-16 21:42:57.123.456"
self.assertRaises(ValidationError, log.validate)
def test_parse_datetime_as_str(self):
class DTDoc(Document):
date = DateTimeField()
- date_str = '2019-03-02 22:26:01'
+ date_str = "2019-03-02 22:26:01"
# make sure that passing a parsable datetime works
dtd = DTDoc()
@@ -206,7 +213,7 @@ class TestDateTimeField(MongoDBTestCase):
self.assertIsInstance(dtd.date, dt.datetime)
self.assertEqual(str(dtd.date), date_str)
- dtd.date = 'January 1st, 9999999999'
+ dtd.date = "January 1st, 9999999999"
self.assertRaises(ValidationError, dtd.validate)
@@ -217,7 +224,7 @@ class TestDateTimeTzAware(MongoDBTestCase):
connection._connections = {}
connection._dbs = {}
- connect(db='mongoenginetest', tz_aware=True)
+ connect(db="mongoenginetest", tz_aware=True)
class LogEntry(Document):
time = DateTimeField()
@@ -228,4 +235,4 @@ class TestDateTimeTzAware(MongoDBTestCase):
log = LogEntry.objects.first()
log.time = dt.datetime(2013, 1, 1, 0, 0, 0)
- self.assertEqual(['time'], log._changed_fields)
+ self.assertEqual(["time"], log._changed_fields)
diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py
index 0213b880..30b7e5ea 100644
--- a/tests/fields/test_decimal_field.py
+++ b/tests/fields/test_decimal_field.py
@@ -7,32 +7,31 @@ from tests.utils import MongoDBTestCase
class TestDecimalField(MongoDBTestCase):
-
def test_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields.
"""
+
class Person(Document):
- height = DecimalField(min_value=Decimal('0.1'),
- max_value=Decimal('3.5'))
+ height = DecimalField(min_value=Decimal("0.1"), max_value=Decimal("3.5"))
Person.drop_collection()
- Person(height=Decimal('1.89')).save()
+ Person(height=Decimal("1.89")).save()
person = Person.objects.first()
- self.assertEqual(person.height, Decimal('1.89'))
+ self.assertEqual(person.height, Decimal("1.89"))
- person.height = '2.0'
+ person.height = "2.0"
person.save()
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
- person.height = Decimal('0.01')
+ person.height = Decimal("0.01")
self.assertRaises(ValidationError, person.validate)
- person.height = Decimal('4.0')
+ person.height = Decimal("4.0")
self.assertRaises(ValidationError, person.validate)
- person.height = 'something invalid'
+ person.height = "something invalid"
self.assertRaises(ValidationError, person.validate)
- person_2 = Person(height='something invalid')
+ person_2 = Person(height="something invalid")
self.assertRaises(ValidationError, person_2.validate)
def test_comparison(self):
@@ -58,7 +57,14 @@ class TestDecimalField(MongoDBTestCase):
string_value = DecimalField(precision=4, force_string=True)
Person.drop_collection()
- values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")]
+ values_to_store = [
+ 10,
+ 10.1,
+ 10.11,
+ "10.111",
+ Decimal("10.1111"),
+ Decimal("10.11111"),
+ ]
for store_at_creation in [True, False]:
for value in values_to_store:
# to_python is called explicitly if values were sent in the kwargs of __init__
@@ -72,20 +78,27 @@ class TestDecimalField(MongoDBTestCase):
# How its stored
expected = [
- {'float_value': 10.0, 'string_value': '10.0000'},
- {'float_value': 10.1, 'string_value': '10.1000'},
- {'float_value': 10.11, 'string_value': '10.1100'},
- {'float_value': 10.111, 'string_value': '10.1110'},
- {'float_value': 10.1111, 'string_value': '10.1111'},
- {'float_value': 10.1111, 'string_value': '10.1111'}]
+ {"float_value": 10.0, "string_value": "10.0000"},
+ {"float_value": 10.1, "string_value": "10.1000"},
+ {"float_value": 10.11, "string_value": "10.1100"},
+ {"float_value": 10.111, "string_value": "10.1110"},
+ {"float_value": 10.1111, "string_value": "10.1111"},
+ {"float_value": 10.1111, "string_value": "10.1111"},
+ ]
expected.extend(expected)
- actual = list(Person.objects.exclude('id').as_pymongo())
+ actual = list(Person.objects.exclude("id").as_pymongo())
self.assertEqual(expected, actual)
# How it comes out locally
- expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'),
- Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')]
+ expected = [
+ Decimal("10.0000"),
+ Decimal("10.1000"),
+ Decimal("10.1100"),
+ Decimal("10.1110"),
+ Decimal("10.1111"),
+ Decimal("10.1111"),
+ ]
expected.extend(expected)
- for field_name in ['float_value', 'string_value']:
+ for field_name in ["float_value", "string_value"]:
actual = list(Person.objects().scalar(field_name))
self.assertEqual(expected, actual)
diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py
index ade02ccf..07bab85b 100644
--- a/tests/fields/test_dict_field.py
+++ b/tests/fields/test_dict_field.py
@@ -6,95 +6,92 @@ from tests.utils import MongoDBTestCase, get_as_pymongo
class TestDictField(MongoDBTestCase):
-
def test_storage(self):
class BlogPost(Document):
info = DictField()
BlogPost.drop_collection()
- info = {'testkey': 'testvalue'}
+ info = {"testkey": "testvalue"}
post = BlogPost(info=info).save()
- self.assertEqual(
- get_as_pymongo(post),
- {
- '_id': post.id,
- 'info': info
- }
- )
+ self.assertEqual(get_as_pymongo(post), {"_id": post.id, "info": info})
def test_general_things(self):
"""Ensure that dict types work as expected."""
+
class BlogPost(Document):
info = DictField()
BlogPost.drop_collection()
post = BlogPost()
- post.info = 'my post'
+ post.info = "my post"
self.assertRaises(ValidationError, post.validate)
- post.info = ['test', 'test']
+ post.info = ["test", "test"]
self.assertRaises(ValidationError, post.validate)
- post.info = {'$title': 'test'}
+ post.info = {"$title": "test"}
self.assertRaises(ValidationError, post.validate)
- post.info = {'nested': {'$title': 'test'}}
+ post.info = {"nested": {"$title": "test"}}
self.assertRaises(ValidationError, post.validate)
- post.info = {'the.title': 'test'}
+ post.info = {"the.title": "test"}
self.assertRaises(ValidationError, post.validate)
- post.info = {'nested': {'the.title': 'test'}}
+ post.info = {"nested": {"the.title": "test"}}
self.assertRaises(ValidationError, post.validate)
- post.info = {1: 'test'}
+ post.info = {1: "test"}
self.assertRaises(ValidationError, post.validate)
- post.info = {'title': 'test'}
+ post.info = {"title": "test"}
post.save()
post = BlogPost()
- post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}}
+ post.info = {"title": "dollar_sign", "details": {"te$t": "test"}}
post.save()
post = BlogPost()
- post.info = {'details': {'test': 'test'}}
+ post.info = {"details": {"test": "test"}}
post.save()
post = BlogPost()
- post.info = {'details': {'test': 3}}
+ post.info = {"details": {"test": 3}}
post.save()
self.assertEqual(BlogPost.objects.count(), 4)
+ self.assertEqual(BlogPost.objects.filter(info__title__exact="test").count(), 1)
self.assertEqual(
- BlogPost.objects.filter(info__title__exact='test').count(), 1)
- self.assertEqual(
- BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
+ BlogPost.objects.filter(info__details__test__exact="test").count(), 1
+ )
- post = BlogPost.objects.filter(info__title__exact='dollar_sign').first()
- self.assertIn('te$t', post['info']['details'])
+ post = BlogPost.objects.filter(info__title__exact="dollar_sign").first()
+ self.assertIn("te$t", post["info"]["details"])
# Confirm handles non strings or non existing keys
self.assertEqual(
- BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
+ BlogPost.objects.filter(info__details__test__exact=5).count(), 0
+ )
self.assertEqual(
- BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
+ BlogPost.objects.filter(info__made_up__test__exact="test").count(), 0
+ )
- post = BlogPost.objects.create(info={'title': 'original'})
- post.info.update({'title': 'updated'})
+ post = BlogPost.objects.create(info={"title": "original"})
+ post.info.update({"title": "updated"})
post.save()
post.reload()
- self.assertEqual('updated', post.info['title'])
+ self.assertEqual("updated", post.info["title"])
- post.info.setdefault('authors', [])
+ post.info.setdefault("authors", [])
post.save()
post.reload()
- self.assertEqual([], post.info['authors'])
+ self.assertEqual([], post.info["authors"])
def test_dictfield_dump_document(self):
"""Ensure a DictField can handle another document's dump."""
+
class Doc(Document):
field = DictField()
@@ -106,51 +103,62 @@ class TestDictField(MongoDBTestCase):
id = IntField(primary_key=True, default=1)
recursive = DictField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class ToEmbedChild(ToEmbedParent):
pass
to_embed_recursive = ToEmbed(id=1).save()
to_embed = ToEmbed(
- id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save()
+ id=2, recursive=to_embed_recursive.to_mongo().to_dict()
+ ).save()
doc = Doc(field=to_embed.to_mongo().to_dict())
doc.save()
self.assertIsInstance(doc.field, dict)
- self.assertEqual(doc.field, {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}})
+ self.assertEqual(
+ doc.field, {"_id": 2, "recursive": {"_id": 1, "recursive": {}}}
+ )
# Same thing with a Document with a _cls field
to_embed_recursive = ToEmbedChild(id=1).save()
to_embed_child = ToEmbedChild(
- id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save()
+ id=2, recursive=to_embed_recursive.to_mongo().to_dict()
+ ).save()
doc = Doc(field=to_embed_child.to_mongo().to_dict())
doc.save()
self.assertIsInstance(doc.field, dict)
expected = {
- '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild',
- 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}}
+ "_id": 2,
+ "_cls": "ToEmbedParent.ToEmbedChild",
+ "recursive": {
+ "_id": 1,
+ "_cls": "ToEmbedParent.ToEmbedChild",
+ "recursive": {},
+ },
}
self.assertEqual(doc.field, expected)
def test_dictfield_strict(self):
"""Ensure that dict field handles validation if provided a strict field type."""
+
class Simple(Document):
mapping = DictField(field=IntField())
Simple.drop_collection()
e = Simple()
- e.mapping['someint'] = 1
+ e.mapping["someint"] = 1
e.save()
# try creating an invalid mapping
with self.assertRaises(ValidationError):
- e.mapping['somestring'] = "abc"
+ e.mapping["somestring"] = "abc"
e.save()
def test_dictfield_complex(self):
"""Ensure that the dict field can handle the complex types."""
+
class SettingBase(EmbeddedDocument):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class StringSetting(SettingBase):
value = StringField()
@@ -164,70 +172,72 @@ class TestDictField(MongoDBTestCase):
Simple.drop_collection()
e = Simple()
- e.mapping['somestring'] = StringSetting(value='foo')
- e.mapping['someint'] = IntegerSetting(value=42)
- e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!',
- 'float': 1.001,
- 'complex': IntegerSetting(value=42),
- 'list': [IntegerSetting(value=42),
- StringSetting(value='foo')]}
+ e.mapping["somestring"] = StringSetting(value="foo")
+ e.mapping["someint"] = IntegerSetting(value=42)
+ e.mapping["nested_dict"] = {
+ "number": 1,
+ "string": "Hi!",
+ "float": 1.001,
+ "complex": IntegerSetting(value=42),
+ "list": [IntegerSetting(value=42), StringSetting(value="foo")],
+ }
e.save()
e2 = Simple.objects.get(id=e.id)
- self.assertIsInstance(e2.mapping['somestring'], StringSetting)
- self.assertIsInstance(e2.mapping['someint'], IntegerSetting)
+ self.assertIsInstance(e2.mapping["somestring"], StringSetting)
+ self.assertIsInstance(e2.mapping["someint"], IntegerSetting)
# Test querying
+ self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEqual(
- Simple.objects.filter(mapping__someint__value=42).count(), 1)
+ Simple.objects.filter(mapping__nested_dict__number=1).count(), 1
+ )
self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
+ Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1
+ )
self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
+ Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1
+ )
self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
- self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
+ Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 1
+ )
# Confirm can update
+ Simple.objects().update(set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update(
- set__mapping={"someint": IntegerSetting(value=10)})
- Simple.objects().update(
- set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
+ set__mapping__nested_dict__list__1=StringSetting(value="Boo")
+ )
self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
+ Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 0
+ )
self.assertEqual(
- Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
+ Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count(), 1
+ )
def test_push_dict(self):
class MyModel(Document):
events = ListField(DictField())
- doc = MyModel(events=[{'a': 1}]).save()
+ doc = MyModel(events=[{"a": 1}]).save()
raw_doc = get_as_pymongo(doc)
- expected_raw_doc = {
- '_id': doc.id,
- 'events': [{'a': 1}]
- }
+ expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}]}
self.assertEqual(raw_doc, expected_raw_doc)
MyModel.objects(id=doc.id).update(push__events={})
raw_doc = get_as_pymongo(doc)
- expected_raw_doc = {
- '_id': doc.id,
- 'events': [{'a': 1}, {}]
- }
+ expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}, {}]}
self.assertEqual(raw_doc, expected_raw_doc)
def test_ensure_unique_default_instances(self):
"""Ensure that every field has it's own unique default instance."""
+
class D(Document):
data = DictField()
data2 = DictField(default=lambda: {})
d1 = D()
- d1.data['foo'] = 'bar'
- d1.data2['foo'] = 'bar'
+ d1.data["foo"] = "bar"
+ d1.data2["foo"] = "bar"
d2 = D()
self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {})
@@ -255,22 +265,25 @@ class TestDictField(MongoDBTestCase):
class Embedded(EmbeddedDocument):
name = StringField()
- embed = Embedded(name='garbage')
+ embed = Embedded(name="garbage")
doc = DictFieldTest(dictionary=embed)
with self.assertRaises(ValidationError) as ctx_err:
doc.validate()
self.assertIn("'dictionary'", str(ctx_err.exception))
- self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception))
+ self.assertIn(
+ "Only dictionaries may be used in a DictField", str(ctx_err.exception)
+ )
def test_atomic_update_dict_field(self):
"""Ensure that the entire DictField can be atomically updated."""
+
class Simple(Document):
mapping = DictField(field=ListField(IntField(required=True)))
Simple.drop_collection()
e = Simple()
- e.mapping['someints'] = [1, 2]
+ e.mapping["someints"] = [1, 2]
e.save()
e.update(set__mapping={"ints": [3, 4]})
e.reload()
@@ -279,7 +292,7 @@ class TestDictField(MongoDBTestCase):
# try creating an invalid mapping
with self.assertRaises(ValueError):
- e.update(set__mapping={"somestrings": ["foo", "bar", ]})
+ e.update(set__mapping={"somestrings": ["foo", "bar"]})
def test_dictfield_with_referencefield_complex_nesting_cases(self):
"""Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)"""
@@ -296,29 +309,33 @@ class TestDictField(MongoDBTestCase):
mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False)))
mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True))))
mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False))))
- mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True)))))
- mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False)))))
+ mapping8 = DictField(
+ ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))
+ )
+ mapping9 = DictField(
+ ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))
+ )
Doc.drop_collection()
Simple.drop_collection()
- d = Doc(s='aa').save()
+ d = Doc(s="aa").save()
e = Simple()
- e.mapping0['someint'] = e.mapping1['someint'] = d
- e.mapping2['someint'] = e.mapping3['someint'] = [d]
- e.mapping4['someint'] = e.mapping5['someint'] = {'d': d}
- e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}]
- e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}]
+ e.mapping0["someint"] = e.mapping1["someint"] = d
+ e.mapping2["someint"] = e.mapping3["someint"] = [d]
+ e.mapping4["someint"] = e.mapping5["someint"] = {"d": d}
+ e.mapping6["someint"] = e.mapping7["someint"] = [{"d": d}]
+ e.mapping8["someint"] = e.mapping9["someint"] = [{"d": [d]}]
e.save()
s = Simple.objects.first()
- self.assertIsInstance(s.mapping0['someint'], Doc)
- self.assertIsInstance(s.mapping1['someint'], Doc)
- self.assertIsInstance(s.mapping2['someint'][0], Doc)
- self.assertIsInstance(s.mapping3['someint'][0], Doc)
- self.assertIsInstance(s.mapping4['someint']['d'], Doc)
- self.assertIsInstance(s.mapping5['someint']['d'], Doc)
- self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc)
- self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc)
- self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc)
- self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc)
+ self.assertIsInstance(s.mapping0["someint"], Doc)
+ self.assertIsInstance(s.mapping1["someint"], Doc)
+ self.assertIsInstance(s.mapping2["someint"][0], Doc)
+ self.assertIsInstance(s.mapping3["someint"][0], Doc)
+ self.assertIsInstance(s.mapping4["someint"]["d"], Doc)
+ self.assertIsInstance(s.mapping5["someint"]["d"], Doc)
+ self.assertIsInstance(s.mapping6["someint"][0]["d"], Doc)
+ self.assertIsInstance(s.mapping7["someint"][0]["d"], Doc)
+ self.assertIsInstance(s.mapping8["someint"][0]["d"][0], Doc)
+ self.assertIsInstance(s.mapping9["someint"][0]["d"][0], Doc)
diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py
index 3ce49d62..06ec5151 100644
--- a/tests/fields/test_email_field.py
+++ b/tests/fields/test_email_field.py
@@ -12,28 +12,29 @@ class TestEmailField(MongoDBTestCase):
class User(Document):
email = EmailField()
- user = User(email='ross@example.com')
+ user = User(email="ross@example.com")
user.validate()
- user = User(email='ross@example.co.uk')
+ user = User(email="ross@example.co.uk")
user.validate()
- user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S'
- 'aJIazqqWkm7.net'))
+ user = User(
+ email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5SaJIazqqWkm7.net")
+ )
user.validate()
- user = User(email='new-tld@example.technology')
+ user = User(email="new-tld@example.technology")
user.validate()
- user = User(email='ross@example.com.')
+ user = User(email="ross@example.com.")
self.assertRaises(ValidationError, user.validate)
# unicode domain
- user = User(email=u'user@пример.рф')
+ user = User(email=u"user@пример.рф")
user.validate()
# invalid unicode domain
- user = User(email=u'user@пример')
+ user = User(email=u"user@пример")
self.assertRaises(ValidationError, user.validate)
# invalid data type
@@ -44,20 +45,20 @@ class TestEmailField(MongoDBTestCase):
# Don't run this test on pypy3, which doesn't support unicode regex:
# https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode
if sys.version_info[:2] == (3, 2):
- raise SkipTest('unicode email addresses are not supported on PyPy 3')
+ raise SkipTest("unicode email addresses are not supported on PyPy 3")
class User(Document):
email = EmailField()
# unicode user shouldn't validate by default...
- user = User(email=u'Dörte@Sörensen.example.com')
+ user = User(email=u"Dörte@Sörensen.example.com")
self.assertRaises(ValidationError, user.validate)
# ...but it should be fine with allow_utf8_user set to True
class User(Document):
email = EmailField(allow_utf8_user=True)
- user = User(email=u'Dörte@Sörensen.example.com')
+ user = User(email=u"Dörte@Sörensen.example.com")
user.validate()
def test_email_field_domain_whitelist(self):
@@ -65,22 +66,22 @@ class TestEmailField(MongoDBTestCase):
email = EmailField()
# localhost domain shouldn't validate by default...
- user = User(email='me@localhost')
+ user = User(email="me@localhost")
self.assertRaises(ValidationError, user.validate)
# ...but it should be fine if it's whitelisted
class User(Document):
- email = EmailField(domain_whitelist=['localhost'])
+ email = EmailField(domain_whitelist=["localhost"])
- user = User(email='me@localhost')
+ user = User(email="me@localhost")
user.validate()
def test_email_domain_validation_fails_if_invalid_idn(self):
class User(Document):
email = EmailField()
- invalid_idn = '.google.com'
- user = User(email='me@%s' % invalid_idn)
+ invalid_idn = ".google.com"
+ user = User(email="me@%s" % invalid_idn)
with self.assertRaises(ValidationError) as ctx_err:
user.validate()
self.assertIn("domain failed IDN encoding", str(ctx_err.exception))
@@ -89,9 +90,9 @@ class TestEmailField(MongoDBTestCase):
class User(Document):
email = EmailField()
- valid_ipv4 = 'email@[127.0.0.1]'
- valid_ipv6 = 'email@[2001:dB8::1]'
- invalid_ip = 'email@[324.0.0.1]'
+ valid_ipv4 = "email@[127.0.0.1]"
+ valid_ipv6 = "email@[2001:dB8::1]"
+ invalid_ip = "email@[324.0.0.1]"
# IP address as a domain shouldn't validate by default...
user = User(email=valid_ipv4)
@@ -119,12 +120,12 @@ class TestEmailField(MongoDBTestCase):
def test_email_field_honors_regex(self):
class User(Document):
- email = EmailField(regex=r'\w+@example.com')
+ email = EmailField(regex=r"\w+@example.com")
# Fails regex validation
- user = User(email='me@foo.com')
+ user = User(email="me@foo.com")
self.assertRaises(ValidationError, user.validate)
# Passes regex validation
- user = User(email='me@example.com')
+ user = User(email="me@example.com")
self.assertIsNone(user.validate())
diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py
index a262d054..6b420781 100644
--- a/tests/fields/test_embedded_document_field.py
+++ b/tests/fields/test_embedded_document_field.py
@@ -1,7 +1,18 @@
# -*- coding: utf-8 -*-
-from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \
- InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField, \
- ReferenceField
+from mongoengine import (
+ Document,
+ StringField,
+ ValidationError,
+ EmbeddedDocument,
+ EmbeddedDocumentField,
+ InvalidQueryError,
+ LookUpError,
+ IntField,
+ GenericEmbeddedDocumentField,
+ ListField,
+ EmbeddedDocumentListField,
+ ReferenceField,
+)
from tests.utils import MongoDBTestCase
@@ -14,22 +25,24 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
field = EmbeddedDocumentField(MyDoc)
self.assertEqual(field.document_type_obj, MyDoc)
- field2 = EmbeddedDocumentField('MyDoc')
- self.assertEqual(field2.document_type_obj, 'MyDoc')
+ field2 = EmbeddedDocumentField("MyDoc")
+ self.assertEqual(field2.document_type_obj, "MyDoc")
def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self):
with self.assertRaises(ValidationError):
EmbeddedDocumentField(dict)
def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self):
-
class MyDoc(Document):
name = StringField()
- emb = EmbeddedDocumentField('MyDoc')
+ emb = EmbeddedDocumentField("MyDoc")
with self.assertRaises(ValidationError) as ctx:
emb.document_type
- self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception))
+ self.assertIn(
+ "Invalid embedded document class provided to an EmbeddedDocumentField",
+ str(ctx.exception),
+ )
def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self):
# Relates to #1661
@@ -37,12 +50,14 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
name = StringField()
with self.assertRaises(ValidationError):
+
class MyFailingDoc(Document):
emb = EmbeddedDocumentField(MyDoc)
with self.assertRaises(ValidationError):
+
class MyFailingdoc2(Document):
- emb = EmbeddedDocumentField('MyDoc')
+ emb = EmbeddedDocumentField("MyDoc")
def test_query_embedded_document_attribute(self):
class AdminSettings(EmbeddedDocument):
@@ -55,34 +70,31 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- p = Person(
- settings=AdminSettings(foo1='bar1', foo2='bar2'),
- name='John',
- ).save()
+ p = Person(settings=AdminSettings(foo1="bar1", foo2="bar2"), name="John").save()
# Test non exiting attribute
with self.assertRaises(InvalidQueryError) as ctx_err:
- Person.objects(settings__notexist='bar').first()
+ Person.objects(settings__notexist="bar").first()
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
with self.assertRaises(LookUpError):
- Person.objects.only('settings.notexist')
+ Person.objects.only("settings.notexist")
# Test existing attribute
- self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p.id)
- only_p = Person.objects.only('settings.foo1').first()
+ self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p.id)
+ only_p = Person.objects.only("settings.foo1").first()
self.assertEqual(only_p.settings.foo1, p.settings.foo1)
self.assertIsNone(only_p.settings.foo2)
self.assertIsNone(only_p.name)
- exclude_p = Person.objects.exclude('settings.foo1').first()
+ exclude_p = Person.objects.exclude("settings.foo1").first()
self.assertIsNone(exclude_p.settings.foo1)
self.assertEqual(exclude_p.settings.foo2, p.settings.foo2)
self.assertEqual(exclude_p.name, p.name)
def test_query_embedded_document_attribute_with_inheritance(self):
class BaseSettings(EmbeddedDocument):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
base_foo = StringField()
class AdminSettings(BaseSettings):
@@ -93,26 +105,26 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo'))
+ p = Person(settings=AdminSettings(base_foo="basefoo", sub_foo="subfoo"))
p.save()
# Test non exiting attribute
with self.assertRaises(InvalidQueryError) as ctx_err:
- self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id)
+ self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id)
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
# Test existing attribute
- self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id)
- self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id)
+ self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id)
+ self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id)
- only_p = Person.objects.only('settings.base_foo', 'settings._cls').first()
- self.assertEqual(only_p.settings.base_foo, 'basefoo')
+ only_p = Person.objects.only("settings.base_foo", "settings._cls").first()
+ self.assertEqual(only_p.settings.base_foo, "basefoo")
self.assertIsNone(only_p.settings.sub_foo)
def test_query_list_embedded_document_with_inheritance(self):
class Post(EmbeddedDocument):
title = StringField(max_length=120, required=True)
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class TextPost(Post):
content = StringField()
@@ -123,8 +135,8 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
class Record(Document):
posts = ListField(EmbeddedDocumentField(Post))
- record_movie = Record(posts=[MoviePost(author='John', title='foo')]).save()
- record_text = Record(posts=[TextPost(content='a', title='foo')]).save()
+ record_movie = Record(posts=[MoviePost(author="John", title="foo")]).save()
+ record_text = Record(posts=[TextPost(content="a", title="foo")]).save()
records = list(Record.objects(posts__author=record_movie.posts[0].author))
self.assertEqual(len(records), 1)
@@ -134,11 +146,10 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
self.assertEqual(len(records), 1)
self.assertEqual(records[0].id, record_text.id)
- self.assertEqual(Record.objects(posts__title='foo').count(), 2)
+ self.assertEqual(Record.objects(posts__title="foo").count(), 2)
class TestGenericEmbeddedDocumentField(MongoDBTestCase):
-
def test_generic_embedded_document(self):
class Car(EmbeddedDocument):
name = StringField()
@@ -153,8 +164,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- person = Person(name='Test User')
- person.like = Car(name='Fiat')
+ person = Person(name="Test User")
+ person.like = Car(name="Fiat")
person.save()
person = Person.objects.first()
@@ -168,6 +179,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
def test_generic_embedded_document_choices(self):
"""Ensure you can limit GenericEmbeddedDocument choices."""
+
class Car(EmbeddedDocument):
name = StringField()
@@ -181,8 +193,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- person = Person(name='Test User')
- person.like = Car(name='Fiat')
+ person = Person(name="Test User")
+ person.like = Car(name="Fiat")
self.assertRaises(ValidationError, person.validate)
person.like = Dish(food="arroz", number=15)
@@ -195,6 +207,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
"""Ensure you can limit GenericEmbeddedDocument choices inside
a list field.
"""
+
class Car(EmbeddedDocument):
name = StringField()
@@ -208,8 +221,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- person = Person(name='Test User')
- person.likes = [Car(name='Fiat')]
+ person = Person(name="Test User")
+ person.likes = [Car(name="Fiat")]
self.assertRaises(ValidationError, person.validate)
person.likes = [Dish(food="arroz", number=15)]
@@ -222,25 +235,23 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
"""
Ensure fields with document choices validate given a valid choice.
"""
+
class UserComments(EmbeddedDocument):
author = StringField()
message = StringField()
class BlogPost(Document):
- comments = ListField(
- GenericEmbeddedDocumentField(choices=(UserComments,))
- )
+ comments = ListField(GenericEmbeddedDocumentField(choices=(UserComments,)))
# Ensure Validation Passes
- BlogPost(comments=[
- UserComments(author='user2', message='message2'),
- ]).save()
+ BlogPost(comments=[UserComments(author="user2", message="message2")]).save()
def test_choices_validation_documents_invalid(self):
"""
Ensure fields with document choices validate given an invalid choice.
This should throw a ValidationError exception.
"""
+
class UserComments(EmbeddedDocument):
author = StringField()
message = StringField()
@@ -250,31 +261,28 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
message = StringField()
class BlogPost(Document):
- comments = ListField(
- GenericEmbeddedDocumentField(choices=(UserComments,))
- )
+ comments = ListField(GenericEmbeddedDocumentField(choices=(UserComments,)))
# Single Entry Failure
- post = BlogPost(comments=[
- ModeratorComments(author='mod1', message='message1'),
- ])
+ post = BlogPost(comments=[ModeratorComments(author="mod1", message="message1")])
self.assertRaises(ValidationError, post.save)
# Mixed Entry Failure
- post = BlogPost(comments=[
- ModeratorComments(author='mod1', message='message1'),
- UserComments(author='user2', message='message2'),
- ])
+ post = BlogPost(
+ comments=[
+ ModeratorComments(author="mod1", message="message1"),
+ UserComments(author="user2", message="message2"),
+ ]
+ )
self.assertRaises(ValidationError, post.save)
def test_choices_validation_documents_inheritance(self):
"""
Ensure fields with document choices validate given subclass of choice.
"""
+
class Comments(EmbeddedDocument):
- meta = {
- 'abstract': True
- }
+ meta = {"abstract": True}
author = StringField()
message = StringField()
@@ -282,14 +290,10 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
pass
class BlogPost(Document):
- comments = ListField(
- GenericEmbeddedDocumentField(choices=(Comments,))
- )
+ comments = ListField(GenericEmbeddedDocumentField(choices=(Comments,)))
# Save Valid EmbeddedDocument Type
- BlogPost(comments=[
- UserComments(author='user2', message='message2'),
- ]).save()
+ BlogPost(comments=[UserComments(author="user2", message="message2")]).save()
def test_query_generic_embedded_document_attribute(self):
class AdminSettings(EmbeddedDocument):
@@ -299,28 +303,30 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
foo2 = StringField()
class Person(Document):
- settings = GenericEmbeddedDocumentField(choices=(AdminSettings, NonAdminSettings))
+ settings = GenericEmbeddedDocumentField(
+ choices=(AdminSettings, NonAdminSettings)
+ )
Person.drop_collection()
- p1 = Person(settings=AdminSettings(foo1='bar1')).save()
- p2 = Person(settings=NonAdminSettings(foo2='bar2')).save()
+ p1 = Person(settings=AdminSettings(foo1="bar1")).save()
+ p2 = Person(settings=NonAdminSettings(foo2="bar2")).save()
# Test non exiting attribute
with self.assertRaises(InvalidQueryError) as ctx_err:
- Person.objects(settings__notexist='bar').first()
+ Person.objects(settings__notexist="bar").first()
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
with self.assertRaises(LookUpError):
- Person.objects.only('settings.notexist')
+ Person.objects.only("settings.notexist")
# Test existing attribute
- self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p1.id)
- self.assertEqual(Person.objects(settings__foo2='bar2').first().id, p2.id)
+ self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p1.id)
+ self.assertEqual(Person.objects(settings__foo2="bar2").first().id, p2.id)
def test_query_generic_embedded_document_attribute_with_inheritance(self):
class BaseSettings(EmbeddedDocument):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
base_foo = StringField()
class AdminSettings(BaseSettings):
@@ -331,14 +337,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
Person.drop_collection()
- p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo'))
+ p = Person(settings=AdminSettings(base_foo="basefoo", sub_foo="subfoo"))
p.save()
# Test non exiting attribute
with self.assertRaises(InvalidQueryError) as ctx_err:
- self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id)
+ self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id)
self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"')
# Test existing attribute
- self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id)
- self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id)
+ self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id)
+ self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id)
diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py
index fa92cf20..9f357ce5 100644
--- a/tests/fields/test_float_field.py
+++ b/tests/fields/test_float_field.py
@@ -7,7 +7,6 @@ from tests.utils import MongoDBTestCase
class TestFloatField(MongoDBTestCase):
-
def test_float_ne_operator(self):
class TestDocument(Document):
float_fld = FloatField()
@@ -23,6 +22,7 @@ class TestFloatField(MongoDBTestCase):
def test_validation(self):
"""Ensure that invalid values cannot be assigned to float fields.
"""
+
class Person(Document):
height = FloatField(min_value=0.1, max_value=3.5)
@@ -33,7 +33,7 @@ class TestFloatField(MongoDBTestCase):
person.height = 1.89
person.validate()
- person.height = '2.0'
+ person.height = "2.0"
self.assertRaises(ValidationError, person.validate)
person.height = 0.01
@@ -42,7 +42,7 @@ class TestFloatField(MongoDBTestCase):
person.height = 4.0
self.assertRaises(ValidationError, person.validate)
- person_2 = Person(height='something invalid')
+ person_2 = Person(height="something invalid")
self.assertRaises(ValidationError, person_2.validate)
big_person = BigPerson()
diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py
index 1b1f7ad9..b7db0416 100644
--- a/tests/fields/test_int_field.py
+++ b/tests/fields/test_int_field.py
@@ -5,10 +5,10 @@ from tests.utils import MongoDBTestCase
class TestIntField(MongoDBTestCase):
-
def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields.
"""
+
class Person(Document):
age = IntField(min_value=0, max_value=110)
@@ -26,7 +26,7 @@ class TestIntField(MongoDBTestCase):
self.assertRaises(ValidationError, person.validate)
person.age = 120
self.assertRaises(ValidationError, person.validate)
- person.age = 'ten'
+ person.age = "ten"
self.assertRaises(ValidationError, person.validate)
def test_ne_operator(self):
diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py
index 1d6e6e79..2a686d7f 100644
--- a/tests/fields/test_lazy_reference_field.py
+++ b/tests/fields/test_lazy_reference_field.py
@@ -25,7 +25,7 @@ class TestLazyReferenceField(MongoDBTestCase):
animal = Animal()
oc = Ocurrence(animal=animal)
- self.assertIn('LazyReference', repr(oc.animal))
+ self.assertIn("LazyReference", repr(oc.animal))
def test___getattr___unknown_attr_raises_attribute_error(self):
class Animal(Document):
@@ -93,7 +93,7 @@ class TestLazyReferenceField(MongoDBTestCase):
def test_lazy_reference_set(self):
class Animal(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
name = StringField()
tag = StringField()
@@ -109,18 +109,17 @@ class TestLazyReferenceField(MongoDBTestCase):
nick = StringField()
animal = Animal(name="Leopard", tag="heavy").save()
- sub_animal = SubAnimal(nick='doggo', name='dog').save()
+ sub_animal = SubAnimal(nick="doggo", name="dog").save()
for ref in (
- animal,
- animal.pk,
- DBRef(animal._get_collection_name(), animal.pk),
- LazyReference(Animal, animal.pk),
-
- sub_animal,
- sub_animal.pk,
- DBRef(sub_animal._get_collection_name(), sub_animal.pk),
- LazyReference(SubAnimal, sub_animal.pk),
- ):
+ animal,
+ animal.pk,
+ DBRef(animal._get_collection_name(), animal.pk),
+ LazyReference(Animal, animal.pk),
+ sub_animal,
+ sub_animal.pk,
+ DBRef(sub_animal._get_collection_name(), sub_animal.pk),
+ LazyReference(SubAnimal, sub_animal.pk),
+ ):
p = Ocurrence(person="test", animal=ref).save()
p.reload()
self.assertIsInstance(p.animal, LazyReference)
@@ -144,12 +143,12 @@ class TestLazyReferenceField(MongoDBTestCase):
animal = Animal(name="Leopard", tag="heavy").save()
baddoc = BadDoc().save()
for bad in (
- 42,
- 'foo',
- baddoc,
- DBRef(baddoc._get_collection_name(), animal.pk),
- LazyReference(BadDoc, animal.pk)
- ):
+ 42,
+ "foo",
+ baddoc,
+ DBRef(baddoc._get_collection_name(), animal.pk),
+ LazyReference(BadDoc, animal.pk),
+ ):
with self.assertRaises(ValidationError):
p = Ocurrence(person="test", animal=bad).save()
@@ -157,6 +156,7 @@ class TestLazyReferenceField(MongoDBTestCase):
"""Ensure that LazyReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
+
class Member(Document):
user_num = IntField(primary_key=True)
@@ -172,10 +172,10 @@ class TestLazyReferenceField(MongoDBTestCase):
m2 = Member(user_num=2)
m2.save()
- post1 = BlogPost(title='post 1', author=m1)
+ post1 = BlogPost(title="post 1", author=m1)
post1.save()
- post2 = BlogPost(title='post 2', author=m2)
+ post2 = BlogPost(title="post 2", author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
@@ -192,6 +192,7 @@ class TestLazyReferenceField(MongoDBTestCase):
"""Ensure that LazyReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
+
class Member(Document):
user_num = IntField(primary_key=True)
@@ -207,10 +208,10 @@ class TestLazyReferenceField(MongoDBTestCase):
m2 = Member(user_num=2)
m2.save()
- post1 = BlogPost(title='post 1', author=m1)
+ post1 = BlogPost(title="post 1", author=m1)
post1.save()
- post2 = BlogPost(title='post 2', author=m2)
+ post2 = BlogPost(title="post 2", author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
@@ -240,19 +241,19 @@ class TestLazyReferenceField(MongoDBTestCase):
p = Ocurrence.objects.get()
self.assertIsInstance(p.animal, LazyReference)
with self.assertRaises(KeyError):
- p.animal['name']
+ p.animal["name"]
with self.assertRaises(AttributeError):
p.animal.name
self.assertEqual(p.animal.pk, animal.pk)
self.assertEqual(p.animal_passthrough.name, "Leopard")
- self.assertEqual(p.animal_passthrough['name'], "Leopard")
+ self.assertEqual(p.animal_passthrough["name"], "Leopard")
# Should not be able to access referenced document's methods
with self.assertRaises(AttributeError):
p.animal.save
with self.assertRaises(KeyError):
- p.animal['save']
+ p.animal["save"]
def test_lazy_reference_not_set(self):
class Animal(Document):
@@ -266,7 +267,7 @@ class TestLazyReferenceField(MongoDBTestCase):
Animal.drop_collection()
Ocurrence.drop_collection()
- Ocurrence(person='foo').save()
+ Ocurrence(person="foo").save()
p = Ocurrence.objects.get()
self.assertIs(p.animal, None)
@@ -303,8 +304,8 @@ class TestLazyReferenceField(MongoDBTestCase):
Animal.drop_collection()
Ocurrence.drop_collection()
- animal1 = Animal(name='doggo').save()
- animal2 = Animal(name='cheeta').save()
+ animal1 = Animal(name="doggo").save()
+ animal2 = Animal(name="cheeta").save()
def check_fields_type(occ):
self.assertIsInstance(occ.direct, LazyReference)
@@ -316,8 +317,8 @@ class TestLazyReferenceField(MongoDBTestCase):
occ = Ocurrence(
in_list=[animal1, animal2],
- in_embedded={'in_list': [animal1, animal2], 'direct': animal1},
- direct=animal1
+ in_embedded={"in_list": [animal1, animal2], "direct": animal1},
+ direct=animal1,
).save()
check_fields_type(occ)
occ.reload()
@@ -403,7 +404,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
def test_generic_lazy_reference_set(self):
class Animal(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
name = StringField()
tag = StringField()
@@ -419,16 +420,18 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
nick = StringField()
animal = Animal(name="Leopard", tag="heavy").save()
- sub_animal = SubAnimal(nick='doggo', name='dog').save()
+ sub_animal = SubAnimal(nick="doggo", name="dog").save()
for ref in (
- animal,
- LazyReference(Animal, animal.pk),
- {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)},
-
- sub_animal,
- LazyReference(SubAnimal, sub_animal.pk),
- {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)},
- ):
+ animal,
+ LazyReference(Animal, animal.pk),
+ {"_cls": "Animal", "_ref": DBRef(animal._get_collection_name(), animal.pk)},
+ sub_animal,
+ LazyReference(SubAnimal, sub_animal.pk),
+ {
+ "_cls": "SubAnimal",
+ "_ref": DBRef(sub_animal._get_collection_name(), sub_animal.pk),
+ },
+ ):
p = Ocurrence(person="test", animal=ref).save()
p.reload()
self.assertIsInstance(p.animal, (LazyReference, Document))
@@ -441,7 +444,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
class Ocurrence(Document):
person = StringField()
- animal = GenericLazyReferenceField(choices=['Animal'])
+ animal = GenericLazyReferenceField(choices=["Animal"])
Animal.drop_collection()
Ocurrence.drop_collection()
@@ -451,12 +454,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
animal = Animal(name="Leopard", tag="heavy").save()
baddoc = BadDoc().save()
- for bad in (
- 42,
- 'foo',
- baddoc,
- LazyReference(BadDoc, animal.pk)
- ):
+ for bad in (42, "foo", baddoc, LazyReference(BadDoc, animal.pk)):
with self.assertRaises(ValidationError):
p = Ocurrence(person="test", animal=bad).save()
@@ -476,10 +474,10 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
m2 = Member(user_num=2)
m2.save()
- post1 = BlogPost(title='post 1', author=m1)
+ post1 = BlogPost(title="post 1", author=m1)
post1.save()
- post2 = BlogPost(title='post 2', author=m2)
+ post2 = BlogPost(title="post 2", author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
@@ -504,7 +502,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
Animal.drop_collection()
Ocurrence.drop_collection()
- Ocurrence(person='foo').save()
+ Ocurrence(person="foo").save()
p = Ocurrence.objects.get()
self.assertIs(p.animal, None)
@@ -515,7 +513,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
class Ocurrence(Document):
person = StringField()
- animal = GenericLazyReferenceField('Animal')
+ animal = GenericLazyReferenceField("Animal")
Animal.drop_collection()
Ocurrence.drop_collection()
@@ -542,8 +540,8 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
Animal.drop_collection()
Ocurrence.drop_collection()
- animal1 = Animal(name='doggo').save()
- animal2 = Animal(name='cheeta').save()
+ animal1 = Animal(name="doggo").save()
+ animal2 = Animal(name="cheeta").save()
def check_fields_type(occ):
self.assertIsInstance(occ.direct, LazyReference)
@@ -555,14 +553,20 @@ class TestGenericLazyReferenceField(MongoDBTestCase):
occ = Ocurrence(
in_list=[animal1, animal2],
- in_embedded={'in_list': [animal1, animal2], 'direct': animal1},
- direct=animal1
+ in_embedded={"in_list": [animal1, animal2], "direct": animal1},
+ direct=animal1,
).save()
check_fields_type(occ)
occ.reload()
check_fields_type(occ)
- animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)}
- animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)}
+ animal1_ref = {
+ "_cls": "Animal",
+ "_ref": DBRef(animal1._get_collection_name(), animal1.pk),
+ }
+ animal2_ref = {
+ "_cls": "Animal",
+ "_ref": DBRef(animal2._get_collection_name(), animal2.pk),
+ }
occ.direct = animal1_ref
occ.in_list = [animal1_ref, animal2_ref]
occ.in_embedded.direct = animal1_ref
diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py
index 3f307809..ab86eccd 100644
--- a/tests/fields/test_long_field.py
+++ b/tests/fields/test_long_field.py
@@ -13,23 +13,26 @@ from tests.utils import MongoDBTestCase
class TestLongField(MongoDBTestCase):
-
def test_long_field_is_considered_as_int64(self):
"""
Tests that long fields are stored as long in mongo, even if long
value is small enough to be an int.
"""
+
class TestLongFieldConsideredAsInt64(Document):
some_long = LongField()
doc = TestLongFieldConsideredAsInt64(some_long=42).save()
db = get_db()
- self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64)
+ self.assertIsInstance(
+ db.test_long_field_considered_as_int64.find()[0]["some_long"], Int64
+ )
self.assertIsInstance(doc.some_long, six.integer_types)
def test_long_validation(self):
"""Ensure that invalid values cannot be assigned to long fields.
"""
+
class TestDocument(Document):
value = LongField(min_value=0, max_value=110)
@@ -41,7 +44,7 @@ class TestLongField(MongoDBTestCase):
self.assertRaises(ValidationError, doc.validate)
doc.value = 120
self.assertRaises(ValidationError, doc.validate)
- doc.value = 'ten'
+ doc.value = "ten"
self.assertRaises(ValidationError, doc.validate)
def test_long_ne_operator(self):
diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py
index cb27cfff..54f70aa1 100644
--- a/tests/fields/test_map_field.py
+++ b/tests/fields/test_map_field.py
@@ -7,23 +7,24 @@ from tests.utils import MongoDBTestCase
class TestMapField(MongoDBTestCase):
-
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
+
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
- e.mapping['someint'] = 1
+ e.mapping["someint"] = 1
e.save()
with self.assertRaises(ValidationError):
- e.mapping['somestring'] = "abc"
+ e.mapping["somestring"] = "abc"
e.save()
with self.assertRaises(ValidationError):
+
class NoDeclaredType(Document):
mapping = MapField()
@@ -45,38 +46,37 @@ class TestMapField(MongoDBTestCase):
Extensible.drop_collection()
e = Extensible()
- e.mapping['somestring'] = StringSetting(value='foo')
- e.mapping['someint'] = IntegerSetting(value=42)
+ e.mapping["somestring"] = StringSetting(value="foo")
+ e.mapping["someint"] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
- self.assertIsInstance(e2.mapping['somestring'], StringSetting)
- self.assertIsInstance(e2.mapping['someint'], IntegerSetting)
+ self.assertIsInstance(e2.mapping["somestring"], StringSetting)
+ self.assertIsInstance(e2.mapping["someint"], IntegerSetting)
with self.assertRaises(ValidationError):
- e.mapping['someint'] = 123
+ e.mapping["someint"] = 123
e.save()
def test_embedded_mapfield_db_field(self):
class Embedded(EmbeddedDocument):
- number = IntField(default=0, db_field='i')
+ number = IntField(default=0, db_field="i")
class Test(Document):
- my_map = MapField(field=EmbeddedDocumentField(Embedded),
- db_field='x')
+ my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field="x")
Test.drop_collection()
test = Test()
- test.my_map['DICTIONARY_KEY'] = Embedded(number=1)
+ test.my_map["DICTIONARY_KEY"] = Embedded(number=1)
test.save()
Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1)
test = Test.objects.get()
- self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2)
+ self.assertEqual(test.my_map["DICTIONARY_KEY"].number, 2)
doc = self.db.test.find_one()
- self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2)
+ self.assertEqual(doc["x"]["DICTIONARY_KEY"]["i"], 2)
def test_mapfield_numerical_index(self):
"""Ensure that MapField accept numeric strings as indexes."""
@@ -90,9 +90,9 @@ class TestMapField(MongoDBTestCase):
Test.drop_collection()
test = Test()
- test.my_map['1'] = Embedded(name='test')
+ test.my_map["1"] = Embedded(name="test")
test.save()
- test.my_map['1'].name = 'test updated'
+ test.my_map["1"].name = "test updated"
test.save()
def test_map_field_lookup(self):
@@ -110,15 +110,20 @@ class TestMapField(MongoDBTestCase):
actions = MapField(EmbeddedDocumentField(Action))
Log.drop_collection()
- Log(name="wilson", visited={'friends': datetime.datetime.now()},
- actions={'friends': Action(operation='drink', object='beer')}).save()
+ Log(
+ name="wilson",
+ visited={"friends": datetime.datetime.now()},
+ actions={"friends": Action(operation="drink", object="beer")},
+ ).save()
- self.assertEqual(1, Log.objects(
- visited__friends__exists=True).count())
+ self.assertEqual(1, Log.objects(visited__friends__exists=True).count())
- self.assertEqual(1, Log.objects(
- actions__friends__operation='drink',
- actions__friends__object='beer').count())
+ self.assertEqual(
+ 1,
+ Log.objects(
+ actions__friends__operation="drink", actions__friends__object="beer"
+ ).count(),
+ )
def test_map_field_unicode(self):
class Info(EmbeddedDocument):
@@ -130,15 +135,11 @@ class TestMapField(MongoDBTestCase):
BlogPost.drop_collection()
- tree = BlogPost(info_dict={
- u"éééé": {
- 'description': u"VALUE: éééé"
- }
- })
+ tree = BlogPost(info_dict={u"éééé": {"description": u"VALUE: éééé"}})
tree.save()
self.assertEqual(
BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description,
- u"VALUE: éééé"
+ u"VALUE: éééé",
)
diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py
index 5e1fc605..5fd053fe 100644
--- a/tests/fields/test_reference_field.py
+++ b/tests/fields/test_reference_field.py
@@ -26,15 +26,15 @@ class TestReferenceField(MongoDBTestCase):
# with a document class name.
self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument)
- user = User(name='Test User')
+ user = User(name="Test User")
# Ensure that the referenced object must have been saved
- post1 = BlogPost(content='Chips and gravy taste good.')
+ post1 = BlogPost(content="Chips and gravy taste good.")
post1.author = user
self.assertRaises(ValidationError, post1.save)
# Check that an invalid object type cannot be used
- post2 = BlogPost(content='Chips and chilli taste good.')
+ post2 = BlogPost(content="Chips and chilli taste good.")
post1.author = post2
self.assertRaises(ValidationError, post1.validate)
@@ -59,7 +59,7 @@ class TestReferenceField(MongoDBTestCase):
class Person(Document):
name = StringField()
- parent = ReferenceField('self')
+ parent = ReferenceField("self")
Person.drop_collection()
@@ -74,7 +74,7 @@ class TestReferenceField(MongoDBTestCase):
class Person(Document):
name = StringField()
- parent = ReferenceField('self', dbref=True)
+ parent = ReferenceField("self", dbref=True)
Person.drop_collection()
@@ -82,8 +82,8 @@ class TestReferenceField(MongoDBTestCase):
Person(name="Ross", parent=p1).save()
self.assertEqual(
- Person._get_collection().find_one({'name': 'Ross'})['parent'],
- DBRef('person', p1.pk)
+ Person._get_collection().find_one({"name": "Ross"})["parent"],
+ DBRef("person", p1.pk),
)
p = Person.objects.get(name="Ross")
@@ -97,21 +97,17 @@ class TestReferenceField(MongoDBTestCase):
class Person(Document):
name = StringField()
- parent = ReferenceField('self', dbref=False)
+ parent = ReferenceField("self", dbref=False)
- p = Person(
- name='Steve',
- parent=DBRef('person', 'abcdefghijklmnop')
+ p = Person(name="Steve", parent=DBRef("person", "abcdefghijklmnop"))
+ self.assertEqual(
+ p.to_mongo(), SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")])
)
- self.assertEqual(p.to_mongo(), SON([
- ('name', u'Steve'),
- ('parent', 'abcdefghijklmnop')
- ]))
def test_objectid_reference_fields(self):
class Person(Document):
name = StringField()
- parent = ReferenceField('self', dbref=False)
+ parent = ReferenceField("self", dbref=False)
Person.drop_collection()
@@ -119,8 +115,8 @@ class TestReferenceField(MongoDBTestCase):
Person(name="Ross", parent=p1).save()
col = Person._get_collection()
- data = col.find_one({'name': 'Ross'})
- self.assertEqual(data['parent'], p1.pk)
+ data = col.find_one({"name": "Ross"})
+ self.assertEqual(data["parent"], p1.pk)
p = Person.objects.get(name="Ross")
self.assertEqual(p.parent, p1)
@@ -128,9 +124,10 @@ class TestReferenceField(MongoDBTestCase):
def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents.
"""
+
class Product(Document):
name = StringField()
- company = ReferenceField('Company')
+ company = ReferenceField("Company")
class Company(Document):
name = StringField()
@@ -138,12 +135,12 @@ class TestReferenceField(MongoDBTestCase):
Product.drop_collection()
Company.drop_collection()
- ten_gen = Company(name='10gen')
+ ten_gen = Company(name="10gen")
ten_gen.save()
- mongodb = Product(name='MongoDB', company=ten_gen)
+ mongodb = Product(name="MongoDB", company=ten_gen)
mongodb.save()
- me = Product(name='MongoEngine')
+ me = Product(name="MongoEngine")
me.save()
obj = Product.objects(company=ten_gen).first()
@@ -160,6 +157,7 @@ class TestReferenceField(MongoDBTestCase):
"""Ensure that ReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
+
class Member(Document):
user_num = IntField(primary_key=True)
@@ -175,10 +173,10 @@ class TestReferenceField(MongoDBTestCase):
m2 = Member(user_num=2)
m2.save()
- post1 = BlogPost(title='post 1', author=m1)
+ post1 = BlogPost(title="post 1", author=m1)
post1.save()
- post2 = BlogPost(title='post 2', author=m2)
+ post2 = BlogPost(title="post 2", author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
@@ -191,6 +189,7 @@ class TestReferenceField(MongoDBTestCase):
"""Ensure that ReferenceFields can be queried using objects and values
of the type of the primary key of the referenced object.
"""
+
class Member(Document):
user_num = IntField(primary_key=True)
@@ -206,10 +205,10 @@ class TestReferenceField(MongoDBTestCase):
m2 = Member(user_num=2)
m2.save()
- post1 = BlogPost(title='post 1', author=m1)
+ post1 = BlogPost(title="post 1", author=m1)
post1.save()
- post2 = BlogPost(title='post 2', author=m2)
+ post2 = BlogPost(title="post 2", author=m2)
post2.save()
post = BlogPost.objects(author=m1).first()
diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py
index 6124c65e..f2c8388b 100644
--- a/tests/fields/test_sequence_field.py
+++ b/tests/fields/test_sequence_field.py
@@ -11,38 +11,38 @@ class TestSequenceField(MongoDBTestCase):
id = SequenceField(primary_key=True)
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
Person(name="Person %s" % x).save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
Person.id.set_next_value(1000)
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 1000)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 1000)
def test_sequence_field_get_next_value(self):
class Person(Document):
id = SequenceField(primary_key=True)
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), 11)
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
self.assertEqual(Person.id.get_next_value(), 1)
@@ -50,40 +50,40 @@ class TestSequenceField(MongoDBTestCase):
id = SequenceField(primary_key=True, value_decorator=str)
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
Person(name="Person %s" % x).save()
- self.assertEqual(Person.id.get_next_value(), '11')
- self.db['mongoengine.counters'].drop()
+ self.assertEqual(Person.id.get_next_value(), "11")
+ self.db["mongoengine.counters"].drop()
- self.assertEqual(Person.id.get_next_value(), '1')
+ self.assertEqual(Person.id.get_next_value(), "1")
def test_sequence_field_sequence_name(self):
class Person(Document):
- id = SequenceField(primary_key=True, sequence_name='jelly')
+ id = SequenceField(primary_key=True, sequence_name="jelly")
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
Person(name="Person %s" % x).save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"})
+ self.assertEqual(c["next"], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
- c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"})
+ self.assertEqual(c["next"], 10)
Person.id.set_next_value(1000)
- c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
- self.assertEqual(c['next'], 1000)
+ c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"})
+ self.assertEqual(c["next"], 1000)
def test_multiple_sequence_fields(self):
class Person(Document):
@@ -91,14 +91,14 @@ class TestSequenceField(MongoDBTestCase):
counter = SequenceField()
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
Person(name="Person %s" % x).save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
@@ -106,23 +106,23 @@ class TestSequenceField(MongoDBTestCase):
counters = [i.counter for i in Person.objects]
self.assertEqual(counters, range(1, 11))
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
Person.id.set_next_value(1000)
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 1000)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 1000)
Person.counter.set_next_value(999)
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'})
- self.assertEqual(c['next'], 999)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.counter"})
+ self.assertEqual(c["next"], 999)
def test_sequence_fields_reload(self):
class Animal(Document):
counter = SequenceField()
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Animal.drop_collection()
a = Animal(name="Boi").save()
@@ -151,7 +151,7 @@ class TestSequenceField(MongoDBTestCase):
id = SequenceField(primary_key=True)
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Animal.drop_collection()
Person.drop_collection()
@@ -159,11 +159,11 @@ class TestSequenceField(MongoDBTestCase):
Animal(name="Animal %s" % x).save()
Person(name="Person %s" % x).save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
- c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"})
+ self.assertEqual(c["next"], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, range(1, 11))
@@ -171,32 +171,32 @@ class TestSequenceField(MongoDBTestCase):
id = [i.id for i in Animal.objects]
self.assertEqual(id, range(1, 11))
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
- c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"})
+ self.assertEqual(c["next"], 10)
def test_sequence_field_value_decorator(self):
class Person(Document):
id = SequenceField(primary_key=True, value_decorator=str)
name = StringField()
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Person.drop_collection()
for x in range(10):
p = Person(name="Person %s" % x)
p.save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
ids = [i.id for i in Person.objects]
self.assertEqual(ids, map(str, range(1, 11)))
- c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
- self.assertEqual(c['next'], 10)
+ c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
+ self.assertEqual(c["next"], 10)
def test_embedded_sequence_field(self):
class Comment(EmbeddedDocument):
@@ -207,14 +207,18 @@ class TestSequenceField(MongoDBTestCase):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
- self.db['mongoengine.counters'].drop()
+ self.db["mongoengine.counters"].drop()
Post.drop_collection()
- Post(title="MongoEngine",
- comments=[Comment(content="NoSQL Rocks"),
- Comment(content="MongoEngine Rocks")]).save()
- c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'})
- self.assertEqual(c['next'], 2)
+ Post(
+ title="MongoEngine",
+ comments=[
+ Comment(content="NoSQL Rocks"),
+ Comment(content="MongoEngine Rocks"),
+ ],
+ ).save()
+ c = self.db["mongoengine.counters"].find_one({"_id": "comment.id"})
+ self.assertEqual(c["next"], 2)
post = Post.objects.first()
self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id)
@@ -223,7 +227,7 @@ class TestSequenceField(MongoDBTestCase):
class Base(Document):
name = StringField()
counter = SequenceField()
- meta = {'abstract': True}
+ meta = {"abstract": True}
class Foo(Base):
pass
@@ -231,24 +235,27 @@ class TestSequenceField(MongoDBTestCase):
class Bar(Base):
pass
- bar = Bar(name='Bar')
+ bar = Bar(name="Bar")
bar.save()
- foo = Foo(name='Foo')
+ foo = Foo(name="Foo")
foo.save()
- self.assertTrue('base.counter' in
- self.db['mongoengine.counters'].find().distinct('_id'))
- self.assertFalse(('foo.counter' or 'bar.counter') in
- self.db['mongoengine.counters'].find().distinct('_id'))
+ self.assertTrue(
+ "base.counter" in self.db["mongoengine.counters"].find().distinct("_id")
+ )
+ self.assertFalse(
+ ("foo.counter" or "bar.counter")
+ in self.db["mongoengine.counters"].find().distinct("_id")
+ )
self.assertNotEqual(foo.counter, bar.counter)
- self.assertEqual(foo._fields['counter'].owner_document, Base)
- self.assertEqual(bar._fields['counter'].owner_document, Base)
+ self.assertEqual(foo._fields["counter"].owner_document, Base)
+ self.assertEqual(bar._fields["counter"].owner_document, Base)
def test_no_inherited_sequencefield(self):
class Base(Document):
name = StringField()
- meta = {'abstract': True}
+ meta = {"abstract": True}
class Foo(Base):
counter = SequenceField()
@@ -256,16 +263,19 @@ class TestSequenceField(MongoDBTestCase):
class Bar(Base):
counter = SequenceField()
- bar = Bar(name='Bar')
+ bar = Bar(name="Bar")
bar.save()
- foo = Foo(name='Foo')
+ foo = Foo(name="Foo")
foo.save()
- self.assertFalse('base.counter' in
- self.db['mongoengine.counters'].find().distinct('_id'))
- self.assertTrue(('foo.counter' and 'bar.counter') in
- self.db['mongoengine.counters'].find().distinct('_id'))
+ self.assertFalse(
+ "base.counter" in self.db["mongoengine.counters"].find().distinct("_id")
+ )
+ self.assertTrue(
+ ("foo.counter" and "bar.counter")
+ in self.db["mongoengine.counters"].find().distinct("_id")
+ )
self.assertEqual(foo.counter, bar.counter)
- self.assertEqual(foo._fields['counter'].owner_document, Foo)
- self.assertEqual(bar._fields['counter'].owner_document, Bar)
+ self.assertEqual(foo._fields["counter"].owner_document, Foo)
+ self.assertEqual(bar._fields["counter"].owner_document, Bar)
diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py
index ddbf707e..81baf8d0 100644
--- a/tests/fields/test_url_field.py
+++ b/tests/fields/test_url_field.py
@@ -5,49 +5,53 @@ from tests.utils import MongoDBTestCase
class TestURLField(MongoDBTestCase):
-
def test_validation(self):
"""Ensure that URLFields validate urls properly."""
+
class Link(Document):
url = URLField()
link = Link()
- link.url = 'google'
+ link.url = "google"
self.assertRaises(ValidationError, link.validate)
- link.url = 'http://www.google.com:8080'
+ link.url = "http://www.google.com:8080"
link.validate()
def test_unicode_url_validation(self):
"""Ensure unicode URLs are validated properly."""
+
class Link(Document):
url = URLField()
link = Link()
- link.url = u'http://привет.com'
+ link.url = u"http://привет.com"
# TODO fix URL validation - this *IS* a valid URL
# For now we just want to make sure that the error message is correct
with self.assertRaises(ValidationError) as ctx_err:
link.validate()
- self.assertEqual(unicode(ctx_err.exception),
- u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])")
+ self.assertEqual(
+ unicode(ctx_err.exception),
+ u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])",
+ )
def test_url_scheme_validation(self):
"""Ensure that URLFields validate urls with specific schemes properly.
"""
+
class Link(Document):
url = URLField()
class SchemeLink(Document):
- url = URLField(schemes=['ws', 'irc'])
+ url = URLField(schemes=["ws", "irc"])
link = Link()
- link.url = 'ws://google.com'
+ link.url = "ws://google.com"
self.assertRaises(ValidationError, link.validate)
scheme_link = SchemeLink()
- scheme_link.url = 'ws://google.com'
+ scheme_link.url = "ws://google.com"
scheme_link.validate()
def test_underscore_allowed_in_domains_names(self):
@@ -55,5 +59,5 @@ class TestURLField(MongoDBTestCase):
url = URLField()
link = Link()
- link.url = 'https://san_leandro-ca.geebo.com'
+ link.url = "https://san_leandro-ca.geebo.com"
link.validate()
diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py
index 7b7faaf2..647dceaf 100644
--- a/tests/fields/test_uuid_field.py
+++ b/tests/fields/test_uuid_field.py
@@ -15,11 +15,8 @@ class TestUUIDField(MongoDBTestCase):
uid = uuid.uuid4()
person = Person(api_key=uid).save()
self.assertEqual(
- get_as_pymongo(person),
- {'_id': person.id,
- 'api_key': str(uid)
- }
- )
+ get_as_pymongo(person), {"_id": person.id, "api_key": str(uid)}
+ )
def test_field_string(self):
"""Test UUID fields storing as String
@@ -37,8 +34,10 @@ class TestUUIDField(MongoDBTestCase):
person.api_key = api_key
person.validate()
- invalid = ('9d159858-549b-4975-9f98-dd2f987c113g',
- '9d159858-549b-4975-9f98-dd2f987c113')
+ invalid = (
+ "9d159858-549b-4975-9f98-dd2f987c113g",
+ "9d159858-549b-4975-9f98-dd2f987c113",
+ )
for api_key in invalid:
person.api_key = api_key
self.assertRaises(ValidationError, person.validate)
@@ -58,8 +57,10 @@ class TestUUIDField(MongoDBTestCase):
person.api_key = api_key
person.validate()
- invalid = ('9d159858-549b-4975-9f98-dd2f987c113g',
- '9d159858-549b-4975-9f98-dd2f987c113')
+ invalid = (
+ "9d159858-549b-4975-9f98-dd2f987c113g",
+ "9d159858-549b-4975-9f98-dd2f987c113",
+ )
for api_key in invalid:
person.api_key = api_key
self.assertRaises(ValidationError, person.validate)
diff --git a/tests/fixtures.py b/tests/fixtures.py
index b8303b99..9f06f1ab 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -11,7 +11,7 @@ class PickleEmbedded(EmbeddedDocument):
class PickleTest(Document):
number = IntField()
- string = StringField(choices=(('One', '1'), ('Two', '2')))
+ string = StringField(choices=(("One", "1"), ("Two", "2")))
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
photo = FileField()
@@ -19,7 +19,7 @@ class PickleTest(Document):
class NewDocumentPickleTest(Document):
number = IntField()
- string = StringField(choices=(('One', '1'), ('Two', '2')))
+ string = StringField(choices=(("One", "1"), ("Two", "2")))
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
photo = FileField()
@@ -36,7 +36,7 @@ class PickleDynamicTest(DynamicDocument):
class PickleSignalsTest(Document):
number = IntField()
- string = StringField(choices=(('One', '1'), ('Two', '2')))
+ string = StringField(choices=(("One", "1"), ("Two", "2")))
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
@@ -58,4 +58,4 @@ class Mixin(object):
class Base(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py
index 250e2601..9f0fe827 100644
--- a/tests/queryset/field_list.py
+++ b/tests/queryset/field_list.py
@@ -7,79 +7,78 @@ __all__ = ("QueryFieldListTest", "OnlyExcludeAllTest")
class QueryFieldListTest(unittest.TestCase):
-
def test_empty(self):
q = QueryFieldList()
self.assertFalse(q)
- q = QueryFieldList(always_include=['_cls'])
+ q = QueryFieldList(always_include=["_cls"])
self.assertFalse(q)
def test_include_include(self):
q = QueryFieldList()
- q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY, _only_called=True)
- self.assertEqual(q.as_dict(), {'a': 1, 'b': 1})
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'a': 1, 'b': 1, 'c': 1})
+ q += QueryFieldList(
+ fields=["a", "b"], value=QueryFieldList.ONLY, _only_called=True
+ )
+ self.assertEqual(q.as_dict(), {"a": 1, "b": 1})
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"a": 1, "b": 1, "c": 1})
def test_include_exclude(self):
q = QueryFieldList()
- q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'a': 1, 'b': 1})
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
- self.assertEqual(q.as_dict(), {'a': 1})
+ q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"a": 1, "b": 1})
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE)
+ self.assertEqual(q.as_dict(), {"a": 1})
def test_exclude_exclude(self):
q = QueryFieldList()
- q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
- self.assertEqual(q.as_dict(), {'a': 0, 'b': 0})
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
- self.assertEqual(q.as_dict(), {'a': 0, 'b': 0, 'c': 0})
+ q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE)
+ self.assertEqual(q.as_dict(), {"a": 0, "b": 0})
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE)
+ self.assertEqual(q.as_dict(), {"a": 0, "b": 0, "c": 0})
def test_exclude_include(self):
q = QueryFieldList()
- q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
- self.assertEqual(q.as_dict(), {'a': 0, 'b': 0})
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'c': 1})
+ q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE)
+ self.assertEqual(q.as_dict(), {"a": 0, "b": 0})
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"c": 1})
def test_always_include(self):
- q = QueryFieldList(always_include=['x', 'y'])
- q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1})
+ q = QueryFieldList(always_include=["x", "y"])
+ q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE)
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1})
def test_reset(self):
- q = QueryFieldList(always_include=['x', 'y'])
- q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1})
+ q = QueryFieldList(always_include=["x", "y"])
+ q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE)
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1})
q.reset()
self.assertFalse(q)
- q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
- self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'b': 1, 'c': 1})
+ q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY)
+ self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "b": 1, "c": 1})
def test_using_a_slice(self):
q = QueryFieldList()
- q += QueryFieldList(fields=['a'], value={"$slice": 5})
- self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
+ q += QueryFieldList(fields=["a"], value={"$slice": 5})
+ self.assertEqual(q.as_dict(), {"a": {"$slice": 5}})
class OnlyExcludeAllTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
class Person(Document):
name = StringField()
age = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
Person.drop_collection()
self.Person = Person
def test_mixing_only_exclude(self):
-
class MyDoc(Document):
a = StringField()
b = StringField()
@@ -88,32 +87,32 @@ class OnlyExcludeAllTest(unittest.TestCase):
e = StringField()
f = StringField()
- include = ['a', 'b', 'c', 'd', 'e']
- exclude = ['d', 'e']
- only = ['b', 'c']
+ include = ["a", "b", "c", "d", "e"]
+ exclude = ["d", "e"]
+ only = ["b", "c"]
qs = MyDoc.objects.fields(**{i: 1 for i in include})
- self.assertEqual(qs._loaded_fields.as_dict(),
- {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1})
+ self.assertEqual(
+ qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}
+ )
qs = qs.only(*only)
- self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
qs = qs.exclude(*exclude)
- self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
qs = MyDoc.objects.fields(**{i: 1 for i in include})
qs = qs.exclude(*exclude)
- self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1})
qs = qs.only(*only)
- self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
qs = MyDoc.objects.exclude(*exclude)
qs = qs.fields(**{i: 1 for i in include})
- self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1})
qs = qs.only(*only)
- self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1})
def test_slicing(self):
-
class MyDoc(Document):
a = ListField()
b = ListField()
@@ -122,24 +121,23 @@ class OnlyExcludeAllTest(unittest.TestCase):
e = ListField()
f = ListField()
- include = ['a', 'b', 'c', 'd', 'e']
- exclude = ['d', 'e']
- only = ['b', 'c']
+ include = ["a", "b", "c", "d", "e"]
+ exclude = ["d", "e"]
+ only = ["b", "c"]
qs = MyDoc.objects.fields(**{i: 1 for i in include})
qs = qs.exclude(*exclude)
qs = qs.only(*only)
qs = qs.fields(slice__b=5)
- self.assertEqual(qs._loaded_fields.as_dict(),
- {'b': {'$slice': 5}, 'c': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": 1})
qs = qs.fields(slice__c=[5, 1])
- self.assertEqual(qs._loaded_fields.as_dict(),
- {'b': {'$slice': 5}, 'c': {'$slice': [5, 1]}})
+ self.assertEqual(
+ qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": {"$slice": [5, 1]}}
+ )
- qs = qs.exclude('c')
- self.assertEqual(qs._loaded_fields.as_dict(),
- {'b': {'$slice': 5}})
+ qs = qs.exclude("c")
+ self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}})
def test_mix_slice_with_other_fields(self):
class MyDoc(Document):
@@ -148,43 +146,42 @@ class OnlyExcludeAllTest(unittest.TestCase):
c = ListField()
qs = MyDoc.objects.fields(a=1, b=0, slice__c=2)
- self.assertEqual(qs._loaded_fields.as_dict(),
- {'c': {'$slice': 2}, 'a': 1})
+ self.assertEqual(qs._loaded_fields.as_dict(), {"c": {"$slice": 2}, "a": 1})
def test_only(self):
"""Ensure that QuerySet.only only returns the requested fields.
"""
- person = self.Person(name='test', age=25)
+ person = self.Person(name="test", age=25)
person.save()
- obj = self.Person.objects.only('name').get()
+ obj = self.Person.objects.only("name").get()
self.assertEqual(obj.name, person.name)
self.assertEqual(obj.age, None)
- obj = self.Person.objects.only('age').get()
+ obj = self.Person.objects.only("age").get()
self.assertEqual(obj.name, None)
self.assertEqual(obj.age, person.age)
- obj = self.Person.objects.only('name', 'age').get()
+ obj = self.Person.objects.only("name", "age").get()
self.assertEqual(obj.name, person.name)
self.assertEqual(obj.age, person.age)
- obj = self.Person.objects.only(*('id', 'name',)).get()
+ obj = self.Person.objects.only(*("id", "name")).get()
self.assertEqual(obj.name, person.name)
self.assertEqual(obj.age, None)
# Check polymorphism still works
class Employee(self.Person):
- salary = IntField(db_field='wage')
+ salary = IntField(db_field="wage")
- employee = Employee(name='test employee', age=40, salary=30000)
+ employee = Employee(name="test employee", age=40, salary=30000)
employee.save()
- obj = self.Person.objects(id=employee.id).only('age').get()
+ obj = self.Person.objects(id=employee.id).only("age").get()
self.assertIsInstance(obj, Employee)
# Check field names are looked up properly
- obj = Employee.objects(id=employee.id).only('salary').get()
+ obj = Employee.objects(id=employee.id).only("salary").get()
self.assertEqual(obj.salary, employee.salary)
self.assertEqual(obj.name, None)
@@ -208,35 +205,41 @@ class OnlyExcludeAllTest(unittest.TestCase):
BlogPost.drop_collection()
- post = BlogPost(content='Had a good coffee today...', various={'test_dynamic': {'some': True}})
- post.author = User(name='Test User')
- post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
+ post = BlogPost(
+ content="Had a good coffee today...",
+ various={"test_dynamic": {"some": True}},
+ )
+ post.author = User(name="Test User")
+ post.comments = [
+ Comment(title="I aggree", text="Great post!"),
+ Comment(title="Coffee", text="I hate coffee"),
+ ]
post.save()
- obj = BlogPost.objects.only('author.name',).get()
+ obj = BlogPost.objects.only("author.name").get()
self.assertEqual(obj.content, None)
self.assertEqual(obj.author.email, None)
- self.assertEqual(obj.author.name, 'Test User')
+ self.assertEqual(obj.author.name, "Test User")
self.assertEqual(obj.comments, [])
- obj = BlogPost.objects.only('various.test_dynamic.some').get()
+ obj = BlogPost.objects.only("various.test_dynamic.some").get()
self.assertEqual(obj.various["test_dynamic"].some, True)
- obj = BlogPost.objects.only('content', 'comments.title',).get()
- self.assertEqual(obj.content, 'Had a good coffee today...')
+ obj = BlogPost.objects.only("content", "comments.title").get()
+ self.assertEqual(obj.content, "Had a good coffee today...")
self.assertEqual(obj.author, None)
- self.assertEqual(obj.comments[0].title, 'I aggree')
- self.assertEqual(obj.comments[1].title, 'Coffee')
+ self.assertEqual(obj.comments[0].title, "I aggree")
+ self.assertEqual(obj.comments[1].title, "Coffee")
self.assertEqual(obj.comments[0].text, None)
self.assertEqual(obj.comments[1].text, None)
- obj = BlogPost.objects.only('comments',).get()
+ obj = BlogPost.objects.only("comments").get()
self.assertEqual(obj.content, None)
self.assertEqual(obj.author, None)
- self.assertEqual(obj.comments[0].title, 'I aggree')
- self.assertEqual(obj.comments[1].title, 'Coffee')
- self.assertEqual(obj.comments[0].text, 'Great post!')
- self.assertEqual(obj.comments[1].text, 'I hate coffee')
+ self.assertEqual(obj.comments[0].title, "I aggree")
+ self.assertEqual(obj.comments[1].title, "Coffee")
+ self.assertEqual(obj.comments[0].text, "Great post!")
+ self.assertEqual(obj.comments[1].text, "I hate coffee")
BlogPost.drop_collection()
@@ -256,15 +259,18 @@ class OnlyExcludeAllTest(unittest.TestCase):
BlogPost.drop_collection()
- post = BlogPost(content='Had a good coffee today...')
- post.author = User(name='Test User')
- post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
+ post = BlogPost(content="Had a good coffee today...")
+ post.author = User(name="Test User")
+ post.comments = [
+ Comment(title="I aggree", text="Great post!"),
+ Comment(title="Coffee", text="I hate coffee"),
+ ]
post.save()
- obj = BlogPost.objects.exclude('author', 'comments.text').get()
+ obj = BlogPost.objects.exclude("author", "comments.text").get()
self.assertEqual(obj.author, None)
- self.assertEqual(obj.content, 'Had a good coffee today...')
- self.assertEqual(obj.comments[0].title, 'I aggree')
+ self.assertEqual(obj.content, "Had a good coffee today...")
+ self.assertEqual(obj.comments[0].title, "I aggree")
self.assertEqual(obj.comments[0].text, None)
BlogPost.drop_collection()
@@ -283,32 +289,43 @@ class OnlyExcludeAllTest(unittest.TestCase):
attachments = ListField(EmbeddedDocumentField(Attachment))
Email.drop_collection()
- email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
+ email = Email(
+ sender="me",
+ to="you",
+ subject="From Russia with Love",
+ body="Hello!",
+ content_type="text/plain",
+ )
email.attachments = [
- Attachment(name='file1.doc', content='ABC'),
- Attachment(name='file2.doc', content='XYZ'),
+ Attachment(name="file1.doc", content="ABC"),
+ Attachment(name="file2.doc", content="XYZ"),
]
email.save()
- obj = Email.objects.exclude('content_type').exclude('body').get()
- self.assertEqual(obj.sender, 'me')
- self.assertEqual(obj.to, 'you')
- self.assertEqual(obj.subject, 'From Russia with Love')
+ obj = Email.objects.exclude("content_type").exclude("body").get()
+ self.assertEqual(obj.sender, "me")
+ self.assertEqual(obj.to, "you")
+ self.assertEqual(obj.subject, "From Russia with Love")
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
- obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get()
+ obj = Email.objects.only("sender", "to").exclude("body", "sender").get()
self.assertEqual(obj.sender, None)
- self.assertEqual(obj.to, 'you')
+ self.assertEqual(obj.to, "you")
self.assertEqual(obj.subject, None)
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
- obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get()
- self.assertEqual(obj.attachments[0].name, 'file1.doc')
+ obj = (
+ Email.objects.exclude("attachments.content")
+ .exclude("body")
+ .only("to", "attachments.name")
+ .get()
+ )
+ self.assertEqual(obj.attachments[0].name, "file1.doc")
self.assertEqual(obj.attachments[0].content, None)
self.assertEqual(obj.sender, None)
- self.assertEqual(obj.to, 'you')
+ self.assertEqual(obj.to, "you")
self.assertEqual(obj.subject, None)
self.assertEqual(obj.body, None)
self.assertEqual(obj.content_type, None)
@@ -316,7 +333,6 @@ class OnlyExcludeAllTest(unittest.TestCase):
Email.drop_collection()
def test_all_fields(self):
-
class Email(Document):
sender = StringField()
to = StringField()
@@ -326,21 +342,33 @@ class OnlyExcludeAllTest(unittest.TestCase):
Email.drop_collection()
- email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain')
+ email = Email(
+ sender="me",
+ to="you",
+ subject="From Russia with Love",
+ body="Hello!",
+ content_type="text/plain",
+ )
email.save()
- obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get()
- self.assertEqual(obj.sender, 'me')
- self.assertEqual(obj.to, 'you')
- self.assertEqual(obj.subject, 'From Russia with Love')
- self.assertEqual(obj.body, 'Hello!')
- self.assertEqual(obj.content_type, 'text/plain')
+ obj = (
+ Email.objects.exclude("content_type", "body")
+ .only("to", "body")
+ .all_fields()
+ .get()
+ )
+ self.assertEqual(obj.sender, "me")
+ self.assertEqual(obj.to, "you")
+ self.assertEqual(obj.subject, "From Russia with Love")
+ self.assertEqual(obj.body, "Hello!")
+ self.assertEqual(obj.content_type, "text/plain")
Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
+
class Numbers(Document):
n = ListField(IntField())
@@ -414,11 +442,10 @@ class OnlyExcludeAllTest(unittest.TestCase):
self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_exclude_from_subclasses_docs(self):
-
class Base(Document):
username = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Anon(Base):
anon = BooleanField()
@@ -436,5 +463,5 @@ class OnlyExcludeAllTest(unittest.TestCase):
self.assertRaises(LookUpError, Base.objects.exclude, "made_up")
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py
index 45e6a089..95dc913d 100644
--- a/tests/queryset/geo.py
+++ b/tests/queryset/geo.py
@@ -10,9 +10,9 @@ __all__ = ("GeoQueriesTest",)
class GeoQueriesTest(MongoDBTestCase):
-
def _create_event_data(self, point_field_class=GeoPointField):
"""Create some sample data re-used in many of the tests below."""
+
class Event(Document):
title = StringField()
date = DateTimeField()
@@ -28,15 +28,18 @@ class GeoQueriesTest(MongoDBTestCase):
event1 = Event.objects.create(
title="Coltrane Motion @ Double Door",
date=datetime.datetime.now() - datetime.timedelta(days=1),
- location=[-87.677137, 41.909889])
+ location=[-87.677137, 41.909889],
+ )
event2 = Event.objects.create(
title="Coltrane Motion @ Bottom of the Hill",
date=datetime.datetime.now() - datetime.timedelta(days=10),
- location=[-122.4194155, 37.7749295])
+ location=[-122.4194155, 37.7749295],
+ )
event3 = Event.objects.create(
title="Coltrane Motion @ Empty Bottle",
date=datetime.datetime.now(),
- location=[-87.686638, 41.900474])
+ location=[-87.686638, 41.900474],
+ )
return event1, event2, event3
@@ -65,8 +68,7 @@ class GeoQueriesTest(MongoDBTestCase):
# find events within 10 degrees of san francisco
point = [-122.415579, 37.7566023]
- events = self.Event.objects(location__near=point,
- location__max_distance=10)
+ events = self.Event.objects(location__near=point, location__max_distance=10)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
@@ -78,8 +80,7 @@ class GeoQueriesTest(MongoDBTestCase):
# find events at least 10 degrees away of san francisco
point = [-122.415579, 37.7566023]
- events = self.Event.objects(location__near=point,
- location__min_distance=10)
+ events = self.Event.objects(location__near=point, location__min_distance=10)
self.assertEqual(events.count(), 2)
def test_within_distance(self):
@@ -88,8 +89,7 @@ class GeoQueriesTest(MongoDBTestCase):
# find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[-87.67892, 41.9120459], 5]
- events = self.Event.objects(
- location__within_distance=point_and_distance)
+ events = self.Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertNotIn(event2, events)
@@ -98,21 +98,18 @@ class GeoQueriesTest(MongoDBTestCase):
# find events within 10 degrees of san francisco
point_and_distance = [[-122.415579, 37.7566023], 10]
- events = self.Event.objects(
- location__within_distance=point_and_distance)
+ events = self.Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[-73.9509714, 40.7237134], 1]
- events = self.Event.objects(
- location__within_distance=point_and_distance)
+ events = self.Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance"
point_and_distance = [[-87.67892, 41.9120459], 10]
- events = self.Event.objects(
- location__within_distance=point_and_distance)
+ events = self.Event.objects(location__within_distance=point_and_distance)
events = events.order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
@@ -145,7 +142,7 @@ class GeoQueriesTest(MongoDBTestCase):
polygon2 = [
(-1.742249, 54.033586),
(-1.225891, 52.792797),
- (-4.40094, 53.389881)
+ (-4.40094, 53.389881),
]
events = self.Event.objects(location__within_polygon=polygon2)
self.assertEqual(events.count(), 0)
@@ -154,9 +151,7 @@ class GeoQueriesTest(MongoDBTestCase):
"""Make sure the "near" operator works with a PointField, which
corresponds to a 2dsphere index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
@@ -175,26 +170,23 @@ class GeoQueriesTest(MongoDBTestCase):
"""Ensure the "max_distance" operator works alongside the "near"
operator with a 2dsphere index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
# find events within 10km of san francisco
point = [-122.415579, 37.7566023]
- events = self.Event.objects(location__near=point,
- location__max_distance=10000)
+ events = self.Event.objects(location__near=point, location__max_distance=10000)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1km of greenpoint, broolyn, nyc, ny
- events = self.Event.objects(location__near=[-73.9509714, 40.7237134],
- location__max_distance=1000)
+ events = self.Event.objects(
+ location__near=[-73.9509714, 40.7237134], location__max_distance=1000
+ )
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "near"
events = self.Event.objects(
- location__near=[-87.67892, 41.9120459],
- location__max_distance=10000
+ location__near=[-87.67892, 41.9120459], location__max_distance=10000
).order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
@@ -203,9 +195,7 @@ class GeoQueriesTest(MongoDBTestCase):
"""Ensure the "geo_within_box" operator works with a 2dsphere
index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
# check that within_box works
box = [(-125.0, 35.0), (-100.0, 40.0)]
@@ -217,9 +207,7 @@ class GeoQueriesTest(MongoDBTestCase):
"""Ensure the "geo_within_polygon" operator works with a
2dsphere index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
polygon = [
(-87.694445, 41.912114),
@@ -235,7 +223,7 @@ class GeoQueriesTest(MongoDBTestCase):
polygon2 = [
(-1.742249, 54.033586),
(-1.225891, 52.792797),
- (-4.40094, 53.389881)
+ (-4.40094, 53.389881),
]
events = self.Event.objects(location__geo_within_polygon=polygon2)
self.assertEqual(events.count(), 0)
@@ -244,23 +232,20 @@ class GeoQueriesTest(MongoDBTestCase):
"""Ensure "min_distace" and "max_distance" operators work well
together with the "near" operator in a 2dsphere index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
# ensure min_distance and max_distance combine well
events = self.Event.objects(
location__near=[-87.67892, 41.9120459],
location__min_distance=1000,
- location__max_distance=10000
+ location__max_distance=10000,
).order_by("-date")
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event3)
# ensure ordering is respected by "near" with "min_distance"
events = self.Event.objects(
- location__near=[-87.67892, 41.9120459],
- location__min_distance=10000
+ location__near=[-87.67892, 41.9120459], location__min_distance=10000
).order_by("-date")
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
@@ -269,14 +254,11 @@ class GeoQueriesTest(MongoDBTestCase):
"""Make sure the "geo_within_center" operator works with a
2dsphere index.
"""
- event1, event2, event3 = self._create_event_data(
- point_field_class=PointField
- )
+ event1, event2, event3 = self._create_event_data(point_field_class=PointField)
# find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[-87.67892, 41.9120459], 2]
- events = self.Event.objects(
- location__geo_within_center=point_and_distance)
+ events = self.Event.objects(location__geo_within_center=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertNotIn(event2, events)
@@ -287,6 +269,7 @@ class GeoQueriesTest(MongoDBTestCase):
"""Helper test method ensuring given point field class works
well in an embedded document.
"""
+
class Venue(EmbeddedDocument):
location = point_field_class()
name = StringField()
@@ -300,12 +283,11 @@ class GeoQueriesTest(MongoDBTestCase):
venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889])
venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295])
- event1 = Event(title="Coltrane Motion @ Double Door",
- venue=venue1).save()
- event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
- venue=venue2).save()
- event3 = Event(title="Coltrane Motion @ Empty Bottle",
- venue=venue1).save()
+ event1 = Event(title="Coltrane Motion @ Double Door", venue=venue1).save()
+ event2 = Event(
+ title="Coltrane Motion @ Bottom of the Hill", venue=venue2
+ ).save()
+ event3 = Event(title="Coltrane Motion @ Empty Bottle", venue=venue1).save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
@@ -324,6 +306,7 @@ class GeoQueriesTest(MongoDBTestCase):
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working."""
+
class Point(Document):
location = GeoPointField()
@@ -343,26 +326,26 @@ class GeoQueriesTest(MongoDBTestCase):
# Same behavior for _within_spherical_distance
points = Point.objects(
- location__within_spherical_distance=[
- [-122, 37.5],
- 60 / earth_radius
- ]
+ location__within_spherical_distance=[[-122, 37.5], 60 / earth_radius]
)
self.assertEqual(points.count(), 2)
- points = Point.objects(location__near_sphere=[-122, 37.5],
- location__max_distance=60 / earth_radius)
+ points = Point.objects(
+ location__near_sphere=[-122, 37.5], location__max_distance=60 / earth_radius
+ )
self.assertEqual(points.count(), 2)
# Test query works with max_distance, being farer from one point
- points = Point.objects(location__near_sphere=[-122, 37.8],
- location__max_distance=60 / earth_radius)
+ points = Point.objects(
+ location__near_sphere=[-122, 37.8], location__max_distance=60 / earth_radius
+ )
close_point = points.first()
self.assertEqual(points.count(), 1)
# Test query works with min_distance, being farer from one point
- points = Point.objects(location__near_sphere=[-122, 37.8],
- location__min_distance=60 / earth_radius)
+ points = Point.objects(
+ location__near_sphere=[-122, 37.8], location__min_distance=60 / earth_radius
+ )
self.assertEqual(points.count(), 1)
far_point = points.first()
self.assertNotEqual(close_point, far_point)
@@ -384,10 +367,7 @@ class GeoQueriesTest(MongoDBTestCase):
# Finds only one point because only the first point is within 60km of
# the reference point to the south.
points = Point.objects(
- location__within_spherical_distance=[
- [-122, 36.5],
- 60 / earth_radius
- ]
+ location__within_spherical_distance=[[-122, 36.5], 60 / earth_radius]
)
self.assertEqual(points.count(), 1)
self.assertEqual(points[0].id, south_point.id)
@@ -413,8 +393,10 @@ class GeoQueriesTest(MongoDBTestCase):
self.assertEqual(1, roads)
# Within
- polygon = {"type": "Polygon",
- "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
+ polygon = {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]],
+ }
roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count()
self.assertEqual(1, roads)
@@ -425,8 +407,7 @@ class GeoQueriesTest(MongoDBTestCase):
self.assertEqual(1, roads)
# Intersects
- line = {"type": "LineString",
- "coordinates": [[40, 5], [40, 6]]}
+ line = {"type": "LineString", "coordinates": [[40, 5], [40, 6]]}
roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count()
self.assertEqual(1, roads)
@@ -436,8 +417,10 @@ class GeoQueriesTest(MongoDBTestCase):
roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count()
self.assertEqual(1, roads)
- polygon = {"type": "Polygon",
- "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
+ polygon = {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]],
+ }
roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count()
self.assertEqual(1, roads)
@@ -468,8 +451,10 @@ class GeoQueriesTest(MongoDBTestCase):
self.assertEqual(1, roads)
# Within
- polygon = {"type": "Polygon",
- "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
+ polygon = {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]],
+ }
roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count()
self.assertEqual(1, roads)
@@ -480,8 +465,7 @@ class GeoQueriesTest(MongoDBTestCase):
self.assertEqual(1, roads)
# Intersects
- line = {"type": "LineString",
- "coordinates": [[40, 5], [41, 6]]}
+ line = {"type": "LineString", "coordinates": [[40, 5], [41, 6]]}
roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count()
self.assertEqual(1, roads)
@@ -491,8 +475,10 @@ class GeoQueriesTest(MongoDBTestCase):
roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count()
self.assertEqual(1, roads)
- polygon = {"type": "Polygon",
- "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
+ polygon = {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]],
+ }
roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count()
self.assertEqual(1, roads)
@@ -504,20 +490,20 @@ class GeoQueriesTest(MongoDBTestCase):
def test_aspymongo_with_only(self):
"""Ensure as_pymongo works with only"""
+
class Place(Document):
location = PointField()
Place.drop_collection()
p = Place(location=[24.946861267089844, 60.16311983618494])
p.save()
- qs = Place.objects().only('location')
+ qs = Place.objects().only("location")
self.assertDictEqual(
- qs.as_pymongo()[0]['location'],
- {u'type': u'Point',
- u'coordinates': [
- 24.946861267089844,
- 60.16311983618494]
- }
+ qs.as_pymongo()[0]["location"],
+ {
+ u"type": u"Point",
+ u"coordinates": [24.946861267089844, 60.16311983618494],
+ },
)
def test_2dsphere_point_sets_correctly(self):
@@ -542,11 +528,15 @@ class GeoQueriesTest(MongoDBTestCase):
Location(line=[[1, 2], [2, 2]]).save()
loc = Location.objects.as_pymongo()[0]
- self.assertEqual(loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]})
+ self.assertEqual(
+ loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}
+ )
Location.objects.update(set__line=[[2, 1], [1, 2]])
loc = Location.objects.as_pymongo()[0]
- self.assertEqual(loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]})
+ self.assertEqual(
+ loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]}
+ )
def test_geojson_PolygonField(self):
class Location(Document):
@@ -556,12 +546,18 @@ class GeoQueriesTest(MongoDBTestCase):
Location(poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save()
loc = Location.objects.as_pymongo()[0]
- self.assertEqual(loc["poly"], {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]})
+ self.assertEqual(
+ loc["poly"],
+ {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]},
+ )
Location.objects.update(set__poly=[[[40, 4], [40, 6], [41, 6], [40, 4]]])
loc = Location.objects.as_pymongo()[0]
- self.assertEqual(loc["poly"], {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]})
+ self.assertEqual(
+ loc["poly"],
+ {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]},
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py
index 3c5879ba..e092d11c 100644
--- a/tests/queryset/modify.py
+++ b/tests/queryset/modify.py
@@ -11,7 +11,6 @@ class Doc(Document):
class FindAndModifyTest(unittest.TestCase):
-
def setUp(self):
connect(db="mongoenginetest")
Doc.drop_collection()
@@ -82,9 +81,14 @@ class FindAndModifyTest(unittest.TestCase):
old_doc = Doc.objects().order_by("-id").modify(set__value=-1)
self.assertEqual(old_doc.to_json(), doc.to_json())
- self.assertDbEqual([
- {"_id": 0, "value": 3}, {"_id": 1, "value": 2},
- {"_id": 2, "value": 1}, {"_id": 3, "value": -1}])
+ self.assertDbEqual(
+ [
+ {"_id": 0, "value": 3},
+ {"_id": 1, "value": 2},
+ {"_id": 2, "value": 1},
+ {"_id": 3, "value": -1},
+ ]
+ )
def test_modify_with_fields(self):
Doc(id=0, value=0).save()
@@ -103,27 +107,25 @@ class FindAndModifyTest(unittest.TestCase):
blog = BlogPost.objects.create()
# Push a new tag via modify with new=False (default).
- BlogPost(id=blog.id).modify(push__tags='code')
+ BlogPost(id=blog.id).modify(push__tags="code")
self.assertEqual(blog.tags, [])
blog.reload()
- self.assertEqual(blog.tags, ['code'])
+ self.assertEqual(blog.tags, ["code"])
# Push a new tag via modify with new=True.
- blog = BlogPost.objects(id=blog.id).modify(push__tags='java', new=True)
- self.assertEqual(blog.tags, ['code', 'java'])
+ blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True)
+ self.assertEqual(blog.tags, ["code", "java"])
# Push a new tag with a positional argument.
- blog = BlogPost.objects(id=blog.id).modify(
- push__tags__0='python',
- new=True)
- self.assertEqual(blog.tags, ['python', 'code', 'java'])
+ blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True)
+ self.assertEqual(blog.tags, ["python", "code", "java"])
# Push multiple new tags with a positional argument.
blog = BlogPost.objects(id=blog.id).modify(
- push__tags__1=['go', 'rust'],
- new=True)
- self.assertEqual(blog.tags, ['python', 'go', 'rust', 'code', 'java'])
+ push__tags__1=["go", "rust"], new=True
+ )
+ self.assertEqual(blog.tags, ["python", "go", "rust", "code", "java"])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/queryset/pickable.py b/tests/queryset/pickable.py
index bf7bb31c..0945fcbc 100644
--- a/tests/queryset/pickable.py
+++ b/tests/queryset/pickable.py
@@ -4,7 +4,7 @@ from pymongo.mongo_client import MongoClient
from mongoengine import Document, StringField, IntField
from mongoengine.connection import connect
-__author__ = 'stas'
+__author__ = "stas"
class Person(Document):
@@ -17,6 +17,7 @@ class TestQuerysetPickable(unittest.TestCase):
Test for adding pickling support for QuerySet instances
See issue https://github.com/MongoEngine/mongoengine/issues/442
"""
+
def setUp(self):
super(TestQuerysetPickable, self).setUp()
@@ -24,10 +25,7 @@ class TestQuerysetPickable(unittest.TestCase):
connection.drop_database("test")
- self.john = Person.objects.create(
- name="John",
- age=21
- )
+ self.john = Person.objects.create(name="John", age=21)
def test_picke_simple_qs(self):
@@ -54,15 +52,9 @@ class TestQuerysetPickable(unittest.TestCase):
self.assertEqual(Person.objects.first().age, 23)
def test_pickle_support_filtration(self):
- Person.objects.create(
- name="Alice",
- age=22
- )
+ Person.objects.create(name="Alice", age=22)
- Person.objects.create(
- name="Bob",
- age=23
- )
+ Person.objects.create(name="Bob", age=23)
qs = Person.objects.filter(age__gte=22)
self.assertEqual(qs.count(), 2)
@@ -71,9 +63,3 @@ class TestQuerysetPickable(unittest.TestCase):
self.assertEqual(loaded.count(), 2)
self.assertEqual(loaded.filter(name="Bob").first().age, 23)
-
-
-
-
-
-
diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py
index c86e4095..21f35012 100644
--- a/tests/queryset/queryset.py
+++ b/tests/queryset/queryset.py
@@ -17,29 +17,34 @@ from mongoengine.connection import get_connection, get_db
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import InvalidQueryError
from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version
-from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
- QuerySet, QuerySetManager, queryset_manager)
+from mongoengine.queryset import (
+ DoesNotExist,
+ MultipleObjectsReturned,
+ QuerySet,
+ QuerySetManager,
+ queryset_manager,
+)
class db_ops_tracker(query_counter):
-
def get_ops(self):
ignore_query = dict(self._ignored_query)
- ignore_query['command.count'] = {'$ne': 'system.profile'} # Ignore the query issued by query_counter
+ ignore_query["command.count"] = {
+ "$ne": "system.profile"
+ } # Ignore the query issued by query_counter
return list(self.db.system.profile.find(ignore_query))
def get_key_compat(mongo_ver):
- ORDER_BY_KEY = 'sort'
- CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query'
+ ORDER_BY_KEY = "sort"
+ CMD_QUERY_KEY = "command" if mongo_ver >= MONGODB_36 else "query"
return ORDER_BY_KEY, CMD_QUERY_KEY
class QuerySetTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
- connect(db='mongoenginetest2', alias='test2')
+ connect(db="mongoenginetest")
+ connect(db="mongoenginetest2", alias="test2")
class PersonMeta(EmbeddedDocument):
weight = IntField()
@@ -48,7 +53,7 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
age = IntField()
person_meta = EmbeddedDocumentField(PersonMeta)
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
Person.drop_collection()
self.PersonMeta = PersonMeta
@@ -60,12 +65,14 @@ class QuerySetTest(unittest.TestCase):
"""Ensure that a QuerySet is correctly initialised by QuerySetManager.
"""
self.assertIsInstance(self.Person.objects, QuerySet)
- self.assertEqual(self.Person.objects._collection.name,
- self.Person._get_collection_name())
- self.assertIsInstance(self.Person.objects._collection, pymongo.collection.Collection)
+ self.assertEqual(
+ self.Person.objects._collection.name, self.Person._get_collection_name()
+ )
+ self.assertIsInstance(
+ self.Person.objects._collection, pymongo.collection.Collection
+ )
def test_cannot_perform_joins_references(self):
-
class BlogPost(Document):
author = ReferenceField(self.Person)
author2 = GenericReferenceField()
@@ -80,8 +87,8 @@ class QuerySetTest(unittest.TestCase):
def test_find(self):
"""Ensure that a query returns a valid set of results."""
- user_a = self.Person.objects.create(name='User A', age=20)
- user_b = self.Person.objects.create(name='User B', age=30)
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
# Find all people in the collection
people = self.Person.objects
@@ -92,11 +99,11 @@ class QuerySetTest(unittest.TestCase):
self.assertIsInstance(results[0].id, ObjectId)
self.assertEqual(results[0], user_a)
- self.assertEqual(results[0].name, 'User A')
+ self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0].age, 20)
self.assertEqual(results[1], user_b)
- self.assertEqual(results[1].name, 'User B')
+ self.assertEqual(results[1].name, "User B")
self.assertEqual(results[1].age, 30)
# Filter people by age
@@ -109,8 +116,8 @@ class QuerySetTest(unittest.TestCase):
def test_limit(self):
"""Ensure that QuerySet.limit works as expected."""
- user_a = self.Person.objects.create(name='User A', age=20)
- user_b = self.Person.objects.create(name='User B', age=30)
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
# Test limit on a new queryset
people = list(self.Person.objects.limit(1))
@@ -131,15 +138,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 2)
# Test chaining of only after limit
- person = self.Person.objects().limit(1).only('name').first()
+ person = self.Person.objects().limit(1).only("name").first()
self.assertEqual(person, user_a)
- self.assertEqual(person.name, 'User A')
+ self.assertEqual(person.name, "User A")
self.assertEqual(person.age, None)
def test_skip(self):
"""Ensure that QuerySet.skip works as expected."""
- user_a = self.Person.objects.create(name='User A', age=20)
- user_b = self.Person.objects.create(name='User B', age=30)
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
# Test skip on a new queryset
people = list(self.Person.objects.skip(1))
@@ -155,20 +162,20 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(people2[0], user_b)
# Test chaining of only after skip
- person = self.Person.objects().skip(1).only('name').first()
+ person = self.Person.objects().skip(1).only("name").first()
self.assertEqual(person, user_b)
- self.assertEqual(person.name, 'User B')
+ self.assertEqual(person.name, "User B")
self.assertEqual(person.age, None)
def test___getitem___invalid_index(self):
"""Ensure slicing a queryset works as expected."""
with self.assertRaises(TypeError):
- self.Person.objects()['a']
+ self.Person.objects()["a"]
def test_slice(self):
"""Ensure slicing a queryset works as expected."""
- user_a = self.Person.objects.create(name='User A', age=20)
- user_b = self.Person.objects.create(name='User B', age=30)
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
user_c = self.Person.objects.create(name="User C", age=40)
# Test slice limit
@@ -202,7 +209,7 @@ class QuerySetTest(unittest.TestCase):
qs._cursor_obj = None
people = list(qs)
self.assertEqual(len(people), 1)
- self.assertEqual(people[0].name, 'User B')
+ self.assertEqual(people[0].name, "User B")
# Test empty slice
people = list(self.Person.objects[1:1])
@@ -215,14 +222,18 @@ class QuerySetTest(unittest.TestCase):
# Test larger slice __repr__
self.Person.objects.delete()
for i in range(55):
- self.Person(name='A%s' % i, age=i).save()
+ self.Person(name="A%s" % i, age=i).save()
self.assertEqual(self.Person.objects.count(), 55)
self.assertEqual("Person object", "%s" % self.Person.objects[0])
- self.assertEqual("[, ]",
- "%s" % self.Person.objects[1:3])
- self.assertEqual("[, ]",
- "%s" % self.Person.objects[51:53])
+ self.assertEqual(
+ "[, ]",
+ "%s" % self.Person.objects[1:3],
+ )
+ self.assertEqual(
+ "[, ]",
+ "%s" % self.Person.objects[51:53],
+ )
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
@@ -276,8 +287,7 @@ class QuerySetTest(unittest.TestCase):
# Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.Person.objects.get)
- self.assertRaises(self.Person.MultipleObjectsReturned,
- self.Person.objects.get)
+ self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get)
# Use a query to filter the people found to just person2
person = self.Person.objects.get(age=30)
@@ -289,6 +299,7 @@ class QuerySetTest(unittest.TestCase):
def test_find_array_position(self):
"""Ensure that query by array position works.
"""
+
class Comment(EmbeddedDocument):
name = StringField()
@@ -301,34 +312,34 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
- Blog.objects.create(tags=['a', 'b'])
- self.assertEqual(Blog.objects(tags__0='a').count(), 1)
- self.assertEqual(Blog.objects(tags__0='b').count(), 0)
- self.assertEqual(Blog.objects(tags__1='a').count(), 0)
- self.assertEqual(Blog.objects(tags__1='b').count(), 1)
+ Blog.objects.create(tags=["a", "b"])
+ self.assertEqual(Blog.objects(tags__0="a").count(), 1)
+ self.assertEqual(Blog.objects(tags__0="b").count(), 0)
+ self.assertEqual(Blog.objects(tags__1="a").count(), 0)
+ self.assertEqual(Blog.objects(tags__1="b").count(), 1)
Blog.drop_collection()
- comment1 = Comment(name='testa')
- comment2 = Comment(name='testb')
+ comment1 = Comment(name="testa")
+ comment2 = Comment(name="testb")
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
- blog = Blog.objects(posts__0__comments__0__name='testa').get()
+ blog = Blog.objects(posts__0__comments__0__name="testa").get()
self.assertEqual(blog, blog1)
- blog = Blog.objects(posts__0__comments__0__name='testb').get()
+ blog = Blog.objects(posts__0__comments__0__name="testb").get()
self.assertEqual(blog, blog2)
- query = Blog.objects(posts__1__comments__1__name='testb')
+ query = Blog.objects(posts__1__comments__1__name="testb")
self.assertEqual(query.count(), 2)
- query = Blog.objects(posts__1__comments__1__name='testa')
+ query = Blog.objects(posts__1__comments__1__name="testa")
self.assertEqual(query.count(), 0)
- query = Blog.objects(posts__0__comments__1__name='testa')
+ query = Blog.objects(posts__0__comments__1__name="testa")
self.assertEqual(query.count(), 0)
Blog.drop_collection()
@@ -367,13 +378,14 @@ class QuerySetTest(unittest.TestCase):
q2 = q2.filter(ref=a1)._query
self.assertEqual(q1, q2)
- a_objects = A.objects(s='test1')
+ a_objects = A.objects(s="test1")
query = B.objects(ref__in=a_objects)
query = query.filter(boolfield=True)
self.assertEqual(query.count(), 1)
def test_batch_size(self):
"""Ensure that batch_size works."""
+
class A(Document):
s = StringField()
@@ -416,33 +428,33 @@ class QuerySetTest(unittest.TestCase):
self.Person.drop_collection()
write_concern = {"fsync": True}
- author = self.Person.objects.create(name='Test User')
+ author = self.Person.objects.create(name="Test User")
author.save(write_concern=write_concern)
# Ensure no regression of #1958
- author = self.Person(name='Test User2')
+ author = self.Person(name="Test User2")
author.save(write_concern=None) # will default to {w: 1}
- result = self.Person.objects.update(
- set__name='Ross', write_concern={"w": 1})
+ result = self.Person.objects.update(set__name="Ross", write_concern={"w": 1})
self.assertEqual(result, 2)
- result = self.Person.objects.update(
- set__name='Ross', write_concern={"w": 0})
+ result = self.Person.objects.update(set__name="Ross", write_concern={"w": 0})
self.assertEqual(result, None)
result = self.Person.objects.update_one(
- set__name='Test User', write_concern={"w": 1})
+ set__name="Test User", write_concern={"w": 1}
+ )
self.assertEqual(result, 1)
result = self.Person.objects.update_one(
- set__name='Test User', write_concern={"w": 0})
+ set__name="Test User", write_concern={"w": 0}
+ )
self.assertEqual(result, None)
def test_update_update_has_a_value(self):
"""Test to ensure that update is passed a value to update to"""
self.Person.drop_collection()
- author = self.Person.objects.create(name='Test User')
+ author = self.Person.objects.create(name="Test User")
with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update({})
@@ -457,6 +469,7 @@ class QuerySetTest(unittest.TestCase):
set__posts__1__comments__1__name="testc"
Check that it only works for ListFields.
"""
+
class Comment(EmbeddedDocument):
name = StringField()
@@ -469,16 +482,16 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
- comment1 = Comment(name='testa')
- comment2 = Comment(name='testb')
+ comment1 = Comment(name="testa")
+ comment2 = Comment(name="testb")
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
Blog.objects.create(posts=[post1, post2])
Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
- Blog.objects().update(set__posts__1__comments__0__name='testc')
- testc_blogs = Blog.objects(posts__1__comments__0__name='testc')
+ Blog.objects().update(set__posts__1__comments__0__name="testc")
+ testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
self.assertEqual(testc_blogs.count(), 2)
Blog.drop_collection()
@@ -486,14 +499,13 @@ class QuerySetTest(unittest.TestCase):
Blog.objects.create(posts=[post2, post1])
# Update only the first blog returned by the query
- Blog.objects().update_one(
- set__posts__1__comments__1__name='testc')
- testc_blogs = Blog.objects(posts__1__comments__1__name='testc')
+ Blog.objects().update_one(set__posts__1__comments__1__name="testc")
+ testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
self.assertEqual(testc_blogs.count(), 1)
# Check that using this indexing syntax on a non-list fails
with self.assertRaises(InvalidQueryError):
- Blog.objects().update(set__posts__1__comments__0__name__1='asdf')
+ Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
Blog.drop_collection()
@@ -519,7 +531,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects(comments__by="jane").update(inc__comments__S__votes=1)
post = BlogPost.objects.first()
- self.assertEqual(post.comments[1].by, 'jane')
+ self.assertEqual(post.comments[1].by, "jane")
self.assertEqual(post.comments[1].votes, 8)
def test_update_using_positional_operator_matches_first(self):
@@ -563,7 +575,7 @@ class QuerySetTest(unittest.TestCase):
# Nested updates arent supported yet..
with self.assertRaises(OperationError):
Simple.drop_collection()
- Simple(x=[{'test': [1, 2, 3, 4]}]).save()
+ Simple(x=[{"test": [1, 2, 3, 4]}]).save()
Simple.objects(x__test=2).update(set__x__S__test__S=3)
self.assertEqual(simple.x, [1, 2, 3, 4])
@@ -590,10 +602,11 @@ class QuerySetTest(unittest.TestCase):
BlogPost(title="ABC", comments=[c1, c2]).save()
BlogPost.objects(comments__by="joe").update(
- set__comments__S__votes=Vote(score=4))
+ set__comments__S__votes=Vote(score=4)
+ )
post = BlogPost.objects.first()
- self.assertEqual(post.comments[0].by, 'joe')
+ self.assertEqual(post.comments[0].by, "joe")
self.assertEqual(post.comments[0].votes.score, 4)
def test_update_min_max(self):
@@ -618,16 +631,15 @@ class QuerySetTest(unittest.TestCase):
item = StringField()
price = FloatField()
- product = Product.objects.create(item='ABC', price=10.99)
- product = Product.objects.create(item='ABC', price=10.99)
+ product = Product.objects.create(item="ABC", price=10.99)
+ product = Product.objects.create(item="ABC", price=10.99)
Product.objects(id=product.id).update(mul__price=1.25)
self.assertEqual(Product.objects.get(id=product.id).price, 13.7375)
- unknown_product = Product.objects.create(item='Unknown')
+ unknown_product = Product.objects.create(item="Unknown")
Product.objects(id=unknown_product.id).update(mul__price=100)
self.assertEqual(Product.objects.get(id=unknown_product.id).price, 0)
def test_updates_can_have_match_operators(self):
-
class Comment(EmbeddedDocument):
content = StringField()
name = StringField(max_length=120)
@@ -643,8 +655,11 @@ class QuerySetTest(unittest.TestCase):
comm1 = Comment(content="very funny indeed", name="John S", vote=1)
comm2 = Comment(content="kind of funny", name="Mark P", vote=0)
- Post(title='Fun with MongoEngine', tags=['mongodb', 'mongoengine'],
- comments=[comm1, comm2]).save()
+ Post(
+ title="Fun with MongoEngine",
+ tags=["mongodb", "mongoengine"],
+ comments=[comm1, comm2],
+ ).save()
Post.objects().update_one(pull__comments__vote__lt=1)
@@ -652,6 +667,7 @@ class QuerySetTest(unittest.TestCase):
def test_mapfield_update(self):
"""Ensure that the MapField can be updated."""
+
class Member(EmbeddedDocument):
gender = StringField()
age = IntField()
@@ -662,37 +678,35 @@ class QuerySetTest(unittest.TestCase):
Club.drop_collection()
club = Club()
- club.members['John'] = Member(gender="M", age=13)
+ club.members["John"] = Member(gender="M", age=13)
club.save()
- Club.objects().update(
- set__members={"John": Member(gender="F", age=14)})
+ Club.objects().update(set__members={"John": Member(gender="F", age=14)})
club = Club.objects().first()
- self.assertEqual(club.members['John'].gender, "F")
- self.assertEqual(club.members['John'].age, 14)
+ self.assertEqual(club.members["John"].gender, "F")
+ self.assertEqual(club.members["John"].age, 14)
def test_dictfield_update(self):
"""Ensure that the DictField can be updated."""
+
class Club(Document):
members = DictField()
club = Club()
- club.members['John'] = {'gender': 'M', 'age': 13}
+ club.members["John"] = {"gender": "M", "age": 13}
club.save()
- Club.objects().update(
- set__members={"John": {'gender': 'F', 'age': 14}})
+ Club.objects().update(set__members={"John": {"gender": "F", "age": 14}})
club = Club.objects().first()
- self.assertEqual(club.members['John']['gender'], "F")
- self.assertEqual(club.members['John']['age'], 14)
+ self.assertEqual(club.members["John"]["gender"], "F")
+ self.assertEqual(club.members["John"]["age"], 14)
def test_update_results(self):
self.Person.drop_collection()
- result = self.Person(name="Bob", age=25).update(
- upsert=True, full_result=True)
+ result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True)
self.assertIsInstance(result, UpdateResult)
self.assertIn("upserted", result.raw_result)
self.assertFalse(result.raw_result["updatedExisting"])
@@ -703,8 +717,7 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(result.raw_result["updatedExisting"])
self.Person(name="Bob", age=20).save()
- result = self.Person.objects(name="Bob").update(
- set__name="bobby", multi=True)
+ result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True)
self.assertEqual(result, 2)
def test_update_validate(self):
@@ -718,8 +731,12 @@ class QuerySetTest(unittest.TestCase):
ed_f = EmbeddedDocumentField(EmDoc)
self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True)
- self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True)
- self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
+ self.assertRaises(
+ ValidationError, Doc.objects().update, dt_f="datetime", upsert=True
+ )
+ self.assertRaises(
+ ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True
+ )
def test_update_related_models(self):
class TestPerson(Document):
@@ -732,34 +749,33 @@ class QuerySetTest(unittest.TestCase):
TestPerson.drop_collection()
TestOrganization.drop_collection()
- p = TestPerson(name='p1')
+ p = TestPerson(name="p1")
p.save()
- o = TestOrganization(name='o1')
+ o = TestOrganization(name="o1")
o.save()
o.owner = p
- p.name = 'p2'
+ p.name = "p2"
- self.assertEqual(o._get_changed_fields(), ['owner'])
- self.assertEqual(p._get_changed_fields(), ['name'])
+ self.assertEqual(o._get_changed_fields(), ["owner"])
+ self.assertEqual(p._get_changed_fields(), ["name"])
o.save()
self.assertEqual(o._get_changed_fields(), [])
- self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty
+ self.assertEqual(p._get_changed_fields(), ["name"]) # Fails; it's empty
# This will do NOTHING at all, even though we changed the name
p.save()
p.reload()
- self.assertEqual(p.name, 'p2') # Fails; it's still `p1`
+ self.assertEqual(p.name, "p2") # Fails; it's still `p1`
def test_upsert(self):
self.Person.drop_collection()
- self.Person.objects(
- pk=ObjectId(), name="Bob", age=30).update(upsert=True)
+ self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True)
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
@@ -786,7 +802,8 @@ class QuerySetTest(unittest.TestCase):
self.Person.drop_collection()
self.Person.objects(pk=ObjectId()).update(
- set__name='Bob', set_on_insert__age=30, upsert=True)
+ set__name="Bob", set_on_insert__age=30, upsert=True
+ )
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
@@ -797,7 +814,7 @@ class QuerySetTest(unittest.TestCase):
field = IntField()
class B(Document):
- meta = {'collection': 'b'}
+ meta = {"collection": "b"}
field = IntField(default=1)
embed = EmbeddedDocumentField(Embed, default=Embed)
@@ -820,7 +837,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(record.embed.field, 2)
# Request only the _id field and save
- clone = B.objects().only('id').first()
+ clone = B.objects().only("id").first()
clone.save()
# Reload the record and see that the embed data is not lost
@@ -831,6 +848,7 @@ class QuerySetTest(unittest.TestCase):
def test_bulk_insert(self):
"""Ensure that bulk insert works"""
+
class Comment(EmbeddedDocument):
name = StringField()
@@ -847,14 +865,13 @@ class QuerySetTest(unittest.TestCase):
# Recreates the collection
self.assertEqual(0, Blog.objects.count())
- comment1 = Comment(name='testa')
- comment2 = Comment(name='testb')
+ comment1 = Comment(name="testa")
+ comment2 = Comment(name="testb")
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
# Check bulk insert using load_bulk=False
- blogs = [Blog(title="%s" % i, posts=[post1, post2])
- for i in range(99)]
+ blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs, load_bulk=False)
@@ -866,8 +883,7 @@ class QuerySetTest(unittest.TestCase):
Blog.ensure_indexes()
# Check bulk insert using load_bulk=True
- blogs = [Blog(title="%s" % i, posts=[post1, post2])
- for i in range(99)]
+ blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)]
with query_counter() as q:
self.assertEqual(q, 0)
Blog.objects.insert(blogs)
@@ -875,8 +891,8 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
- comment1 = Comment(name='testa')
- comment2 = Comment(name='testb')
+ comment1 = Comment(name="testa")
+ comment2 = Comment(name="testb")
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog(title="code", posts=[post1, post2])
@@ -892,8 +908,7 @@ class QuerySetTest(unittest.TestCase):
blog = Blog.objects.first()
Blog.objects.insert(blog)
self.assertEqual(
- str(cm.exception),
- 'Some documents have ObjectIds, use doc.update() instead'
+ str(cm.exception), "Some documents have ObjectIds, use doc.update() instead"
)
# test inserting a query set
@@ -901,8 +916,7 @@ class QuerySetTest(unittest.TestCase):
blogs_qs = Blog.objects
Blog.objects.insert(blogs_qs)
self.assertEqual(
- str(cm.exception),
- 'Some documents have ObjectIds, use doc.update() instead'
+ str(cm.exception), "Some documents have ObjectIds, use doc.update() instead"
)
# insert 1 new doc
@@ -948,13 +962,13 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
Blog.drop_collection()
- Blog(name='test').save()
+ Blog(name="test").save()
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
with self.assertRaises(OperationError):
- Blog.objects.insert({'name': 'garbage'})
+ Blog.objects.insert({"name": "garbage"})
def test_bulk_insert_update_input_document_ids(self):
class Comment(Document):
@@ -1010,10 +1024,11 @@ class QuerySetTest(unittest.TestCase):
"""Make sure we don't perform unnecessary db operations when
none of document's fields were updated.
"""
+
class Person(Document):
name = StringField()
- owns = ListField(ReferenceField('Organization'))
- projects = ListField(ReferenceField('Project'))
+ owns = ListField(ReferenceField("Organization"))
+ projects = ListField(ReferenceField("Project"))
class Organization(Document):
name = StringField()
@@ -1070,8 +1085,8 @@ class QuerySetTest(unittest.TestCase):
def test_repeated_iteration(self):
"""Ensure that QuerySet rewinds itself one iteration finishes.
"""
- self.Person(name='Person 1').save()
- self.Person(name='Person 2').save()
+ self.Person(name="Person 1").save()
+ self.Person(name="Person 2").save()
queryset = self.Person.objects
people1 = [person for person in queryset]
@@ -1099,7 +1114,7 @@ class QuerySetTest(unittest.TestCase):
for i in range(1000):
Doc(number=i).save()
- docs = Doc.objects.order_by('number')
+ docs = Doc.objects.order_by("number")
self.assertEqual(docs.count(), 1000)
@@ -1107,88 +1122,89 @@ class QuerySetTest(unittest.TestCase):
self.assertIn("Doc: 0", docs_string)
self.assertEqual(docs.count(), 1000)
- self.assertIn('(remaining elements truncated)', "%s" % docs)
+ self.assertIn("(remaining elements truncated)", "%s" % docs)
# Limit and skip
docs = docs[1:4]
- self.assertEqual('[, , ]', "%s" % docs)
+ self.assertEqual("[, , ]", "%s" % docs)
self.assertEqual(docs.count(with_limit_and_skip=True), 3)
for doc in docs:
- self.assertEqual('.. queryset mid-iteration ..', repr(docs))
+ self.assertEqual(".. queryset mid-iteration ..", repr(docs))
def test_regex_query_shortcuts(self):
"""Ensure that contains, startswith, endswith, etc work.
"""
- person = self.Person(name='Guido van Rossum')
+ person = self.Person(name="Guido van Rossum")
person.save()
# Test contains
- obj = self.Person.objects(name__contains='van').first()
+ obj = self.Person.objects(name__contains="van").first()
self.assertEqual(obj, person)
- obj = self.Person.objects(name__contains='Van').first()
+ obj = self.Person.objects(name__contains="Van").first()
self.assertEqual(obj, None)
# Test icontains
- obj = self.Person.objects(name__icontains='Van').first()
+ obj = self.Person.objects(name__icontains="Van").first()
self.assertEqual(obj, person)
# Test startswith
- obj = self.Person.objects(name__startswith='Guido').first()
+ obj = self.Person.objects(name__startswith="Guido").first()
self.assertEqual(obj, person)
- obj = self.Person.objects(name__startswith='guido').first()
+ obj = self.Person.objects(name__startswith="guido").first()
self.assertEqual(obj, None)
# Test istartswith
- obj = self.Person.objects(name__istartswith='guido').first()
+ obj = self.Person.objects(name__istartswith="guido").first()
self.assertEqual(obj, person)
# Test endswith
- obj = self.Person.objects(name__endswith='Rossum').first()
+ obj = self.Person.objects(name__endswith="Rossum").first()
self.assertEqual(obj, person)
- obj = self.Person.objects(name__endswith='rossuM').first()
+ obj = self.Person.objects(name__endswith="rossuM").first()
self.assertEqual(obj, None)
# Test iendswith
- obj = self.Person.objects(name__iendswith='rossuM').first()
+ obj = self.Person.objects(name__iendswith="rossuM").first()
self.assertEqual(obj, person)
# Test exact
- obj = self.Person.objects(name__exact='Guido van Rossum').first()
+ obj = self.Person.objects(name__exact="Guido van Rossum").first()
self.assertEqual(obj, person)
- obj = self.Person.objects(name__exact='Guido van rossum').first()
+ obj = self.Person.objects(name__exact="Guido van rossum").first()
self.assertEqual(obj, None)
- obj = self.Person.objects(name__exact='Guido van Rossu').first()
+ obj = self.Person.objects(name__exact="Guido van Rossu").first()
self.assertEqual(obj, None)
# Test iexact
- obj = self.Person.objects(name__iexact='gUIDO VAN rOSSUM').first()
+ obj = self.Person.objects(name__iexact="gUIDO VAN rOSSUM").first()
self.assertEqual(obj, person)
- obj = self.Person.objects(name__iexact='gUIDO VAN rOSSU').first()
+ obj = self.Person.objects(name__iexact="gUIDO VAN rOSSU").first()
self.assertEqual(obj, None)
# Test unsafe expressions
- person = self.Person(name='Guido van Rossum [.\'Geek\']')
+ person = self.Person(name="Guido van Rossum [.'Geek']")
person.save()
- obj = self.Person.objects(name__icontains='[.\'Geek').first()
+ obj = self.Person.objects(name__icontains="[.'Geek").first()
self.assertEqual(obj, person)
def test_not(self):
"""Ensure that the __not operator works as expected.
"""
- alice = self.Person(name='Alice', age=25)
+ alice = self.Person(name="Alice", age=25)
alice.save()
- obj = self.Person.objects(name__iexact='alice').first()
+ obj = self.Person.objects(name__iexact="alice").first()
self.assertEqual(obj, alice)
- obj = self.Person.objects(name__not__iexact='alice').first()
+ obj = self.Person.objects(name__not__iexact="alice").first()
self.assertEqual(obj, None)
def test_filter_chaining(self):
"""Ensure filters can be chained together.
"""
+
class Blog(Document):
id = StringField(primary_key=True)
@@ -1217,25 +1233,26 @@ class QuerySetTest(unittest.TestCase):
blog=blog_1,
title="Blog Post #1",
is_published=True,
- published_date=datetime.datetime(2010, 1, 5, 0, 0, 0)
+ published_date=datetime.datetime(2010, 1, 5, 0, 0, 0),
)
BlogPost.objects.create(
blog=blog_2,
title="Blog Post #2",
is_published=True,
- published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
+ published_date=datetime.datetime(2010, 1, 6, 0, 0, 0),
)
BlogPost.objects.create(
blog=blog_3,
title="Blog Post #3",
is_published=True,
- published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
+ published_date=datetime.datetime(2010, 1, 7, 0, 0, 0),
)
# find all published blog posts before 2010-01-07
published_posts = BlogPost.published()
published_posts = published_posts.filter(
- published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0))
+ published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0)
+ )
self.assertEqual(published_posts.count(), 2)
blog_posts = BlogPost.objects
@@ -1247,11 +1264,11 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection()
def test_filter_chaining_with_regex(self):
- person = self.Person(name='Guido van Rossum')
+ person = self.Person(name="Guido van Rossum")
person.save()
people = self.Person.objects
- people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum')
+ people = people.filter(name__startswith="Gui").filter(name__not__endswith="tum")
self.assertEqual(people.count(), 1)
def assertSequence(self, qs, expected):
@@ -1264,27 +1281,23 @@ class QuerySetTest(unittest.TestCase):
def test_ordering(self):
"""Ensure default ordering is applied and can be overridden.
"""
+
class BlogPost(Document):
title = StringField()
published_date = DateTimeField()
- meta = {
- 'ordering': ['-published_date']
- }
+ meta = {"ordering": ["-published_date"]}
BlogPost.drop_collection()
blog_post_1 = BlogPost.objects.create(
- title="Blog Post #1",
- published_date=datetime.datetime(2010, 1, 5, 0, 0, 0)
+ title="Blog Post #1", published_date=datetime.datetime(2010, 1, 5, 0, 0, 0)
)
blog_post_2 = BlogPost.objects.create(
- title="Blog Post #2",
- published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
+ title="Blog Post #2", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
)
blog_post_3 = BlogPost.objects.create(
- title="Blog Post #3",
- published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
+ title="Blog Post #3", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
)
# get the "first" BlogPost using default ordering
@@ -1307,39 +1320,35 @@ class QuerySetTest(unittest.TestCase):
title = StringField()
published_date = DateTimeField()
- meta = {
- 'ordering': ['-published_date']
- }
+ meta = {"ordering": ["-published_date"]}
BlogPost.drop_collection()
# default ordering should be used by default
with db_ops_tracker() as q:
- BlogPost.objects.filter(title='whatever').first()
+ BlogPost.objects.filter(title="whatever").first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
- q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY],
- {'published_date': -1}
+ q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": -1}
)
# calling order_by() should clear the default ordering
with db_ops_tracker() as q:
- BlogPost.objects.filter(title='whatever').order_by().first()
+ BlogPost.objects.filter(title="whatever").order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
# calling an explicit order_by should use a specified sort
with db_ops_tracker() as q:
- BlogPost.objects.filter(title='whatever').order_by('published_date').first()
+ BlogPost.objects.filter(title="whatever").order_by("published_date").first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
- q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY],
- {'published_date': 1}
+ q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": 1}
)
# calling order_by() after an explicit sort should clear it
with db_ops_tracker() as q:
- qs = BlogPost.objects.filter(title='whatever').order_by('published_date')
+ qs = BlogPost.objects.filter(title="whatever").order_by("published_date")
qs.order_by().first()
self.assertEqual(len(q.get_ops()), 1)
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
@@ -1353,21 +1362,20 @@ class QuerySetTest(unittest.TestCase):
title = StringField()
published_date = DateTimeField()
- meta = {
- 'ordering': ['-published_date']
- }
+ meta = {"ordering": ["-published_date"]}
BlogPost.objects.create(
- title='whatever', published_date=datetime.datetime.utcnow())
+ title="whatever", published_date=datetime.datetime.utcnow()
+ )
with db_ops_tracker() as q:
- BlogPost.objects.get(title='whatever')
+ BlogPost.objects.get(title="whatever")
self.assertEqual(len(q.get_ops()), 1)
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
# Ordering should be ignored for .get even if we set it explicitly
with db_ops_tracker() as q:
- BlogPost.objects.order_by('-title').get(title='whatever')
+ BlogPost.objects.order_by("-title").get(title="whatever")
self.assertEqual(len(q.get_ops()), 1)
self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY])
@@ -1375,6 +1383,7 @@ class QuerySetTest(unittest.TestCase):
"""Ensure that an embedded document is properly returned from
different manners of querying.
"""
+
class User(EmbeddedDocument):
name = StringField()
@@ -1384,23 +1393,20 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
- user = User(name='Test User')
- BlogPost.objects.create(
- author=user,
- content='Had a good coffee today...'
- )
+ user = User(name="Test User")
+ BlogPost.objects.create(author=user, content="Had a good coffee today...")
result = BlogPost.objects.first()
self.assertIsInstance(result.author, User)
- self.assertEqual(result.author.name, 'Test User')
+ self.assertEqual(result.author.name, "Test User")
result = BlogPost.objects.get(author__name=user.name)
self.assertIsInstance(result.author, User)
- self.assertEqual(result.author.name, 'Test User')
+ self.assertEqual(result.author.name, "Test User")
- result = BlogPost.objects.get(author={'name': user.name})
+ result = BlogPost.objects.get(author={"name": user.name})
self.assertIsInstance(result.author, User)
- self.assertEqual(result.author.name, 'Test User')
+ self.assertEqual(result.author.name, "Test User")
# Fails, since the string is not a type that is able to represent the
# author's document structure (should be dict)
@@ -1409,6 +1415,7 @@ class QuerySetTest(unittest.TestCase):
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
+
class User(EmbeddedDocument):
name = StringField()
@@ -1418,7 +1425,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
- BlogPost.objects.create(content='Anonymous post...')
+ BlogPost.objects.create(content="Anonymous post...")
result = BlogPost.objects.get(author=None)
self.assertEqual(result.author, None)
@@ -1426,15 +1433,16 @@ class QuerySetTest(unittest.TestCase):
def test_find_dict_item(self):
"""Ensure that DictField items may be found.
"""
+
class BlogPost(Document):
info = DictField()
BlogPost.drop_collection()
- post = BlogPost(info={'title': 'test'})
+ post = BlogPost(info={"title": "test"})
post.save()
- post_obj = BlogPost.objects(info__title='test').first()
+ post_obj = BlogPost.objects(info__title="test").first()
self.assertEqual(post_obj.id, post.id)
BlogPost.drop_collection()
@@ -1442,6 +1450,7 @@ class QuerySetTest(unittest.TestCase):
def test_exec_js_query(self):
"""Ensure that queries are properly formed for use in exec_js.
"""
+
class BlogPost(Document):
hits = IntField()
published = BooleanField()
@@ -1468,10 +1477,10 @@ class QuerySetTest(unittest.TestCase):
"""
# Ensure that normal queries work
- c = BlogPost.objects(published=True).exec_js(js_func, 'hits')
+ c = BlogPost.objects(published=True).exec_js(js_func, "hits")
self.assertEqual(c, 2)
- c = BlogPost.objects(published=False).exec_js(js_func, 'hits')
+ c = BlogPost.objects(published=False).exec_js(js_func, "hits")
self.assertEqual(c, 1)
BlogPost.drop_collection()
@@ -1479,22 +1488,22 @@ class QuerySetTest(unittest.TestCase):
def test_exec_js_field_sub(self):
"""Ensure that field substitutions occur properly in exec_js functions.
"""
+
class Comment(EmbeddedDocument):
- content = StringField(db_field='body')
+ content = StringField(db_field="body")
class BlogPost(Document):
- name = StringField(db_field='doc-name')
- comments = ListField(EmbeddedDocumentField(Comment),
- db_field='cmnts')
+ name = StringField(db_field="doc-name")
+ comments = ListField(EmbeddedDocumentField(Comment), db_field="cmnts")
BlogPost.drop_collection()
- comments1 = [Comment(content='cool'), Comment(content='yay')]
- post1 = BlogPost(name='post1', comments=comments1)
+ comments1 = [Comment(content="cool"), Comment(content="yay")]
+ post1 = BlogPost(name="post1", comments=comments1)
post1.save()
- comments2 = [Comment(content='nice stuff')]
- post2 = BlogPost(name='post2', comments=comments2)
+ comments2 = [Comment(content="nice stuff")]
+ post2 = BlogPost(name="post2", comments=comments2)
post2.save()
code = """
@@ -1514,16 +1523,15 @@ class QuerySetTest(unittest.TestCase):
"""
sub_code = BlogPost.objects._sub_js_fields(code)
- code_chunks = ['doc["cmnts"];', 'doc["doc-name"],',
- 'doc["cmnts"][i]["body"]']
+ code_chunks = ['doc["cmnts"];', 'doc["doc-name"],', 'doc["cmnts"][i]["body"]']
for chunk in code_chunks:
self.assertIn(chunk, sub_code)
results = BlogPost.objects.exec_js(code)
expected_results = [
- {u'comment': u'cool', u'document': u'post1'},
- {u'comment': u'yay', u'document': u'post1'},
- {u'comment': u'nice stuff', u'document': u'post2'},
+ {u"comment": u"cool", u"document": u"post1"},
+ {u"comment": u"yay", u"document": u"post1"},
+ {u"comment": u"nice stuff", u"document": u"post2"},
]
self.assertEqual(results, expected_results)
@@ -1552,55 +1560,60 @@ class QuerySetTest(unittest.TestCase):
def test_reverse_delete_rule_cascade(self):
"""Ensure cascading deletion of referring documents from the database.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
+
BlogPost.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- BlogPost(content='Watching TV', author=me).save()
- BlogPost(content='Chilling out', author=me).save()
- BlogPost(content='Pro Testing', author=someoneelse).save()
+ BlogPost(content="Watching TV", author=me).save()
+ BlogPost(content="Chilling out", author=me).save()
+ BlogPost(content="Pro Testing", author=someoneelse).save()
self.assertEqual(3, BlogPost.objects.count())
- self.Person.objects(name='Test User').delete()
+ self.Person.objects(name="Test User").delete()
self.assertEqual(1, BlogPost.objects.count())
def test_reverse_delete_rule_cascade_on_abstract_document(self):
"""Ensure cascading deletion of referring documents from the database
does not fail on abstract document.
"""
+
class AbstractBlogPost(Document):
- meta = {'abstract': True}
+ meta = {"abstract": True}
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
class BlogPost(AbstractBlogPost):
content = StringField()
+
BlogPost.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- BlogPost(content='Watching TV', author=me).save()
- BlogPost(content='Chilling out', author=me).save()
- BlogPost(content='Pro Testing', author=someoneelse).save()
+ BlogPost(content="Watching TV", author=me).save()
+ BlogPost(content="Chilling out", author=me).save()
+ BlogPost(content="Pro Testing", author=someoneelse).save()
self.assertEqual(3, BlogPost.objects.count())
- self.Person.objects(name='Test User').delete()
+ self.Person.objects(name="Test User").delete()
self.assertEqual(1, BlogPost.objects.count())
def test_reverse_delete_rule_cascade_cycle(self):
"""Ensure reference cascading doesn't loop if reference graph isn't
a tree
"""
+
class Dummy(Document):
- reference = ReferenceField('self', reverse_delete_rule=CASCADE)
+ reference = ReferenceField("self", reverse_delete_rule=CASCADE)
base = Dummy().save()
other = Dummy(reference=base).save()
@@ -1616,14 +1629,15 @@ class QuerySetTest(unittest.TestCase):
"""Ensure reference cascading doesn't loop if reference graph isn't
a tree
"""
+
class Category(Document):
name = StringField()
class Dummy(Document):
- reference = ReferenceField('self', reverse_delete_rule=CASCADE)
+ reference = ReferenceField("self", reverse_delete_rule=CASCADE)
cat = ReferenceField(Category, reverse_delete_rule=CASCADE)
- cat = Category(name='cat').save()
+ cat = Category(name="cat").save()
base = Dummy(cat=cat).save()
other = Dummy(reference=base).save()
other2 = Dummy(reference=other).save()
@@ -1640,24 +1654,25 @@ class QuerySetTest(unittest.TestCase):
"""Ensure self-referencing CASCADE deletes do not result in infinite
loop
"""
+
class Category(Document):
name = StringField()
- parent = ReferenceField('self', reverse_delete_rule=CASCADE)
+ parent = ReferenceField("self", reverse_delete_rule=CASCADE)
Category.drop_collection()
num_children = 3
- base = Category(name='Root')
+ base = Category(name="Root")
base.save()
# Create a simple parent-child tree
for i in range(num_children):
- child_name = 'Child-%i' % i
+ child_name = "Child-%i" % i
child = Category(name=child_name, parent=base)
child.save()
for i in range(num_children):
- child_child_name = 'Child-Child-%i' % i
+ child_child_name = "Child-Child-%i" % i
child_child = Category(name=child_child_name, parent=child)
child_child.save()
@@ -1673,6 +1688,7 @@ class QuerySetTest(unittest.TestCase):
def test_reverse_delete_rule_nullify(self):
"""Ensure nullification of references to deleted documents.
"""
+
class Category(Document):
name = StringField()
@@ -1683,14 +1699,14 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
Category.drop_collection()
- lameness = Category(name='Lameness')
+ lameness = Category(name="Lameness")
lameness.save()
- post = BlogPost(content='Watching TV', category=lameness)
+ post = BlogPost(content="Watching TV", category=lameness)
post.save()
self.assertEqual(1, BlogPost.objects.count())
- self.assertEqual('Lameness', BlogPost.objects.first().category.name)
+ self.assertEqual("Lameness", BlogPost.objects.first().category.name)
Category.objects.delete()
self.assertEqual(1, BlogPost.objects.count())
self.assertEqual(None, BlogPost.objects.first().category)
@@ -1699,24 +1715,26 @@ class QuerySetTest(unittest.TestCase):
"""Ensure nullification of references to deleted documents when
reference is on an abstract document.
"""
+
class AbstractBlogPost(Document):
- meta = {'abstract': True}
+ meta = {"abstract": True}
author = ReferenceField(self.Person, reverse_delete_rule=NULLIFY)
class BlogPost(AbstractBlogPost):
content = StringField()
+
BlogPost.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- BlogPost(content='Watching TV', author=me).save()
+ BlogPost(content="Watching TV", author=me).save()
self.assertEqual(1, BlogPost.objects.count())
self.assertEqual(me, BlogPost.objects.first().author)
- self.Person.objects(name='Test User').delete()
+ self.Person.objects(name="Test User").delete()
self.assertEqual(1, BlogPost.objects.count())
self.assertEqual(None, BlogPost.objects.first().author)
@@ -1724,6 +1742,7 @@ class QuerySetTest(unittest.TestCase):
"""Ensure deletion gets denied on documents that still have references
to them.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=DENY)
@@ -1731,10 +1750,10 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
self.Person.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- post = BlogPost(content='Watching TV', author=me)
+ post = BlogPost(content="Watching TV", author=me)
post.save()
self.assertRaises(OperationError, self.Person.objects.delete)
@@ -1743,18 +1762,20 @@ class QuerySetTest(unittest.TestCase):
"""Ensure deletion gets denied on documents that still have references
to them, when reference is on an abstract document.
"""
+
class AbstractBlogPost(Document):
- meta = {'abstract': True}
+ meta = {"abstract": True}
author = ReferenceField(self.Person, reverse_delete_rule=DENY)
class BlogPost(AbstractBlogPost):
content = StringField()
+
BlogPost.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- BlogPost(content='Watching TV', author=me).save()
+ BlogPost(content="Watching TV", author=me).save()
self.assertEqual(1, BlogPost.objects.count())
self.assertRaises(OperationError, self.Person.objects.delete)
@@ -1762,24 +1783,24 @@ class QuerySetTest(unittest.TestCase):
def test_reverse_delete_rule_pull(self):
"""Ensure pulling of references to deleted documents.
"""
+
class BlogPost(Document):
content = StringField()
- authors = ListField(ReferenceField(self.Person,
- reverse_delete_rule=PULL))
+ authors = ListField(ReferenceField(self.Person, reverse_delete_rule=PULL))
BlogPost.drop_collection()
self.Person.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- post = BlogPost(content='Watching TV', authors=[me, someoneelse])
+ post = BlogPost(content="Watching TV", authors=[me, someoneelse])
post.save()
- another = BlogPost(content='Chilling Out', authors=[someoneelse])
+ another = BlogPost(content="Chilling Out", authors=[someoneelse])
another.save()
someoneelse.delete()
@@ -1793,10 +1814,10 @@ class QuerySetTest(unittest.TestCase):
"""Ensure pulling of references to deleted documents when reference
is defined on an abstract document..
"""
+
class AbstractBlogPost(Document):
- meta = {'abstract': True}
- authors = ListField(ReferenceField(self.Person,
- reverse_delete_rule=PULL))
+ meta = {"abstract": True}
+ authors = ListField(ReferenceField(self.Person, reverse_delete_rule=PULL))
class BlogPost(AbstractBlogPost):
content = StringField()
@@ -1804,16 +1825,16 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
self.Person.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- post = BlogPost(content='Watching TV', authors=[me, someoneelse])
+ post = BlogPost(content="Watching TV", authors=[me, someoneelse])
post.save()
- another = BlogPost(content='Chilling Out', authors=[someoneelse])
+ another = BlogPost(content="Chilling Out", authors=[someoneelse])
another.save()
someoneelse.delete()
@@ -1824,7 +1845,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(another.authors, [])
def test_delete_with_limits(self):
-
class Log(Document):
pass
@@ -1839,19 +1859,21 @@ class QuerySetTest(unittest.TestCase):
def test_delete_with_limit_handles_delete_rules(self):
"""Ensure cascading deletion of referring documents from the database.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
+
BlogPost.drop_collection()
- me = self.Person(name='Test User')
+ me = self.Person(name="Test User")
me.save()
- someoneelse = self.Person(name='Some-one Else')
+ someoneelse = self.Person(name="Some-one Else")
someoneelse.save()
- BlogPost(content='Watching TV', author=me).save()
- BlogPost(content='Chilling out', author=me).save()
- BlogPost(content='Pro Testing', author=someoneelse).save()
+ BlogPost(content="Watching TV", author=me).save()
+ BlogPost(content="Chilling out", author=me).save()
+ BlogPost(content="Pro Testing", author=someoneelse).save()
self.assertEqual(3, BlogPost.objects.count())
self.Person.objects()[:1].delete()
@@ -1870,6 +1892,7 @@ class QuerySetTest(unittest.TestCase):
def test_reference_field_find(self):
"""Ensure cascading deletion of referring documents from the database.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person)
@@ -1877,7 +1900,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
self.Person.drop_collection()
- me = self.Person(name='Test User').save()
+ me = self.Person(name="Test User").save()
BlogPost(content="test 123", author=me).save()
self.assertEqual(1, BlogPost.objects(author=me).count())
@@ -1886,12 +1909,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
- self.assertEqual(
- 1, BlogPost.objects(author__in=["%s" % me.pk]).count())
+ self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_reference_field_find_dbref(self):
"""Ensure cascading deletion of referring documents from the database.
"""
+
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, dbref=True)
@@ -1899,7 +1922,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
self.Person.drop_collection()
- me = self.Person(name='Test User').save()
+ me = self.Person(name="Test User").save()
BlogPost(content="test 123", author=me).save()
self.assertEqual(1, BlogPost.objects(author=me).count())
@@ -1908,8 +1931,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
- self.assertEqual(
- 1, BlogPost.objects(author__in=["%s" % me.pk]).count())
+ self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
def test_update_intfield_operator(self):
class BlogPost(Document):
@@ -1946,7 +1968,7 @@ class QuerySetTest(unittest.TestCase):
post = BlogPost(review=3.5)
post.save()
- BlogPost.objects.update_one(inc__review=0.1) # test with floats
+ BlogPost.objects.update_one(inc__review=0.1) # test with floats
post.reload()
self.assertEqual(float(post.review), 3.6)
@@ -1954,7 +1976,7 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertEqual(float(post.review), 3.5)
- BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
+ BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal
post.reload()
self.assertEqual(float(post.review), 3.62)
@@ -1972,38 +1994,39 @@ class QuerySetTest(unittest.TestCase):
post.save()
with self.assertRaises(OperationError):
- BlogPost.objects.update_one(inc__review=0.1) # test with floats
+ BlogPost.objects.update_one(inc__review=0.1) # test with floats
def test_update_listfield_operator(self):
"""Ensure that atomic updates work properly.
"""
+
class BlogPost(Document):
tags = ListField(StringField())
BlogPost.drop_collection()
- post = BlogPost(tags=['test'])
+ post = BlogPost(tags=["test"])
post.save()
# ListField operator
- BlogPost.objects.update(push__tags='mongo')
+ BlogPost.objects.update(push__tags="mongo")
post.reload()
- self.assertIn('mongo', post.tags)
+ self.assertIn("mongo", post.tags)
- BlogPost.objects.update_one(push_all__tags=['db', 'nosql'])
+ BlogPost.objects.update_one(push_all__tags=["db", "nosql"])
post.reload()
- self.assertIn('db', post.tags)
- self.assertIn('nosql', post.tags)
+ self.assertIn("db", post.tags)
+ self.assertIn("nosql", post.tags)
tags = post.tags[:-1]
BlogPost.objects.update(pop__tags=1)
post.reload()
self.assertEqual(post.tags, tags)
- BlogPost.objects.update_one(add_to_set__tags='unique')
- BlogPost.objects.update_one(add_to_set__tags='unique')
+ BlogPost.objects.update_one(add_to_set__tags="unique")
+ BlogPost.objects.update_one(add_to_set__tags="unique")
post.reload()
- self.assertEqual(post.tags.count('unique'), 1)
+ self.assertEqual(post.tags.count("unique"), 1)
BlogPost.drop_collection()
@@ -2013,18 +2036,19 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
- post = BlogPost(title='garbage').save()
+ post = BlogPost(title="garbage").save()
self.assertNotEqual(post.title, None)
BlogPost.objects.update_one(unset__title=1)
post.reload()
self.assertEqual(post.title, None)
pymongo_doc = BlogPost.objects.as_pymongo().first()
- self.assertNotIn('title', pymongo_doc)
+ self.assertNotIn("title", pymongo_doc)
def test_update_push_with_position(self):
"""Ensure that the 'push' update with position works properly.
"""
+
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
@@ -2036,20 +2060,21 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.filter(id=post.id).update(push__tags="code")
BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"])
post.reload()
- self.assertEqual(post.tags, ['mongodb', 'python', 'code'])
+ self.assertEqual(post.tags, ["mongodb", "python", "code"])
BlogPost.objects.filter(id=post.id).update(set__tags__2="java")
post.reload()
- self.assertEqual(post.tags, ['mongodb', 'python', 'java'])
+ self.assertEqual(post.tags, ["mongodb", "python", "java"])
# test push with singular value
- BlogPost.objects.filter(id=post.id).update(push__tags__0='scala')
+ BlogPost.objects.filter(id=post.id).update(push__tags__0="scala")
post.reload()
- self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java'])
+ self.assertEqual(post.tags, ["scala", "mongodb", "python", "java"])
def test_update_push_list_of_list(self):
"""Ensure that the 'push' update operation works in the list of list
"""
+
class BlogPost(Document):
slug = StringField()
tags = ListField()
@@ -2065,6 +2090,7 @@ class QuerySetTest(unittest.TestCase):
def test_update_push_and_pull_add_to_set(self):
"""Ensure that the 'pull' update operation works correctly.
"""
+
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
@@ -2078,8 +2104,7 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertEqual(post.tags, ["code"])
- BlogPost.objects.filter(id=post.id).update(
- push_all__tags=["mongodb", "code"])
+ BlogPost.objects.filter(id=post.id).update(push_all__tags=["mongodb", "code"])
post.reload()
self.assertEqual(post.tags, ["code", "mongodb", "code"])
@@ -2087,13 +2112,13 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertEqual(post.tags, ["mongodb"])
- BlogPost.objects(slug="test").update(
- pull_all__tags=["mongodb", "code"])
+ BlogPost.objects(slug="test").update(pull_all__tags=["mongodb", "code"])
post.reload()
self.assertEqual(post.tags, [])
BlogPost.objects(slug="test").update(
- __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}})
+ __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}}
+ )
post.reload()
self.assertEqual(post.tags, ["code", "mongodb"])
@@ -2101,13 +2126,13 @@ class QuerySetTest(unittest.TestCase):
class Item(Document):
name = StringField(required=True)
description = StringField(max_length=50)
- parents = ListField(ReferenceField('self'))
+ parents = ListField(ReferenceField("self"))
Item.drop_collection()
- item = Item(name='test item').save()
- parent_1 = Item(name='parent 1').save()
- parent_2 = Item(name='parent 2').save()
+ item = Item(name="test item").save()
+ parent_1 = Item(name="parent 1").save()
+ parent_2 = Item(name="parent 2").save()
item.update(add_to_set__parents=[parent_1, parent_2, parent_1])
item.reload()
@@ -2115,12 +2140,11 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([parent_1, parent_2], item.parents)
def test_pull_nested(self):
-
class Collaborator(EmbeddedDocument):
user = StringField()
def __unicode__(self):
- return '%s' % self.user
+ return "%s" % self.user
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
@@ -2128,23 +2152,21 @@ class QuerySetTest(unittest.TestCase):
Site.drop_collection()
- c = Collaborator(user='Esteban')
+ c = Collaborator(user="Esteban")
s = Site(name="test", collaborators=[c]).save()
- Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
+ Site.objects(id=s.id).update_one(pull__collaborators__user="Esteban")
self.assertEqual(Site.objects.first().collaborators, [])
with self.assertRaises(InvalidQueryError):
- Site.objects(id=s.id).update_one(
- pull_all__collaborators__user=['Ross'])
+ Site.objects(id=s.id).update_one(pull_all__collaborators__user=["Ross"])
def test_pull_from_nested_embedded(self):
-
class User(EmbeddedDocument):
name = StringField()
def __unicode__(self):
- return '%s' % self.name
+ return "%s" % self.name
class Collaborator(EmbeddedDocument):
helpful = ListField(EmbeddedDocumentField(User))
@@ -2156,21 +2178,24 @@ class QuerySetTest(unittest.TestCase):
Site.drop_collection()
- c = User(name='Esteban')
- f = User(name='Frank')
- s = Site(name="test", collaborators=Collaborator(
- helpful=[c], unhelpful=[f])).save()
+ c = User(name="Esteban")
+ f = User(name="Frank")
+ s = Site(
+ name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])
+ ).save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful=c)
- self.assertEqual(Site.objects.first().collaborators['helpful'], [])
+ self.assertEqual(Site.objects.first().collaborators["helpful"], [])
Site.objects(id=s.id).update_one(
- pull__collaborators__unhelpful={'name': 'Frank'})
- self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
+ pull__collaborators__unhelpful={"name": "Frank"}
+ )
+ self.assertEqual(Site.objects.first().collaborators["unhelpful"], [])
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
- pull_all__collaborators__helpful__name=['Ross'])
+ pull_all__collaborators__helpful__name=["Ross"]
+ )
def test_pull_from_nested_embedded_using_in_nin(self):
"""Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators.
@@ -2180,7 +2205,7 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
def __unicode__(self):
- return '%s' % self.name
+ return "%s" % self.name
class Collaborator(EmbeddedDocument):
helpful = ListField(EmbeddedDocumentField(User))
@@ -2192,60 +2217,62 @@ class QuerySetTest(unittest.TestCase):
Site.drop_collection()
- a = User(name='Esteban')
- b = User(name='Frank')
- x = User(name='Harry')
- y = User(name='John')
+ a = User(name="Esteban")
+ b = User(name="Frank")
+ x = User(name="Harry")
+ y = User(name="John")
- s = Site(name="test", collaborators=Collaborator(
- helpful=[a, b], unhelpful=[x, y])).save()
+ s = Site(
+ name="test", collaborators=Collaborator(helpful=[a, b], unhelpful=[x, y])
+ ).save()
- Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a
- self.assertEqual(Site.objects.first().collaborators['helpful'], [b])
+ Site.objects(id=s.id).update_one(
+ pull__collaborators__helpful__name__in=["Esteban"]
+ ) # Pull a
+ self.assertEqual(Site.objects.first().collaborators["helpful"], [b])
- Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x
- self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y])
+ Site.objects(id=s.id).update_one(
+ pull__collaborators__unhelpful__name__nin=["John"]
+ ) # Pull x
+ self.assertEqual(Site.objects.first().collaborators["unhelpful"], [y])
def test_pull_from_nested_mapfield(self):
-
class Collaborator(EmbeddedDocument):
user = StringField()
def __unicode__(self):
- return '%s' % self.user
+ return "%s" % self.user
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
- collaborators = MapField(
- ListField(EmbeddedDocumentField(Collaborator)))
+ collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator)))
Site.drop_collection()
- c = Collaborator(user='Esteban')
- f = Collaborator(user='Frank')
- s = Site(name="test", collaborators={'helpful': [c], 'unhelpful': [f]})
+ c = Collaborator(user="Esteban")
+ f = Collaborator(user="Frank")
+ s = Site(name="test", collaborators={"helpful": [c], "unhelpful": [f]})
s.save()
- Site.objects(id=s.id).update_one(
- pull__collaborators__helpful__user='Esteban')
- self.assertEqual(Site.objects.first().collaborators['helpful'], [])
+ Site.objects(id=s.id).update_one(pull__collaborators__helpful__user="Esteban")
+ self.assertEqual(Site.objects.first().collaborators["helpful"], [])
Site.objects(id=s.id).update_one(
- pull__collaborators__unhelpful={'user': 'Frank'})
- self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
+ pull__collaborators__unhelpful={"user": "Frank"}
+ )
+ self.assertEqual(Site.objects.first().collaborators["unhelpful"], [])
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
- pull_all__collaborators__helpful__user=['Ross'])
+ pull_all__collaborators__helpful__user=["Ross"]
+ )
def test_pull_in_genericembedded_field(self):
-
class Foo(EmbeddedDocument):
name = StringField()
class Bar(Document):
- foos = ListField(GenericEmbeddedDocumentField(
- choices=[Foo, ]))
+ foos = ListField(GenericEmbeddedDocumentField(choices=[Foo]))
Bar.drop_collection()
@@ -2261,15 +2288,14 @@ class QuerySetTest(unittest.TestCase):
BlogTag.drop_collection()
- BlogTag(name='garbage').save()
- default_update = BlogTag.objects.update_one(name='new')
+ BlogTag(name="garbage").save()
+ default_update = BlogTag.objects.update_one(name="new")
self.assertEqual(default_update, 1)
- full_result_update = BlogTag.objects.update_one(name='new', full_result=True)
+ full_result_update = BlogTag.objects.update_one(name="new", full_result=True)
self.assertIsInstance(full_result_update, UpdateResult)
def test_update_one_pop_generic_reference(self):
-
class BlogTag(Document):
name = StringField(required=True)
@@ -2280,9 +2306,9 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
BlogTag.drop_collection()
- tag_1 = BlogTag(name='code')
+ tag_1 = BlogTag(name="code")
tag_1.save()
- tag_2 = BlogTag(name='mongodb')
+ tag_2 = BlogTag(name="mongodb")
tag_2.save()
post = BlogPost(slug="test", tags=[tag_1])
@@ -2301,7 +2327,6 @@ class QuerySetTest(unittest.TestCase):
BlogTag.drop_collection()
def test_editting_embedded_objects(self):
-
class BlogTag(EmbeddedDocument):
name = StringField(required=True)
@@ -2311,8 +2336,8 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
- tag_1 = BlogTag(name='code')
- tag_2 = BlogTag(name='mongodb')
+ tag_1 = BlogTag(name="code")
+ tag_2 = BlogTag(name="mongodb")
post = BlogPost(slug="test", tags=[tag_1])
post.save()
@@ -2323,7 +2348,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects(slug="test-2").update_one(set__tags__0__name="python")
post.reload()
- self.assertEqual(post.tags[0].name, 'python')
+ self.assertEqual(post.tags[0].name, "python")
BlogPost.objects(slug="test-2").update_one(pop__tags=-1)
post.reload()
@@ -2332,13 +2357,12 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
def test_set_list_embedded_documents(self):
-
class Author(EmbeddedDocument):
name = StringField()
class Message(Document):
title = StringField()
- authors = ListField(EmbeddedDocumentField('Author'))
+ authors = ListField(EmbeddedDocumentField("Author"))
Message.drop_collection()
@@ -2346,15 +2370,19 @@ class QuerySetTest(unittest.TestCase):
message.save()
Message.objects(authors__name="Harry").update_one(
- set__authors__S=Author(name="Ross"))
+ set__authors__S=Author(name="Ross")
+ )
message = message.reload()
self.assertEqual(message.authors[0].name, "Ross")
Message.objects(authors__name="Ross").update_one(
- set__authors=[Author(name="Harry"),
- Author(name="Ross"),
- Author(name="Adam")])
+ set__authors=[
+ Author(name="Harry"),
+ Author(name="Ross"),
+ Author(name="Adam"),
+ ]
+ )
message = message.reload()
self.assertEqual(message.authors[0].name, "Harry")
@@ -2362,7 +2390,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(message.authors[2].name, "Adam")
def test_set_generic_embedded_documents(self):
-
class Bar(EmbeddedDocument):
name = StringField()
@@ -2372,15 +2399,13 @@ class QuerySetTest(unittest.TestCase):
User.drop_collection()
- User(username='abc').save()
- User.objects(username='abc').update(
- set__bar=Bar(name='test'), upsert=True)
+ User(username="abc").save()
+ User.objects(username="abc").update(set__bar=Bar(name="test"), upsert=True)
- user = User.objects(username='abc').first()
+ user = User.objects(username="abc").first()
self.assertEqual(user.bar.name, "test")
def test_reload_embedded_docs_instance(self):
-
class SubDoc(EmbeddedDocument):
val = IntField()
@@ -2393,7 +2418,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(doc.pk, doc.embedded._instance.pk)
def test_reload_list_embedded_docs_instance(self):
-
class SubDoc(EmbeddedDocument):
val = IntField()
@@ -2412,16 +2436,16 @@ class QuerySetTest(unittest.TestCase):
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save()
- names = [p.name for p in self.Person.objects.order_by('-age')]
- self.assertEqual(names, ['User B', 'User C', 'User A'])
+ names = [p.name for p in self.Person.objects.order_by("-age")]
+ self.assertEqual(names, ["User B", "User C", "User A"])
- names = [p.name for p in self.Person.objects.order_by('+age')]
- self.assertEqual(names, ['User A', 'User C', 'User B'])
+ names = [p.name for p in self.Person.objects.order_by("+age")]
+ self.assertEqual(names, ["User A", "User C", "User B"])
- names = [p.name for p in self.Person.objects.order_by('age')]
- self.assertEqual(names, ['User A', 'User C', 'User B'])
+ names = [p.name for p in self.Person.objects.order_by("age")]
+ self.assertEqual(names, ["User A", "User C", "User B"])
- ages = [p.age for p in self.Person.objects.order_by('-name')]
+ ages = [p.age for p in self.Person.objects.order_by("-name")]
self.assertEqual(ages, [30, 40, 20])
def test_order_by_optional(self):
@@ -2432,31 +2456,22 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
blog_post_3 = BlogPost.objects.create(
- title="Blog Post #3",
- published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
+ title="Blog Post #3", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
)
blog_post_2 = BlogPost.objects.create(
- title="Blog Post #2",
- published_date=datetime.datetime(2010, 1, 5, 0, 0, 0)
+ title="Blog Post #2", published_date=datetime.datetime(2010, 1, 5, 0, 0, 0)
)
blog_post_4 = BlogPost.objects.create(
- title="Blog Post #4",
- published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
- )
- blog_post_1 = BlogPost.objects.create(
- title="Blog Post #1",
- published_date=None
+ title="Blog Post #4", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
)
+ blog_post_1 = BlogPost.objects.create(title="Blog Post #1", published_date=None)
expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4]
- self.assertSequence(BlogPost.objects.order_by('published_date'),
- expected)
- self.assertSequence(BlogPost.objects.order_by('+published_date'),
- expected)
+ self.assertSequence(BlogPost.objects.order_by("published_date"), expected)
+ self.assertSequence(BlogPost.objects.order_by("+published_date"), expected)
expected.reverse()
- self.assertSequence(BlogPost.objects.order_by('-published_date'),
- expected)
+ self.assertSequence(BlogPost.objects.order_by("-published_date"), expected)
def test_order_by_list(self):
class BlogPost(Document):
@@ -2466,23 +2481,20 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
blog_post_1 = BlogPost.objects.create(
- title="A",
- published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
+ title="A", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
)
blog_post_2 = BlogPost.objects.create(
- title="B",
- published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
+ title="B", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0)
)
blog_post_3 = BlogPost.objects.create(
- title="C",
- published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
+ title="C", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0)
)
- qs = BlogPost.objects.order_by('published_date', 'title')
+ qs = BlogPost.objects.order_by("published_date", "title")
expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
- qs = BlogPost.objects.order_by('-published_date', '-title')
+ qs = BlogPost.objects.order_by("-published_date", "-title")
expected.reverse()
self.assertSequence(qs, expected)
@@ -2493,7 +2505,7 @@ class QuerySetTest(unittest.TestCase):
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save()
- only_age = self.Person.objects.order_by('-age').only('age')
+ only_age = self.Person.objects.order_by("-age").only("age")
names = [p.name for p in only_age]
ages = [p.age for p in only_age]
@@ -2502,19 +2514,19 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(names, [None, None, None])
self.assertEqual(ages, [40, 30, 20])
- qs = self.Person.objects.all().order_by('-age')
+ qs = self.Person.objects.all().order_by("-age")
qs = qs.limit(10)
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().limit(10)
- qs = qs.order_by('-age')
+ qs = qs.order_by("-age")
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().skip(0)
- qs = qs.order_by('-age')
+ qs = qs.order_by("-age")
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
@@ -2538,47 +2550,47 @@ class QuerySetTest(unittest.TestCase):
Author(author=person_b).save()
Author(author=person_c).save()
- names = [
- a.author.name for a in Author.objects.order_by('-author__age')]
- self.assertEqual(names, ['User A', 'User B', 'User C'])
+ names = [a.author.name for a in Author.objects.order_by("-author__age")]
+ self.assertEqual(names, ["User A", "User B", "User C"])
def test_comment(self):
"""Make sure adding a comment to the query gets added to the query"""
MONGO_VER = self.mongodb_version
_, CMD_QUERY_KEY = get_key_compat(MONGO_VER)
- QUERY_KEY = 'filter'
- COMMENT_KEY = 'comment'
+ QUERY_KEY = "filter"
+ COMMENT_KEY = "comment"
class User(Document):
age = IntField()
with db_ops_tracker() as q:
- adult1 = (User.objects.filter(age__gte=18)
- .comment('looking for an adult')
- .first())
+ adult1 = (
+ User.objects.filter(age__gte=18).comment("looking for an adult").first()
+ )
- adult2 = (User.objects.comment('looking for an adult')
- .filter(age__gte=18)
- .first())
+ adult2 = (
+ User.objects.comment("looking for an adult").filter(age__gte=18).first()
+ )
ops = q.get_ops()
self.assertEqual(len(ops), 2)
for op in ops:
- self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {'age': {'$gte': 18}})
- self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], 'looking for an adult')
+ self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {"age": {"$gte": 18}})
+ self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], "looking for an adult")
def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing.
"""
+
class BlogPost(Document):
title = StringField()
- tags = ListField(StringField(), db_field='post-tag-list')
+ tags = ListField(StringField(), db_field="post-tag-list")
BlogPost.drop_collection()
- BlogPost(title="Post #1", tags=['music', 'film', 'print']).save()
- BlogPost(title="Post #2", tags=['music', 'film']).save()
- BlogPost(title="Post #3", tags=['film', 'photography']).save()
+ BlogPost(title="Post #1", tags=["music", "film", "print"]).save()
+ BlogPost(title="Post #2", tags=["music", "film"]).save()
+ BlogPost(title="Post #3", tags=["film", "photography"]).save()
map_f = """
function() {
@@ -2628,8 +2640,8 @@ class QuerySetTest(unittest.TestCase):
post2.save()
post3.save()
- self.assertEqual(BlogPost._fields['title'].db_field, '_id')
- self.assertEqual(BlogPost._meta['id_field'], 'title')
+ self.assertEqual(BlogPost._fields["title"].db_field, "_id")
+ self.assertEqual(BlogPost._meta["id_field"], "title")
map_f = """
function() {
@@ -2661,16 +2673,14 @@ class QuerySetTest(unittest.TestCase):
"""
Test map/reduce custom output
"""
- register_connection('test2', 'mongoenginetest2')
+ register_connection("test2", "mongoenginetest2")
class Family(Document):
- id = IntField(
- primary_key=True)
+ id = IntField(primary_key=True)
log = StringField()
class Person(Document):
- id = IntField(
- primary_key=True)
+ id = IntField(primary_key=True)
name = StringField()
age = IntField()
family = ReferenceField(Family)
@@ -2745,7 +2755,8 @@ class QuerySetTest(unittest.TestCase):
cursor = Family.objects.map_reduce(
map_f=map_family,
reduce_f=reduce_f,
- output={'replace': 'family_map', 'db_alias': 'test2'})
+ output={"replace": "family_map", "db_alias": "test2"},
+ )
# start a map/reduce
cursor.next()
@@ -2753,43 +2764,56 @@ class QuerySetTest(unittest.TestCase):
results = Person.objects.map_reduce(
map_f=map_person,
reduce_f=reduce_f,
- output={'reduce': 'family_map', 'db_alias': 'test2'})
+ output={"reduce": "family_map", "db_alias": "test2"},
+ )
results = list(results)
- collection = get_db('test2').family_map
+ collection = get_db("test2").family_map
self.assertEqual(
- collection.find_one({'_id': 1}), {
- '_id': 1,
- 'value': {
- 'persons': [
- {'age': 21, 'name': u'Wilson Jr'},
- {'age': 45, 'name': u'Wilson Father'},
- {'age': 40, 'name': u'Eliana Costa'},
- {'age': 17, 'name': u'Tayza Mariana'}],
- 'totalAge': 123}
- })
+ collection.find_one({"_id": 1}),
+ {
+ "_id": 1,
+ "value": {
+ "persons": [
+ {"age": 21, "name": u"Wilson Jr"},
+ {"age": 45, "name": u"Wilson Father"},
+ {"age": 40, "name": u"Eliana Costa"},
+ {"age": 17, "name": u"Tayza Mariana"},
+ ],
+ "totalAge": 123,
+ },
+ },
+ )
self.assertEqual(
- collection.find_one({'_id': 2}), {
- '_id': 2,
- 'value': {
- 'persons': [
- {'age': 16, 'name': u'Isabella Luanna'},
- {'age': 36, 'name': u'Sandra Mara'},
- {'age': 10, 'name': u'Igor Gabriel'}],
- 'totalAge': 62}
- })
+ collection.find_one({"_id": 2}),
+ {
+ "_id": 2,
+ "value": {
+ "persons": [
+ {"age": 16, "name": u"Isabella Luanna"},
+ {"age": 36, "name": u"Sandra Mara"},
+ {"age": 10, "name": u"Igor Gabriel"},
+ ],
+ "totalAge": 62,
+ },
+ },
+ )
self.assertEqual(
- collection.find_one({'_id': 3}), {
- '_id': 3,
- 'value': {
- 'persons': [
- {'age': 30, 'name': u'Arthur WA'},
- {'age': 25, 'name': u'Paula Leonel'}],
- 'totalAge': 55}
- })
+ collection.find_one({"_id": 3}),
+ {
+ "_id": 3,
+ "value": {
+ "persons": [
+ {"age": 30, "name": u"Arthur WA"},
+ {"age": 25, "name": u"Paula Leonel"},
+ ],
+ "totalAge": 55,
+ },
+ },
+ )
def test_map_reduce_finalize(self):
"""Ensure that map, reduce, and finalize run and introduce "scope"
@@ -2798,10 +2822,10 @@ class QuerySetTest(unittest.TestCase):
from time import mktime
class Link(Document):
- title = StringField(db_field='bpTitle')
+ title = StringField(db_field="bpTitle")
up_votes = IntField()
down_votes = IntField()
- submitted = DateTimeField(db_field='sTime')
+ submitted = DateTimeField(db_field="sTime")
Link.drop_collection()
@@ -2811,30 +2835,42 @@ class QuerySetTest(unittest.TestCase):
# Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should
# reflect order of insertion below, but is not influenced
# by insertion order.
- Link(title="Google Buzz auto-followed a woman's abusive ex ...",
- up_votes=1079,
- down_votes=553,
- submitted=now - datetime.timedelta(hours=4)).save()
- Link(title="We did it! Barbie is a computer engineer.",
- up_votes=481,
- down_votes=124,
- submitted=now - datetime.timedelta(hours=2)).save()
- Link(title="This Is A Mosquito Getting Killed By A Laser",
- up_votes=1446,
- down_votes=530,
- submitted=now - datetime.timedelta(hours=13)).save()
- Link(title="Arabic flashcards land physics student in jail.",
- up_votes=215,
- down_votes=105,
- submitted=now - datetime.timedelta(hours=6)).save()
- Link(title="The Burger Lab: Presenting, the Flood Burger",
- up_votes=48,
- down_votes=17,
- submitted=now - datetime.timedelta(hours=5)).save()
- Link(title="How to see polarization with the naked eye",
- up_votes=74,
- down_votes=13,
- submitted=now - datetime.timedelta(hours=10)).save()
+ Link(
+ title="Google Buzz auto-followed a woman's abusive ex ...",
+ up_votes=1079,
+ down_votes=553,
+ submitted=now - datetime.timedelta(hours=4),
+ ).save()
+ Link(
+ title="We did it! Barbie is a computer engineer.",
+ up_votes=481,
+ down_votes=124,
+ submitted=now - datetime.timedelta(hours=2),
+ ).save()
+ Link(
+ title="This Is A Mosquito Getting Killed By A Laser",
+ up_votes=1446,
+ down_votes=530,
+ submitted=now - datetime.timedelta(hours=13),
+ ).save()
+ Link(
+ title="Arabic flashcards land physics student in jail.",
+ up_votes=215,
+ down_votes=105,
+ submitted=now - datetime.timedelta(hours=6),
+ ).save()
+ Link(
+ title="The Burger Lab: Presenting, the Flood Burger",
+ up_votes=48,
+ down_votes=17,
+ submitted=now - datetime.timedelta(hours=5),
+ ).save()
+ Link(
+ title="How to see polarization with the naked eye",
+ up_votes=74,
+ down_votes=13,
+ submitted=now - datetime.timedelta(hours=10),
+ ).save()
map_f = """
function() {
@@ -2885,17 +2921,15 @@ class QuerySetTest(unittest.TestCase):
# provide the reddit epoch (used for ranking) as a variable available
# to all phases of the map/reduce operation: map, reduce, and finalize.
reddit_epoch = mktime(datetime.datetime(2005, 12, 8, 7, 46, 43).timetuple())
- scope = {'reddit_epoch': reddit_epoch}
+ scope = {"reddit_epoch": reddit_epoch}
# run a map/reduce operation across all links. ordering is set
# to "-value", which orders the "weight" value returned from
# "finalize_f" in descending order.
results = Link.objects.order_by("-value")
- results = results.map_reduce(map_f,
- reduce_f,
- "myresults",
- finalize_f=finalize_f,
- scope=scope)
+ results = results.map_reduce(
+ map_f, reduce_f, "myresults", finalize_f=finalize_f, scope=scope
+ )
results = list(results)
# assert troublesome Buzz article is ranked 1st
@@ -2909,54 +2943,56 @@ class QuerySetTest(unittest.TestCase):
def test_item_frequencies(self):
"""Ensure that item frequencies are properly generated from lists.
"""
+
class BlogPost(Document):
hits = IntField()
- tags = ListField(StringField(), db_field='blogTags')
+ tags = ListField(StringField(), db_field="blogTags")
BlogPost.drop_collection()
- BlogPost(hits=1, tags=['music', 'film', 'actors', 'watch']).save()
- BlogPost(hits=2, tags=['music', 'watch']).save()
- BlogPost(hits=2, tags=['music', 'actors']).save()
+ BlogPost(hits=1, tags=["music", "film", "actors", "watch"]).save()
+ BlogPost(hits=2, tags=["music", "watch"]).save()
+ BlogPost(hits=2, tags=["music", "actors"]).save()
def test_assertions(f):
f = {key: int(val) for key, val in f.items()}
- self.assertEqual(
- set(['music', 'film', 'actors', 'watch']), set(f.keys()))
- self.assertEqual(f['music'], 3)
- self.assertEqual(f['actors'], 2)
- self.assertEqual(f['watch'], 2)
- self.assertEqual(f['film'], 1)
+ self.assertEqual(set(["music", "film", "actors", "watch"]), set(f.keys()))
+ self.assertEqual(f["music"], 3)
+ self.assertEqual(f["actors"], 2)
+ self.assertEqual(f["watch"], 2)
+ self.assertEqual(f["film"], 1)
- exec_js = BlogPost.objects.item_frequencies('tags')
- map_reduce = BlogPost.objects.item_frequencies('tags', map_reduce=True)
+ exec_js = BlogPost.objects.item_frequencies("tags")
+ map_reduce = BlogPost.objects.item_frequencies("tags", map_reduce=True)
test_assertions(exec_js)
test_assertions(map_reduce)
# Ensure query is taken into account
def test_assertions(f):
f = {key: int(val) for key, val in f.items()}
- self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys()))
- self.assertEqual(f['music'], 2)
- self.assertEqual(f['actors'], 1)
- self.assertEqual(f['watch'], 1)
+ self.assertEqual(set(["music", "actors", "watch"]), set(f.keys()))
+ self.assertEqual(f["music"], 2)
+ self.assertEqual(f["actors"], 1)
+ self.assertEqual(f["watch"], 1)
- exec_js = BlogPost.objects(hits__gt=1).item_frequencies('tags')
- map_reduce = BlogPost.objects(
- hits__gt=1).item_frequencies('tags', map_reduce=True)
+ exec_js = BlogPost.objects(hits__gt=1).item_frequencies("tags")
+ map_reduce = BlogPost.objects(hits__gt=1).item_frequencies(
+ "tags", map_reduce=True
+ )
test_assertions(exec_js)
test_assertions(map_reduce)
# Check that normalization works
def test_assertions(f):
- self.assertAlmostEqual(f['music'], 3.0 / 8.0)
- self.assertAlmostEqual(f['actors'], 2.0 / 8.0)
- self.assertAlmostEqual(f['watch'], 2.0 / 8.0)
- self.assertAlmostEqual(f['film'], 1.0 / 8.0)
+ self.assertAlmostEqual(f["music"], 3.0 / 8.0)
+ self.assertAlmostEqual(f["actors"], 2.0 / 8.0)
+ self.assertAlmostEqual(f["watch"], 2.0 / 8.0)
+ self.assertAlmostEqual(f["film"], 1.0 / 8.0)
- exec_js = BlogPost.objects.item_frequencies('tags', normalize=True)
+ exec_js = BlogPost.objects.item_frequencies("tags", normalize=True)
map_reduce = BlogPost.objects.item_frequencies(
- 'tags', normalize=True, map_reduce=True)
+ "tags", normalize=True, map_reduce=True
+ )
test_assertions(exec_js)
test_assertions(map_reduce)
@@ -2966,8 +3002,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(f[1], 1)
self.assertEqual(f[2], 2)
- exec_js = BlogPost.objects.item_frequencies('hits')
- map_reduce = BlogPost.objects.item_frequencies('hits', map_reduce=True)
+ exec_js = BlogPost.objects.item_frequencies("hits")
+ map_reduce = BlogPost.objects.item_frequencies("hits", map_reduce=True)
test_assertions(exec_js)
test_assertions(map_reduce)
@@ -2987,57 +3023,56 @@ class QuerySetTest(unittest.TestCase):
Person.drop_collection()
doc = Person(name="Guido")
- doc.phone = Phone(number='62-3331-1656')
+ doc.phone = Phone(number="62-3331-1656")
doc.save()
doc = Person(name="Marr")
- doc.phone = Phone(number='62-3331-1656')
+ doc.phone = Phone(number="62-3331-1656")
doc.save()
doc = Person(name="WP Junior")
- doc.phone = Phone(number='62-3332-1656')
+ doc.phone = Phone(number="62-3332-1656")
doc.save()
def test_assertions(f):
f = {key: int(val) for key, val in f.items()}
- self.assertEqual(
- set(['62-3331-1656', '62-3332-1656']), set(f.keys()))
- self.assertEqual(f['62-3331-1656'], 2)
- self.assertEqual(f['62-3332-1656'], 1)
+ self.assertEqual(set(["62-3331-1656", "62-3332-1656"]), set(f.keys()))
+ self.assertEqual(f["62-3331-1656"], 2)
+ self.assertEqual(f["62-3332-1656"], 1)
- exec_js = Person.objects.item_frequencies('phone.number')
- map_reduce = Person.objects.item_frequencies(
- 'phone.number', map_reduce=True)
+ exec_js = Person.objects.item_frequencies("phone.number")
+ map_reduce = Person.objects.item_frequencies("phone.number", map_reduce=True)
test_assertions(exec_js)
test_assertions(map_reduce)
# Ensure query is taken into account
def test_assertions(f):
f = {key: int(val) for key, val in f.items()}
- self.assertEqual(set(['62-3331-1656']), set(f.keys()))
- self.assertEqual(f['62-3331-1656'], 2)
+ self.assertEqual(set(["62-3331-1656"]), set(f.keys()))
+ self.assertEqual(f["62-3331-1656"], 2)
- exec_js = Person.objects(
- phone__number='62-3331-1656').item_frequencies('phone.number')
- map_reduce = Person.objects(
- phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True)
+ exec_js = Person.objects(phone__number="62-3331-1656").item_frequencies(
+ "phone.number"
+ )
+ map_reduce = Person.objects(phone__number="62-3331-1656").item_frequencies(
+ "phone.number", map_reduce=True
+ )
test_assertions(exec_js)
test_assertions(map_reduce)
# Check that normalization works
def test_assertions(f):
- self.assertEqual(f['62-3331-1656'], 2.0 / 3.0)
- self.assertEqual(f['62-3332-1656'], 1.0 / 3.0)
+ self.assertEqual(f["62-3331-1656"], 2.0 / 3.0)
+ self.assertEqual(f["62-3332-1656"], 1.0 / 3.0)
- exec_js = Person.objects.item_frequencies(
- 'phone.number', normalize=True)
+ exec_js = Person.objects.item_frequencies("phone.number", normalize=True)
map_reduce = Person.objects.item_frequencies(
- 'phone.number', normalize=True, map_reduce=True)
+ "phone.number", normalize=True, map_reduce=True
+ )
test_assertions(exec_js)
test_assertions(map_reduce)
def test_item_frequencies_null_values(self):
-
class Person(Document):
name = StringField()
city = StringField()
@@ -3047,16 +3082,15 @@ class QuerySetTest(unittest.TestCase):
Person(name="Wilson Snr", city="CRB").save()
Person(name="Wilson Jr").save()
- freq = Person.objects.item_frequencies('city')
- self.assertEqual(freq, {'CRB': 1.0, None: 1.0})
- freq = Person.objects.item_frequencies('city', normalize=True)
- self.assertEqual(freq, {'CRB': 0.5, None: 0.5})
+ freq = Person.objects.item_frequencies("city")
+ self.assertEqual(freq, {"CRB": 1.0, None: 1.0})
+ freq = Person.objects.item_frequencies("city", normalize=True)
+ self.assertEqual(freq, {"CRB": 0.5, None: 0.5})
- freq = Person.objects.item_frequencies('city', map_reduce=True)
- self.assertEqual(freq, {'CRB': 1.0, None: 1.0})
- freq = Person.objects.item_frequencies(
- 'city', normalize=True, map_reduce=True)
- self.assertEqual(freq, {'CRB': 0.5, None: 0.5})
+ freq = Person.objects.item_frequencies("city", map_reduce=True)
+ self.assertEqual(freq, {"CRB": 1.0, None: 1.0})
+ freq = Person.objects.item_frequencies("city", normalize=True, map_reduce=True)
+ self.assertEqual(freq, {"CRB": 0.5, None: 0.5})
def test_item_frequencies_with_null_embedded(self):
class Data(EmbeddedDocument):
@@ -3080,11 +3114,11 @@ class QuerySetTest(unittest.TestCase):
p.extra = Extra(tag="friend")
p.save()
- ot = Person.objects.item_frequencies('extra.tag', map_reduce=False)
- self.assertEqual(ot, {None: 1.0, u'friend': 1.0})
+ ot = Person.objects.item_frequencies("extra.tag", map_reduce=False)
+ self.assertEqual(ot, {None: 1.0, u"friend": 1.0})
- ot = Person.objects.item_frequencies('extra.tag', map_reduce=True)
- self.assertEqual(ot, {None: 1.0, u'friend': 1.0})
+ ot = Person.objects.item_frequencies("extra.tag", map_reduce=True)
+ self.assertEqual(ot, {None: 1.0, u"friend": 1.0})
def test_item_frequencies_with_0_values(self):
class Test(Document):
@@ -3095,9 +3129,9 @@ class QuerySetTest(unittest.TestCase):
t.val = 0
t.save()
- ot = Test.objects.item_frequencies('val', map_reduce=True)
+ ot = Test.objects.item_frequencies("val", map_reduce=True)
self.assertEqual(ot, {0: 1})
- ot = Test.objects.item_frequencies('val', map_reduce=False)
+ ot = Test.objects.item_frequencies("val", map_reduce=False)
self.assertEqual(ot, {0: 1})
def test_item_frequencies_with_False_values(self):
@@ -3109,9 +3143,9 @@ class QuerySetTest(unittest.TestCase):
t.val = False
t.save()
- ot = Test.objects.item_frequencies('val', map_reduce=True)
+ ot = Test.objects.item_frequencies("val", map_reduce=True)
self.assertEqual(ot, {False: 1})
- ot = Test.objects.item_frequencies('val', map_reduce=False)
+ ot = Test.objects.item_frequencies("val", map_reduce=False)
self.assertEqual(ot, {False: 1})
def test_item_frequencies_normalize(self):
@@ -3126,113 +3160,108 @@ class QuerySetTest(unittest.TestCase):
for i in range(20):
Test(val=2).save()
- freqs = Test.objects.item_frequencies(
- 'val', map_reduce=False, normalize=True)
+ freqs = Test.objects.item_frequencies("val", map_reduce=False, normalize=True)
self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70})
- freqs = Test.objects.item_frequencies(
- 'val', map_reduce=True, normalize=True)
+ freqs = Test.objects.item_frequencies("val", map_reduce=True, normalize=True)
self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70})
def test_average(self):
"""Ensure that field can be averaged correctly.
"""
- self.Person(name='person', age=0).save()
- self.assertEqual(int(self.Person.objects.average('age')), 0)
+ self.Person(name="person", age=0).save()
+ self.assertEqual(int(self.Person.objects.average("age")), 0)
ages = [23, 54, 12, 94, 27]
for i, age in enumerate(ages):
- self.Person(name='test%s' % i, age=age).save()
+ self.Person(name="test%s" % i, age=age).save()
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
- self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
+ self.assertAlmostEqual(int(self.Person.objects.average("age")), avg)
- self.Person(name='ageless person').save()
- self.assertEqual(int(self.Person.objects.average('age')), avg)
+ self.Person(name="ageless person").save()
+ self.assertEqual(int(self.Person.objects.average("age")), avg)
# dot notation
- self.Person(
- name='person meta', person_meta=self.PersonMeta(weight=0)).save()
+ self.Person(name="person meta", person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(
- int(self.Person.objects.average('person_meta.weight')), 0)
+ int(self.Person.objects.average("person_meta.weight")), 0
+ )
for i, weight in enumerate(ages):
self.Person(
- name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()
+ name="test meta%i", person_meta=self.PersonMeta(weight=weight)
+ ).save()
self.assertAlmostEqual(
- int(self.Person.objects.average('person_meta.weight')), avg
+ int(self.Person.objects.average("person_meta.weight")), avg
)
- self.Person(name='test meta none').save()
- self.assertEqual(
- int(self.Person.objects.average('person_meta.weight')), avg
- )
+ self.Person(name="test meta none").save()
+ self.assertEqual(int(self.Person.objects.average("person_meta.weight")), avg)
# test summing over a filtered queryset
over_50 = [a for a in ages if a >= 50]
avg = float(sum(over_50)) / len(over_50)
- self.assertEqual(
- self.Person.objects.filter(age__gte=50).average('age'),
- avg
- )
+ self.assertEqual(self.Person.objects.filter(age__gte=50).average("age"), avg)
def test_sum(self):
"""Ensure that field can be summed over correctly.
"""
ages = [23, 54, 12, 94, 27]
for i, age in enumerate(ages):
- self.Person(name='test%s' % i, age=age).save()
+ self.Person(name="test%s" % i, age=age).save()
- self.assertEqual(self.Person.objects.sum('age'), sum(ages))
+ self.assertEqual(self.Person.objects.sum("age"), sum(ages))
- self.Person(name='ageless person').save()
- self.assertEqual(self.Person.objects.sum('age'), sum(ages))
+ self.Person(name="ageless person").save()
+ self.assertEqual(self.Person.objects.sum("age"), sum(ages))
for i, age in enumerate(ages):
- self.Person(name='test meta%s' %
- i, person_meta=self.PersonMeta(weight=age)).save()
+ self.Person(
+ name="test meta%s" % i, person_meta=self.PersonMeta(weight=age)
+ ).save()
- self.assertEqual(
- self.Person.objects.sum('person_meta.weight'), sum(ages)
- )
+ self.assertEqual(self.Person.objects.sum("person_meta.weight"), sum(ages))
- self.Person(name='weightless person').save()
- self.assertEqual(self.Person.objects.sum('age'), sum(ages))
+ self.Person(name="weightless person").save()
+ self.assertEqual(self.Person.objects.sum("age"), sum(ages))
# test summing over a filtered queryset
self.assertEqual(
- self.Person.objects.filter(age__gte=50).sum('age'),
- sum([a for a in ages if a >= 50])
+ self.Person.objects.filter(age__gte=50).sum("age"),
+ sum([a for a in ages if a >= 50]),
)
def test_sum_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can be summed over correctly.
"""
+
class UserVisit(Document):
- num_visits = IntField(db_field='visits')
+ num_visits = IntField(db_field="visits")
UserVisit.drop_collection()
UserVisit.objects.create(num_visits=10)
UserVisit.objects.create(num_visits=5)
- self.assertEqual(UserVisit.objects.sum('num_visits'), 15)
+ self.assertEqual(UserVisit.objects.sum("num_visits"), 15)
def test_average_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can have its average computed correctly.
"""
+
class UserVisit(Document):
- num_visits = IntField(db_field='visits')
+ num_visits = IntField(db_field="visits")
UserVisit.drop_collection()
UserVisit.objects.create(num_visits=20)
UserVisit.objects.create(num_visits=10)
- self.assertEqual(UserVisit.objects.average('num_visits'), 15)
+ self.assertEqual(UserVisit.objects.average("num_visits"), 15)
def test_embedded_average(self):
class Pay(EmbeddedDocument):
@@ -3240,17 +3269,16 @@ class QuerySetTest(unittest.TestCase):
class Doc(Document):
name = StringField()
- pay = EmbeddedDocumentField(
- Pay)
+ pay = EmbeddedDocumentField(Pay)
Doc.drop_collection()
- Doc(name='Wilson Junior', pay=Pay(value=150)).save()
- Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
- Doc(name='Tayza mariana', pay=Pay(value=165)).save()
- Doc(name='Eliana Costa', pay=Pay(value=115)).save()
+ Doc(name="Wilson Junior", pay=Pay(value=150)).save()
+ Doc(name="Isabella Luanna", pay=Pay(value=530)).save()
+ Doc(name="Tayza mariana", pay=Pay(value=165)).save()
+ Doc(name="Eliana Costa", pay=Pay(value=115)).save()
- self.assertEqual(Doc.objects.average('pay.value'), 240)
+ self.assertEqual(Doc.objects.average("pay.value"), 240)
def test_embedded_array_average(self):
class Pay(EmbeddedDocument):
@@ -3262,12 +3290,12 @@ class QuerySetTest(unittest.TestCase):
Doc.drop_collection()
- Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
- Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
- Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
- Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
+ Doc(name="Wilson Junior", pay=Pay(values=[150, 100])).save()
+ Doc(name="Isabella Luanna", pay=Pay(values=[530, 100])).save()
+ Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save()
+ Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save()
- self.assertEqual(Doc.objects.average('pay.values'), 170)
+ self.assertEqual(Doc.objects.average("pay.values"), 170)
def test_array_average(self):
class Doc(Document):
@@ -3280,7 +3308,7 @@ class QuerySetTest(unittest.TestCase):
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
- self.assertEqual(Doc.objects.average('values'), 170)
+ self.assertEqual(Doc.objects.average("values"), 170)
def test_embedded_sum(self):
class Pay(EmbeddedDocument):
@@ -3292,12 +3320,12 @@ class QuerySetTest(unittest.TestCase):
Doc.drop_collection()
- Doc(name='Wilson Junior', pay=Pay(value=150)).save()
- Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
- Doc(name='Tayza mariana', pay=Pay(value=165)).save()
- Doc(name='Eliana Costa', pay=Pay(value=115)).save()
+ Doc(name="Wilson Junior", pay=Pay(value=150)).save()
+ Doc(name="Isabella Luanna", pay=Pay(value=530)).save()
+ Doc(name="Tayza mariana", pay=Pay(value=165)).save()
+ Doc(name="Eliana Costa", pay=Pay(value=115)).save()
- self.assertEqual(Doc.objects.sum('pay.value'), 960)
+ self.assertEqual(Doc.objects.sum("pay.value"), 960)
def test_embedded_array_sum(self):
class Pay(EmbeddedDocument):
@@ -3309,12 +3337,12 @@ class QuerySetTest(unittest.TestCase):
Doc.drop_collection()
- Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
- Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
- Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
- Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
+ Doc(name="Wilson Junior", pay=Pay(values=[150, 100])).save()
+ Doc(name="Isabella Luanna", pay=Pay(values=[530, 100])).save()
+ Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save()
+ Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save()
- self.assertEqual(Doc.objects.sum('pay.values'), 1360)
+ self.assertEqual(Doc.objects.sum("pay.values"), 1360)
def test_array_sum(self):
class Doc(Document):
@@ -3327,21 +3355,24 @@ class QuerySetTest(unittest.TestCase):
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
- self.assertEqual(Doc.objects.sum('values'), 1360)
+ self.assertEqual(Doc.objects.sum("values"), 1360)
def test_distinct(self):
"""Ensure that the QuerySet.distinct method works.
"""
- self.Person(name='Mr Orange', age=20).save()
- self.Person(name='Mr White', age=20).save()
- self.Person(name='Mr Orange', age=30).save()
- self.Person(name='Mr Pink', age=30).save()
- self.assertEqual(set(self.Person.objects.distinct('name')),
- set(['Mr Orange', 'Mr White', 'Mr Pink']))
- self.assertEqual(set(self.Person.objects.distinct('age')),
- set([20, 30]))
- self.assertEqual(set(self.Person.objects(age=30).distinct('name')),
- set(['Mr Orange', 'Mr Pink']))
+ self.Person(name="Mr Orange", age=20).save()
+ self.Person(name="Mr White", age=20).save()
+ self.Person(name="Mr Orange", age=30).save()
+ self.Person(name="Mr Pink", age=30).save()
+ self.assertEqual(
+ set(self.Person.objects.distinct("name")),
+ set(["Mr Orange", "Mr White", "Mr Pink"]),
+ )
+ self.assertEqual(set(self.Person.objects.distinct("age")), set([20, 30]))
+ self.assertEqual(
+ set(self.Person.objects(age=30).distinct("name")),
+ set(["Mr Orange", "Mr Pink"]),
+ )
def test_distinct_handles_references(self):
class Foo(Document):
@@ -3367,53 +3398,58 @@ class QuerySetTest(unittest.TestCase):
content = StringField()
is_active = BooleanField(default=True)
- meta = {'indexes': [
- {'fields': ['$title', "$content"],
- 'default_language': 'portuguese',
- 'weights': {'title': 10, 'content': 2}
- }
- ]}
+ meta = {
+ "indexes": [
+ {
+ "fields": ["$title", "$content"],
+ "default_language": "portuguese",
+ "weights": {"title": 10, "content": 2},
+ }
+ ]
+ }
News.drop_collection()
info = News.objects._collection.index_information()
- self.assertIn('title_text_content_text', info)
- self.assertIn('textIndexVersion', info['title_text_content_text'])
+ self.assertIn("title_text_content_text", info)
+ self.assertIn("textIndexVersion", info["title_text_content_text"])
- News(title="Neymar quebrou a vertebra",
- content="O Brasil sofre com a perda de Neymar").save()
+ News(
+ title="Neymar quebrou a vertebra",
+ content="O Brasil sofre com a perda de Neymar",
+ ).save()
- News(title="Brasil passa para as quartas de finais",
- content="Com o brasil nas quartas de finais teremos um "
- "jogo complicado com a alemanha").save()
+ News(
+ title="Brasil passa para as quartas de finais",
+ content="Com o brasil nas quartas de finais teremos um "
+ "jogo complicado com a alemanha",
+ ).save()
- count = News.objects.search_text(
- "neymar", language="portuguese").count()
+ count = News.objects.search_text("neymar", language="portuguese").count()
self.assertEqual(count, 1)
- count = News.objects.search_text(
- "brasil -neymar").count()
+ count = News.objects.search_text("brasil -neymar").count()
self.assertEqual(count, 1)
- News(title=u"As eleições no Brasil já estão em planejamento",
- content=u"A candidata dilma roussef já começa o teu planejamento",
- is_active=False).save()
+ News(
+ title=u"As eleições no Brasil já estão em planejamento",
+ content=u"A candidata dilma roussef já começa o teu planejamento",
+ is_active=False,
+ ).save()
- new = News.objects(is_active=False).search_text(
- "dilma", language="pt").first()
+ new = News.objects(is_active=False).search_text("dilma", language="pt").first()
- query = News.objects(is_active=False).search_text(
- "dilma", language="pt")._query
+ query = News.objects(is_active=False).search_text("dilma", language="pt")._query
self.assertEqual(
- query, {'$text': {
- '$search': 'dilma', '$language': 'pt'},
- 'is_active': False})
+ query,
+ {"$text": {"$search": "dilma", "$language": "pt"}, "is_active": False},
+ )
self.assertFalse(new.is_active)
- self.assertIn('dilma', new.content)
- self.assertIn('planejamento', new.title)
+ self.assertIn("dilma", new.content)
+ self.assertIn("planejamento", new.title)
query = News.objects.search_text("candidata")
self.assertEqual(query._search_text, "candidata")
@@ -3422,15 +3458,14 @@ class QuerySetTest(unittest.TestCase):
self.assertIsInstance(new.get_text_score(), float)
# count
- query = News.objects.search_text('brasil').order_by('$text_score')
+ query = News.objects.search_text("brasil").order_by("$text_score")
self.assertEqual(query._search_text, "brasil")
self.assertEqual(query.count(), 3)
- self.assertEqual(query._query, {'$text': {'$search': 'brasil'}})
+ self.assertEqual(query._query, {"$text": {"$search": "brasil"}})
cursor_args = query._cursor_args
- cursor_args_fields = cursor_args['projection']
- self.assertEqual(
- cursor_args_fields, {'_text_score': {'$meta': 'textScore'}})
+ cursor_args_fields = cursor_args["projection"]
+ self.assertEqual(cursor_args_fields, {"_text_score": {"$meta": "textScore"}})
text_scores = [i.get_text_score() for i in query]
self.assertEqual(len(text_scores), 3)
@@ -3440,20 +3475,19 @@ class QuerySetTest(unittest.TestCase):
max_text_score = text_scores[0]
# get item
- item = News.objects.search_text(
- 'brasil').order_by('$text_score').first()
+ item = News.objects.search_text("brasil").order_by("$text_score").first()
self.assertEqual(item.get_text_score(), max_text_score)
def test_distinct_handles_references_to_alias(self):
- register_connection('testdb', 'mongoenginetest2')
+ register_connection("testdb", "mongoenginetest2")
class Foo(Document):
bar = ReferenceField("Bar")
- meta = {'db_alias': 'testdb'}
+ meta = {"db_alias": "testdb"}
class Bar(Document):
text = StringField()
- meta = {'db_alias': 'testdb'}
+ meta = {"db_alias": "testdb"}
Bar.drop_collection()
Foo.drop_collection()
@@ -3469,8 +3503,9 @@ class QuerySetTest(unittest.TestCase):
def test_distinct_handles_db_field(self):
"""Ensure that distinct resolves field name to db_field as expected.
"""
+
class Product(Document):
- product_id = IntField(db_field='pid')
+ product_id = IntField(db_field="pid")
Product.drop_collection()
@@ -3478,15 +3513,12 @@ class QuerySetTest(unittest.TestCase):
Product(product_id=2).save()
Product(product_id=1).save()
- self.assertEqual(set(Product.objects.distinct('product_id')),
- set([1, 2]))
- self.assertEqual(set(Product.objects.distinct('pid')),
- set([1, 2]))
+ self.assertEqual(set(Product.objects.distinct("product_id")), set([1, 2]))
+ self.assertEqual(set(Product.objects.distinct("pid")), set([1, 2]))
Product.drop_collection()
def test_distinct_ListField_EmbeddedDocumentField(self):
-
class Author(EmbeddedDocument):
name = StringField()
@@ -3524,8 +3556,8 @@ class QuerySetTest(unittest.TestCase):
Book.drop_collection()
- europe = Continent(continent_name='europe')
- asia = Continent(continent_name='asia')
+ europe = Continent(continent_name="europe")
+ asia = Continent(continent_name="asia")
scotland = Country(country_name="Scotland", continent=europe)
tibet = Country(country_name="Tibet", continent=asia)
@@ -3544,13 +3576,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(continent_list, [europe, asia])
def test_distinct_ListField_ReferenceField(self):
-
class Bar(Document):
text = StringField()
class Foo(Document):
- bar = ReferenceField('Bar')
- bar_lst = ListField(ReferenceField('Bar'))
+ bar = ReferenceField("Bar")
+ bar_lst = ListField(ReferenceField("Bar"))
Bar.drop_collection()
Foo.drop_collection()
@@ -3569,6 +3600,7 @@ class QuerySetTest(unittest.TestCase):
def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected.
"""
+
class BlogPost(Document):
tags = ListField(StringField())
deleted = BooleanField(default=False)
@@ -3586,32 +3618,30 @@ class QuerySetTest(unittest.TestCase):
@queryset_manager
def music_posts(doc_cls, queryset, deleted=False):
- return queryset(tags='music',
- deleted=deleted).order_by('date')
+ return queryset(tags="music", deleted=deleted).order_by("date")
BlogPost.drop_collection()
- post1 = BlogPost(tags=['music', 'film']).save()
- post2 = BlogPost(tags=['music']).save()
- post3 = BlogPost(tags=['film', 'actors']).save()
- post4 = BlogPost(tags=['film', 'actors', 'music'], deleted=True).save()
+ post1 = BlogPost(tags=["music", "film"]).save()
+ post2 = BlogPost(tags=["music"]).save()
+ post3 = BlogPost(tags=["film", "actors"]).save()
+ post4 = BlogPost(tags=["film", "actors", "music"], deleted=True).save()
- self.assertEqual([p.id for p in BlogPost.objects()],
- [post1.id, post2.id, post3.id])
- self.assertEqual([p.id for p in BlogPost.objects_1_arg()],
- [post1.id, post2.id, post3.id])
- self.assertEqual([p.id for p in BlogPost.music_posts()],
- [post1.id, post2.id])
+ self.assertEqual(
+ [p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id]
+ )
+ self.assertEqual(
+ [p.id for p in BlogPost.objects_1_arg()], [post1.id, post2.id, post3.id]
+ )
+ self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id])
- self.assertEqual([p.id for p in BlogPost.music_posts(True)],
- [post4.id])
+ self.assertEqual([p.id for p in BlogPost.music_posts(True)], [post4.id])
BlogPost.drop_collection()
def test_custom_manager_overriding_objects_works(self):
-
class Foo(Document):
- bar = StringField(default='bar')
+ bar = StringField(default="bar")
active = BooleanField(default=False)
@queryset_manager
@@ -3635,9 +3665,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, Foo.objects.count())
def test_inherit_objects(self):
-
class Foo(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
active = BooleanField(default=True)
@queryset_manager
@@ -3652,9 +3681,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(0, Bar.objects.count())
def test_inherit_objects_override(self):
-
class Foo(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
active = BooleanField(default=True)
@queryset_manager
@@ -3662,7 +3690,6 @@ class QuerySetTest(unittest.TestCase):
return queryset(active=True)
class Bar(Foo):
-
@queryset_manager
def objects(klass, queryset):
return queryset(active=False)
@@ -3675,12 +3702,13 @@ class QuerySetTest(unittest.TestCase):
def test_query_value_conversion(self):
"""Ensure that query values are properly converted when necessary.
"""
+
class BlogPost(Document):
author = ReferenceField(self.Person)
BlogPost.drop_collection()
- person = self.Person(name='test', age=30)
+ person = self.Person(name="test", age=30)
person.save()
post = BlogPost(author=person)
@@ -3701,14 +3729,15 @@ class QuerySetTest(unittest.TestCase):
def test_update_value_conversion(self):
"""Ensure that values used in updates are converted before use.
"""
+
class Group(Document):
members = ListField(ReferenceField(self.Person))
Group.drop_collection()
- user1 = self.Person(name='user1')
+ user1 = self.Person(name="user1")
user1.save()
- user2 = self.Person(name='user2')
+ user2 = self.Person(name="user2")
user2.save()
group = Group()
@@ -3726,6 +3755,7 @@ class QuerySetTest(unittest.TestCase):
def test_bulk(self):
"""Ensure bulk querying by object id returns a proper dict.
"""
+
class BlogPost(Document):
title = StringField()
@@ -3764,13 +3794,13 @@ class QuerySetTest(unittest.TestCase):
def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used.
"""
- class CustomQuerySet(QuerySet):
+ class CustomQuerySet(QuerySet):
def not_empty(self):
return self.count() > 0
class Post(Document):
- meta = {'queryset_class': CustomQuerySet}
+ meta = {"queryset_class": CustomQuerySet}
Post.drop_collection()
@@ -3787,7 +3817,6 @@ class QuerySetTest(unittest.TestCase):
"""
class CustomQuerySet(QuerySet):
-
def not_empty(self):
return self.count() > 0
@@ -3812,7 +3841,6 @@ class QuerySetTest(unittest.TestCase):
"""
class CustomQuerySetManager(QuerySetManager):
-
@staticmethod
def get_queryset(doc_cls, queryset):
return queryset(is_published=True)
@@ -3835,12 +3863,11 @@ class QuerySetTest(unittest.TestCase):
"""
class CustomQuerySet(QuerySet):
-
def not_empty(self):
return self.count() > 0
class Base(Document):
- meta = {'abstract': True, 'queryset_class': CustomQuerySet}
+ meta = {"abstract": True, "queryset_class": CustomQuerySet}
class Post(Base):
pass
@@ -3859,7 +3886,6 @@ class QuerySetTest(unittest.TestCase):
"""
class CustomQuerySet(QuerySet):
-
def not_empty(self):
return self.count() > 0
@@ -3867,7 +3893,7 @@ class QuerySetTest(unittest.TestCase):
queryset_class = CustomQuerySet
class Base(Document):
- meta = {'abstract': True}
+ meta = {"abstract": True}
objects = CustomQuerySetManager()
class Post(Base):
@@ -3891,10 +3917,13 @@ class QuerySetTest(unittest.TestCase):
for i in range(10):
Post(title="Post %s" % i).save()
- self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True))
+ self.assertEqual(
+ 5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)
+ )
self.assertEqual(
- 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False))
+ 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False)
+ )
def test_count_and_none(self):
"""Test count works with None()"""
@@ -3916,11 +3945,12 @@ class QuerySetTest(unittest.TestCase):
class A(Document):
b = ListField(EmbeddedDocumentField(B))
- self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0)
+ self.assertEqual(A.objects(b=[{"c": "c"}]).count(), 0)
def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works
"""
+
class Post(Document):
title = StringField()
@@ -3937,6 +3967,7 @@ class QuerySetTest(unittest.TestCase):
def test_order_then_filter(self):
"""Ensure that ordering still works after filtering.
"""
+
class Number(Document):
n = IntField()
@@ -3946,14 +3977,15 @@ class QuerySetTest(unittest.TestCase):
n1 = Number.objects.create(n=1)
self.assertEqual(list(Number.objects), [n2, n1])
- self.assertEqual(list(Number.objects.order_by('n')), [n1, n2])
- self.assertEqual(list(Number.objects.order_by('n').filter()), [n1, n2])
+ self.assertEqual(list(Number.objects.order_by("n")), [n1, n2])
+ self.assertEqual(list(Number.objects.order_by("n").filter()), [n1, n2])
Number.drop_collection()
def test_clone(self):
"""Ensure that cloning clones complex querysets
"""
+
class Number(Document):
n = IntField()
@@ -3983,19 +4015,20 @@ class QuerySetTest(unittest.TestCase):
def test_using(self):
"""Ensure that switching databases for a queryset is possible
"""
+
class Number2(Document):
n = IntField()
Number2.drop_collection()
- with switch_db(Number2, 'test2') as Number2:
+ with switch_db(Number2, "test2") as Number2:
Number2.drop_collection()
for i in range(1, 10):
t = Number2(n=i)
- t.switch_db('test2')
+ t.switch_db("test2")
t.save()
- self.assertEqual(len(Number2.objects.using('test2')), 9)
+ self.assertEqual(len(Number2.objects.using("test2")), 9)
def test_unset_reference(self):
class Comment(Document):
@@ -4007,7 +4040,7 @@ class QuerySetTest(unittest.TestCase):
Comment.drop_collection()
Post.drop_collection()
- comment = Comment.objects.create(text='test')
+ comment = Comment.objects.create(text="test")
post = Post.objects.create(comment=comment)
self.assertEqual(post.comment, comment)
@@ -4020,7 +4053,7 @@ class QuerySetTest(unittest.TestCase):
def test_order_works_with_custom_db_field_names(self):
class Number(Document):
- n = IntField(db_field='number')
+ n = IntField(db_field="number")
Number.drop_collection()
@@ -4028,13 +4061,14 @@ class QuerySetTest(unittest.TestCase):
n1 = Number.objects.create(n=1)
self.assertEqual(list(Number.objects), [n2, n1])
- self.assertEqual(list(Number.objects.order_by('n')), [n1, n2])
+ self.assertEqual(list(Number.objects.order_by("n")), [n1, n2])
Number.drop_collection()
def test_order_works_with_primary(self):
"""Ensure that order_by and primary work.
"""
+
class Number(Document):
n = IntField(primary_key=True)
@@ -4044,28 +4078,29 @@ class QuerySetTest(unittest.TestCase):
Number(n=2).save()
Number(n=3).save()
- numbers = [n.n for n in Number.objects.order_by('-n')]
+ numbers = [n.n for n in Number.objects.order_by("-n")]
self.assertEqual([3, 2, 1], numbers)
- numbers = [n.n for n in Number.objects.order_by('+n')]
+ numbers = [n.n for n in Number.objects.order_by("+n")]
self.assertEqual([1, 2, 3], numbers)
Number.drop_collection()
def test_ensure_index(self):
"""Ensure that manual creation of indexes works.
"""
+
class Comment(Document):
message = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
- Comment.ensure_index('message')
+ Comment.ensure_index("message")
info = Comment.objects._collection.index_information()
- info = [(value['key'],
- value.get('unique', False),
- value.get('sparse', False))
- for key, value in iteritems(info)]
- self.assertIn(([('_cls', 1), ('message', 1)], False, False), info)
+ info = [
+ (value["key"], value.get("unique", False), value.get("sparse", False))
+ for key, value in iteritems(info)
+ ]
+ self.assertIn(([("_cls", 1), ("message", 1)], False, False), info)
def test_where(self):
"""Ensure that where clauses work.
@@ -4084,23 +4119,25 @@ class QuerySetTest(unittest.TestCase):
b.save()
c.save()
- query = IntPair.objects.where('this[~fielda] >= this[~fieldb]')
- self.assertEqual(
- 'this["fielda"] >= this["fieldb"]', query._where_clause)
+ query = IntPair.objects.where("this[~fielda] >= this[~fieldb]")
+ self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause)
results = list(query)
self.assertEqual(2, len(results))
self.assertIn(a, results)
self.assertIn(c, results)
- query = IntPair.objects.where('this[~fielda] == this[~fieldb]')
+ query = IntPair.objects.where("this[~fielda] == this[~fieldb]")
results = list(query)
self.assertEqual(1, len(results))
self.assertIn(a, results)
query = IntPair.objects.where(
- 'function() { return this[~fielda] >= this[~fieldb] }')
+ "function() { return this[~fielda] >= this[~fieldb] }"
+ )
self.assertEqual(
- 'function() { return this["fielda"] >= this["fieldb"] }', query._where_clause)
+ 'function() { return this["fielda"] >= this["fieldb"] }',
+ query._where_clause,
+ )
results = list(query)
self.assertEqual(2, len(results))
self.assertIn(a, results)
@@ -4110,7 +4147,6 @@ class QuerySetTest(unittest.TestCase):
list(IntPair.objects.where(fielda__gte=3))
def test_scalar(self):
-
class Organization(Document):
name = StringField()
@@ -4127,13 +4163,13 @@ class QuerySetTest(unittest.TestCase):
# Efficient way to get all unique organization names for a given
# set of users (Pretend this has additional filtering.)
- user_orgs = set(User.objects.scalar('organization'))
- orgs = Organization.objects(id__in=user_orgs).scalar('name')
- self.assertEqual(list(orgs), ['White House'])
+ user_orgs = set(User.objects.scalar("organization"))
+ orgs = Organization.objects(id__in=user_orgs).scalar("name")
+ self.assertEqual(list(orgs), ["White House"])
# Efficient for generating listings, too.
- orgs = Organization.objects.scalar('name').in_bulk(list(user_orgs))
- user_map = User.objects.scalar('name', 'organization')
+ orgs = Organization.objects.scalar("name").in_bulk(list(user_orgs))
+ user_map = User.objects.scalar("name", "organization")
user_listing = [(user, orgs[org]) for user, org in user_map]
self.assertEqual([("Bob Dole", "White House")], user_listing)
@@ -4148,7 +4184,7 @@ class QuerySetTest(unittest.TestCase):
TestDoc(x=20, y=False).save()
TestDoc(x=30, y=True).save()
- plist = list(TestDoc.objects.scalar('x', 'y'))
+ plist = list(TestDoc.objects.scalar("x", "y"))
self.assertEqual(len(plist), 3)
self.assertEqual(plist[0], (10, True))
@@ -4166,21 +4202,16 @@ class QuerySetTest(unittest.TestCase):
UserDoc(name="Eliana", age=37).save()
UserDoc(name="Tayza", age=15).save()
- ulist = list(UserDoc.objects.scalar('name', 'age'))
+ ulist = list(UserDoc.objects.scalar("name", "age"))
- self.assertEqual(ulist, [
- (u'Wilson Jr', 19),
- (u'Wilson', 43),
- (u'Eliana', 37),
- (u'Tayza', 15)])
+ self.assertEqual(
+ ulist,
+ [(u"Wilson Jr", 19), (u"Wilson", 43), (u"Eliana", 37), (u"Tayza", 15)],
+ )
- ulist = list(UserDoc.objects.scalar('name').order_by('age'))
+ ulist = list(UserDoc.objects.scalar("name").order_by("age"))
- self.assertEqual(ulist, [
- (u'Tayza'),
- (u'Wilson Jr'),
- (u'Eliana'),
- (u'Wilson')])
+ self.assertEqual(ulist, [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")])
def test_scalar_embedded(self):
class Profile(EmbeddedDocument):
@@ -4197,30 +4228,45 @@ class QuerySetTest(unittest.TestCase):
Person.drop_collection()
- Person(profile=Profile(name="Wilson Jr", age=19),
- locale=Locale(city="Corumba-GO", country="Brazil")).save()
+ Person(
+ profile=Profile(name="Wilson Jr", age=19),
+ locale=Locale(city="Corumba-GO", country="Brazil"),
+ ).save()
- Person(profile=Profile(name="Gabriel Falcao", age=23),
- locale=Locale(city="New York", country="USA")).save()
+ Person(
+ profile=Profile(name="Gabriel Falcao", age=23),
+ locale=Locale(city="New York", country="USA"),
+ ).save()
- Person(profile=Profile(name="Lincoln de souza", age=28),
- locale=Locale(city="Belo Horizonte", country="Brazil")).save()
+ Person(
+ profile=Profile(name="Lincoln de souza", age=28),
+ locale=Locale(city="Belo Horizonte", country="Brazil"),
+ ).save()
- Person(profile=Profile(name="Walter cruz", age=30),
- locale=Locale(city="Brasilia", country="Brazil")).save()
+ Person(
+ profile=Profile(name="Walter cruz", age=30),
+ locale=Locale(city="Brasilia", country="Brazil"),
+ ).save()
self.assertEqual(
- list(Person.objects.order_by(
- 'profile__age').scalar('profile__name')),
- [u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz'])
+ list(Person.objects.order_by("profile__age").scalar("profile__name")),
+ [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"],
+ )
- ulist = list(Person.objects.order_by('locale.city')
- .scalar('profile__name', 'profile__age', 'locale__city'))
- self.assertEqual(ulist,
- [(u'Lincoln de souza', 28, u'Belo Horizonte'),
- (u'Walter cruz', 30, u'Brasilia'),
- (u'Wilson Jr', 19, u'Corumba-GO'),
- (u'Gabriel Falcao', 23, u'New York')])
+ ulist = list(
+ Person.objects.order_by("locale.city").scalar(
+ "profile__name", "profile__age", "locale__city"
+ )
+ )
+ self.assertEqual(
+ ulist,
+ [
+ (u"Lincoln de souza", 28, u"Belo Horizonte"),
+ (u"Walter cruz", 30, u"Brasilia"),
+ (u"Wilson Jr", 19, u"Corumba-GO"),
+ (u"Gabriel Falcao", 23, u"New York"),
+ ],
+ )
def test_scalar_decimal(self):
from decimal import Decimal
@@ -4230,10 +4276,10 @@ class QuerySetTest(unittest.TestCase):
rating = DecimalField()
Person.drop_collection()
- Person(name="Wilson Jr", rating=Decimal('1.0')).save()
+ Person(name="Wilson Jr", rating=Decimal("1.0")).save()
- ulist = list(Person.objects.scalar('name', 'rating'))
- self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))])
+ ulist = list(Person.objects.scalar("name", "rating"))
+ self.assertEqual(ulist, [(u"Wilson Jr", Decimal("1.0"))])
def test_scalar_reference_field(self):
class State(Document):
@@ -4251,8 +4297,8 @@ class QuerySetTest(unittest.TestCase):
Person(name="Wilson JR", state=s1).save()
- plist = list(Person.objects.scalar('name', 'state'))
- self.assertEqual(plist, [(u'Wilson JR', s1)])
+ plist = list(Person.objects.scalar("name", "state"))
+ self.assertEqual(plist, [(u"Wilson JR", s1)])
def test_scalar_generic_reference_field(self):
class State(Document):
@@ -4270,8 +4316,8 @@ class QuerySetTest(unittest.TestCase):
Person(name="Wilson JR", state=s1).save()
- plist = list(Person.objects.scalar('name', 'state'))
- self.assertEqual(plist, [(u'Wilson JR', s1)])
+ plist = list(Person.objects.scalar("name", "state"))
+ self.assertEqual(plist, [(u"Wilson JR", s1)])
def test_generic_reference_field_with_only_and_as_pymongo(self):
class TestPerson(Document):
@@ -4284,26 +4330,32 @@ class QuerySetTest(unittest.TestCase):
TestPerson.drop_collection()
TestActivity.drop_collection()
- person = TestPerson(name='owner')
+ person = TestPerson(name="owner")
person.save()
- a1 = TestActivity(name='a1', owner=person)
+ a1 = TestActivity(name="a1", owner=person)
a1.save()
- activity = TestActivity.objects(owner=person).scalar('id', 'owner').no_dereference().first()
+ activity = (
+ TestActivity.objects(owner=person)
+ .scalar("id", "owner")
+ .no_dereference()
+ .first()
+ )
self.assertEqual(activity[0], a1.pk)
- self.assertEqual(activity[1]['_ref'], DBRef('test_person', person.pk))
+ self.assertEqual(activity[1]["_ref"], DBRef("test_person", person.pk))
- activity = TestActivity.objects(owner=person).only('id', 'owner')[0]
+ activity = TestActivity.objects(owner=person).only("id", "owner")[0]
self.assertEqual(activity.pk, a1.pk)
self.assertEqual(activity.owner, person)
- activity = TestActivity.objects(owner=person).only('id', 'owner').as_pymongo().first()
- self.assertEqual(activity['_id'], a1.pk)
- self.assertTrue(activity['owner']['_ref'], DBRef('test_person', person.pk))
+ activity = (
+ TestActivity.objects(owner=person).only("id", "owner").as_pymongo().first()
+ )
+ self.assertEqual(activity["_id"], a1.pk)
+ self.assertTrue(activity["owner"]["_ref"], DBRef("test_person", person.pk))
def test_scalar_db_field(self):
-
class TestDoc(Document):
x = IntField()
y = BooleanField()
@@ -4314,14 +4366,13 @@ class QuerySetTest(unittest.TestCase):
TestDoc(x=20, y=False).save()
TestDoc(x=30, y=True).save()
- plist = list(TestDoc.objects.scalar('x', 'y'))
+ plist = list(TestDoc.objects.scalar("x", "y"))
self.assertEqual(len(plist), 3)
self.assertEqual(plist[0], (10, True))
self.assertEqual(plist[1], (20, False))
self.assertEqual(plist[2], (30, True))
def test_scalar_primary_key(self):
-
class SettingValue(Document):
key = StringField(primary_key=True)
value = StringField()
@@ -4330,8 +4381,8 @@ class QuerySetTest(unittest.TestCase):
s = SettingValue(key="test", value="test value")
s.save()
- val = SettingValue.objects.scalar('key', 'value')
- self.assertEqual(list(val), [('test', 'test value')])
+ val = SettingValue.objects.scalar("key", "value")
+ self.assertEqual(list(val), [("test", "test value")])
def test_scalar_cursor_behaviour(self):
"""Ensure that a query returns a valid set of results.
@@ -4342,83 +4393,94 @@ class QuerySetTest(unittest.TestCase):
person2.save()
# Find all people in the collection
- people = self.Person.objects.scalar('name')
+ people = self.Person.objects.scalar("name")
self.assertEqual(people.count(), 2)
results = list(people)
self.assertEqual(results[0], "User A")
self.assertEqual(results[1], "User B")
# Use a query to filter the people found to just person1
- people = self.Person.objects(age=20).scalar('name')
+ people = self.Person.objects(age=20).scalar("name")
self.assertEqual(people.count(), 1)
person = people.next()
self.assertEqual(person, "User A")
# Test limit
- people = list(self.Person.objects.limit(1).scalar('name'))
+ people = list(self.Person.objects.limit(1).scalar("name"))
self.assertEqual(len(people), 1)
- self.assertEqual(people[0], 'User A')
+ self.assertEqual(people[0], "User A")
# Test skip
- people = list(self.Person.objects.skip(1).scalar('name'))
+ people = list(self.Person.objects.skip(1).scalar("name"))
self.assertEqual(len(people), 1)
- self.assertEqual(people[0], 'User B')
+ self.assertEqual(people[0], "User B")
person3 = self.Person(name="User C", age=40)
person3.save()
# Test slice limit
- people = list(self.Person.objects[:2].scalar('name'))
+ people = list(self.Person.objects[:2].scalar("name"))
self.assertEqual(len(people), 2)
- self.assertEqual(people[0], 'User A')
- self.assertEqual(people[1], 'User B')
+ self.assertEqual(people[0], "User A")
+ self.assertEqual(people[1], "User B")
# Test slice skip
- people = list(self.Person.objects[1:].scalar('name'))
+ people = list(self.Person.objects[1:].scalar("name"))
self.assertEqual(len(people), 2)
- self.assertEqual(people[0], 'User B')
- self.assertEqual(people[1], 'User C')
+ self.assertEqual(people[0], "User B")
+ self.assertEqual(people[1], "User C")
# Test slice limit and skip
- people = list(self.Person.objects[1:2].scalar('name'))
+ people = list(self.Person.objects[1:2].scalar("name"))
self.assertEqual(len(people), 1)
- self.assertEqual(people[0], 'User B')
+ self.assertEqual(people[0], "User B")
- people = list(self.Person.objects[1:1].scalar('name'))
+ people = list(self.Person.objects[1:1].scalar("name"))
self.assertEqual(len(people), 0)
# Test slice out of range
- people = list(self.Person.objects.scalar('name')[80000:80001])
+ people = list(self.Person.objects.scalar("name")[80000:80001])
self.assertEqual(len(people), 0)
# Test larger slice __repr__
self.Person.objects.delete()
for i in range(55):
- self.Person(name='A%s' % i, age=i).save()
+ self.Person(name="A%s" % i, age=i).save()
- self.assertEqual(self.Person.objects.scalar('name').count(), 55)
+ self.assertEqual(self.Person.objects.scalar("name").count(), 55)
self.assertEqual(
- "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first())
+ "A0", "%s" % self.Person.objects.order_by("name").scalar("name").first()
+ )
self.assertEqual(
- "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0])
+ "A0", "%s" % self.Person.objects.scalar("name").order_by("name")[0]
+ )
if six.PY3:
- self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by(
- 'age').scalar('name')[1:3])
- self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by(
- 'age').scalar('name')[51:53])
+ self.assertEqual(
+ "['A1', 'A2']",
+ "%s" % self.Person.objects.order_by("age").scalar("name")[1:3],
+ )
+ self.assertEqual(
+ "['A51', 'A52']",
+ "%s" % self.Person.objects.order_by("age").scalar("name")[51:53],
+ )
else:
- self.assertEqual("[u'A1', u'A2']", "%s" % self.Person.objects.order_by(
- 'age').scalar('name')[1:3])
- self.assertEqual("[u'A51', u'A52']", "%s" % self.Person.objects.order_by(
- 'age').scalar('name')[51:53])
+ self.assertEqual(
+ "[u'A1', u'A2']",
+ "%s" % self.Person.objects.order_by("age").scalar("name")[1:3],
+ )
+ self.assertEqual(
+ "[u'A51', u'A52']",
+ "%s" % self.Person.objects.order_by("age").scalar("name")[51:53],
+ )
# with_id and in_bulk
- person = self.Person.objects.order_by('name').first()
- self.assertEqual("A0", "%s" %
- self.Person.objects.scalar('name').with_id(person.id))
+ person = self.Person.objects.order_by("name").first()
+ self.assertEqual(
+ "A0", "%s" % self.Person.objects.scalar("name").with_id(person.id)
+ )
- pks = self.Person.objects.order_by('age').scalar('pk')[1:3]
- names = self.Person.objects.scalar('name').in_bulk(list(pks)).values()
+ pks = self.Person.objects.order_by("age").scalar("pk")[1:3]
+ names = self.Person.objects.scalar("name").in_bulk(list(pks)).values()
if six.PY3:
expected = "['A1', 'A2']"
else:
@@ -4430,51 +4492,61 @@ class QuerySetTest(unittest.TestCase):
shape = StringField()
color = StringField()
thick = BooleanField()
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class Bar(Document):
foo = ListField(EmbeddedDocumentField(Foo))
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
Bar.drop_collection()
- b1 = Bar(foo=[Foo(shape="square", color="purple", thick=False),
- Foo(shape="circle", color="red", thick=True)])
+ b1 = Bar(
+ foo=[
+ Foo(shape="square", color="purple", thick=False),
+ Foo(shape="circle", color="red", thick=True),
+ ]
+ )
b1.save()
- b2 = Bar(foo=[Foo(shape="square", color="red", thick=True),
- Foo(shape="circle", color="purple", thick=False)])
+ b2 = Bar(
+ foo=[
+ Foo(shape="square", color="red", thick=True),
+ Foo(shape="circle", color="purple", thick=False),
+ ]
+ )
b2.save()
- b3 = Bar(foo=[Foo(shape="square", thick=True),
- Foo(shape="circle", color="purple", thick=False)])
+ b3 = Bar(
+ foo=[
+ Foo(shape="square", thick=True),
+ Foo(shape="circle", color="purple", thick=False),
+ ]
+ )
b3.save()
- ak = list(
- Bar.objects(foo__match={'shape': "square", "color": "purple"}))
+ ak = list(Bar.objects(foo__match={"shape": "square", "color": "purple"}))
self.assertEqual([b1], ak)
- ak = list(
- Bar.objects(foo__elemMatch={'shape': "square", "color": "purple"}))
+ ak = list(Bar.objects(foo__elemMatch={"shape": "square", "color": "purple"}))
self.assertEqual([b1], ak)
ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple")))
self.assertEqual([b1], ak)
ak = list(
- Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True}))
+ Bar.objects(foo__elemMatch={"shape": "square", "color__exists": True})
+ )
+ self.assertEqual([b1, b2], ak)
+
+ ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": True}))
self.assertEqual([b1, b2], ak)
ak = list(
- Bar.objects(foo__match={'shape': "square", "color__exists": True}))
- self.assertEqual([b1, b2], ak)
-
- ak = list(
- Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False}))
+ Bar.objects(foo__elemMatch={"shape": "square", "color__exists": False})
+ )
self.assertEqual([b3], ak)
- ak = list(
- Bar.objects(foo__match={'shape': "square", "color__exists": False}))
+ ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": False}))
self.assertEqual([b3], ak)
def test_upsert_includes_cls(self):
@@ -4485,24 +4557,25 @@ class QuerySetTest(unittest.TestCase):
test = StringField()
Test.drop_collection()
- Test.objects(test='foo').update_one(upsert=True, set__test='foo')
- self.assertNotIn('_cls', Test._collection.find_one())
+ Test.objects(test="foo").update_one(upsert=True, set__test="foo")
+ self.assertNotIn("_cls", Test._collection.find_one())
class Test(Document):
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
test = StringField()
Test.drop_collection()
- Test.objects(test='foo').update_one(upsert=True, set__test='foo')
- self.assertIn('_cls', Test._collection.find_one())
+ Test.objects(test="foo").update_one(upsert=True, set__test="foo")
+ self.assertIn("_cls", Test._collection.find_one())
def test_update_upsert_looks_like_a_digit(self):
class MyDoc(DynamicDocument):
pass
+
MyDoc.drop_collection()
self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1))
- self.assertEqual(MyDoc.objects.get()['47'], 1)
+ self.assertEqual(MyDoc.objects.get()["47"], 1)
def test_dictfield_key_looks_like_a_digit(self):
"""Only should work with DictField even if they have numeric keys."""
@@ -4511,86 +4584,84 @@ class QuerySetTest(unittest.TestCase):
test = DictField()
MyDoc.drop_collection()
- doc = MyDoc(test={'47': 1})
+ doc = MyDoc(test={"47": 1})
doc.save()
- self.assertEqual(MyDoc.objects.only('test__47').get().test['47'], 1)
+ self.assertEqual(MyDoc.objects.only("test__47").get().test["47"], 1)
def test_read_preference(self):
class Bar(Document):
txt = StringField()
- meta = {
- 'indexes': ['txt']
- }
+ meta = {"indexes": ["txt"]}
Bar.drop_collection()
bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY))
self.assertEqual([], bars)
- self.assertRaises(TypeError, Bar.objects, read_preference='Primary')
+ self.assertRaises(TypeError, Bar.objects, read_preference="Primary")
# read_preference as a kwarg
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
# read_preference as a query set method
bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
# read_preference after skip
- bars = Bar.objects.skip(1) \
- .read_preference(ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ bars = Bar.objects.skip(1).read_preference(ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
# read_preference after limit
- bars = Bar.objects.limit(1) \
- .read_preference(ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ bars = Bar.objects.limit(1).read_preference(ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
# read_preference after order_by
- bars = Bar.objects.order_by('txt') \
- .read_preference(ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ bars = Bar.objects.order_by("txt").read_preference(
+ ReadPreference.SECONDARY_PREFERRED
+ )
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
# read_preference after hint
- bars = Bar.objects.hint([('txt', 1)]) \
- .read_preference(ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._read_preference,
- ReadPreference.SECONDARY_PREFERRED)
- self.assertEqual(bars._cursor._Cursor__read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ bars = Bar.objects.hint([("txt", 1)]).read_preference(
+ ReadPreference.SECONDARY_PREFERRED
+ )
+ self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
+ self.assertEqual(
+ bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED
+ )
def test_read_preference_aggregation_framework(self):
class Bar(Document):
txt = StringField()
- meta = {
- 'indexes': ['txt']
- }
+ meta = {"indexes": ["txt"]}
+
# Aggregates with read_preference
- bars = Bar.objects \
- .read_preference(ReadPreference.SECONDARY_PREFERRED) \
- .aggregate()
- self.assertEqual(bars._CommandCursor__collection.read_preference,
- ReadPreference.SECONDARY_PREFERRED)
+ bars = Bar.objects.read_preference(
+ ReadPreference.SECONDARY_PREFERRED
+ ).aggregate()
+ self.assertEqual(
+ bars._CommandCursor__collection.read_preference,
+ ReadPreference.SECONDARY_PREFERRED,
+ )
def test_json_simple(self):
-
class Embedded(EmbeddedDocument):
string = StringField()
@@ -4603,7 +4674,7 @@ class QuerySetTest(unittest.TestCase):
Doc(string="Bye", embedded_field=Embedded(string="Bye")).save()
Doc().save()
- json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':'))
+ json_data = Doc.objects.to_json(sort_keys=True, separators=(",", ":"))
doc_objects = list(Doc.objects)
self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
@@ -4616,33 +4687,34 @@ class QuerySetTest(unittest.TestCase):
pass
class Doc(Document):
- string_field = StringField(default='1')
+ string_field = StringField(default="1")
int_field = IntField(default=1)
float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.datetime.now)
embedded_document_field = EmbeddedDocumentField(
- EmbeddedDoc, default=lambda: EmbeddedDoc())
+ EmbeddedDoc, default=lambda: EmbeddedDoc()
+ )
list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=ObjectId)
- reference_field = ReferenceField(
- Simple, default=lambda: Simple().save())
+ reference_field = ReferenceField(Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.datetime.now)
url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(
- default=lambda: Simple().save())
- sorted_list_field = SortedListField(IntField(),
- default=lambda: [1, 2, 3])
+ default=lambda: Simple().save()
+ )
+ sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(
- default=lambda: EmbeddedDoc())
+ default=lambda: EmbeddedDoc()
+ )
Simple.drop_collection()
Doc.drop_collection()
@@ -4667,111 +4739,96 @@ class QuerySetTest(unittest.TestCase):
User.drop_collection()
- User.objects.create(id='Bob', name="Bob Dole", age=89, price=Decimal('1.11'))
+ User.objects.create(id="Bob", name="Bob Dole", age=89, price=Decimal("1.11"))
User.objects.create(
- id='Barak',
+ id="Barak",
name="Barak Obama",
age=51,
- price=Decimal('2.22'),
- last_login=LastLogin(
- location='White House',
- ip='104.107.108.116'
- )
+ price=Decimal("2.22"),
+ last_login=LastLogin(location="White House", ip="104.107.108.116"),
)
results = User.objects.as_pymongo()
+ self.assertEqual(set(results[0].keys()), set(["_id", "name", "age", "price"]))
self.assertEqual(
- set(results[0].keys()),
- set(['_id', 'name', 'age', 'price'])
- )
- self.assertEqual(
- set(results[1].keys()),
- set(['_id', 'name', 'age', 'price', 'last_login'])
+ set(results[1].keys()), set(["_id", "name", "age", "price", "last_login"])
)
- results = User.objects.only('id', 'name').as_pymongo()
- self.assertEqual(set(results[0].keys()), set(['_id', 'name']))
+ results = User.objects.only("id", "name").as_pymongo()
+ self.assertEqual(set(results[0].keys()), set(["_id", "name"]))
- users = User.objects.only('name', 'price').as_pymongo()
+ users = User.objects.only("name", "price").as_pymongo()
results = list(users)
self.assertIsInstance(results[0], dict)
self.assertIsInstance(results[1], dict)
- self.assertEqual(results[0]['name'], 'Bob Dole')
- self.assertEqual(results[0]['price'], 1.11)
- self.assertEqual(results[1]['name'], 'Barak Obama')
- self.assertEqual(results[1]['price'], 2.22)
+ self.assertEqual(results[0]["name"], "Bob Dole")
+ self.assertEqual(results[0]["price"], 1.11)
+ self.assertEqual(results[1]["name"], "Barak Obama")
+ self.assertEqual(results[1]["price"], 2.22)
- users = User.objects.only('name', 'last_login').as_pymongo()
+ users = User.objects.only("name", "last_login").as_pymongo()
results = list(users)
self.assertIsInstance(results[0], dict)
self.assertIsInstance(results[1], dict)
- self.assertEqual(results[0], {
- '_id': 'Bob',
- 'name': 'Bob Dole'
- })
- self.assertEqual(results[1], {
- '_id': 'Barak',
- 'name': 'Barak Obama',
- 'last_login': {
- 'location': 'White House',
- 'ip': '104.107.108.116'
- }
- })
+ self.assertEqual(results[0], {"_id": "Bob", "name": "Bob Dole"})
+ self.assertEqual(
+ results[1],
+ {
+ "_id": "Barak",
+ "name": "Barak Obama",
+ "last_login": {"location": "White House", "ip": "104.107.108.116"},
+ },
+ )
def test_as_pymongo_returns_cls_attribute_when_using_inheritance(self):
class User(Document):
name = StringField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
User.drop_collection()
user = User(name="Bob Dole").save()
result = User.objects.as_pymongo().first()
- self.assertEqual(
- result,
- {
- '_cls': 'User',
- '_id': user.id,
- 'name': 'Bob Dole'
- }
- )
+ self.assertEqual(result, {"_cls": "User", "_id": user.id, "name": "Bob Dole"})
def test_as_pymongo_json_limit_fields(self):
-
class User(Document):
email = EmailField(unique=True, required=True)
- password_hash = StringField(
- db_field='password_hash', required=True)
- password_salt = StringField(
- db_field='password_salt', required=True)
+ password_hash = StringField(db_field="password_hash", required=True)
+ password_salt = StringField(db_field="password_salt", required=True)
User.drop_collection()
- User(email="ross@example.com", password_salt="SomeSalt",
- password_hash="SomeHash").save()
+ User(
+ email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash"
+ ).save()
serialized_user = User.objects.exclude(
- 'password_salt', 'password_hash').as_pymongo()[0]
- self.assertEqual({'_id', 'email'}, set(serialized_user.keys()))
+ "password_salt", "password_hash"
+ ).as_pymongo()[0]
+ self.assertEqual({"_id", "email"}, set(serialized_user.keys()))
serialized_user = User.objects.exclude(
- 'id', 'password_salt', 'password_hash').to_json()
+ "id", "password_salt", "password_hash"
+ ).to_json()
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
- serialized_user = User.objects.only('email').as_pymongo()[0]
- self.assertEqual({'_id', 'email'}, set(serialized_user.keys()))
+ serialized_user = User.objects.only("email").as_pymongo()[0]
+ self.assertEqual({"_id", "email"}, set(serialized_user.keys()))
- serialized_user = User.objects.exclude(
- 'password_salt').only('email').as_pymongo()[0]
- self.assertEqual({'_id', 'email'}, set(serialized_user.keys()))
+ serialized_user = (
+ User.objects.exclude("password_salt").only("email").as_pymongo()[0]
+ )
+ self.assertEqual({"_id", "email"}, set(serialized_user.keys()))
- serialized_user = User.objects.exclude(
- 'password_salt', 'id').only('email').as_pymongo()[0]
- self.assertEqual({'email'}, set(serialized_user.keys()))
+ serialized_user = (
+ User.objects.exclude("password_salt", "id").only("email").as_pymongo()[0]
+ )
+ self.assertEqual({"email"}, set(serialized_user.keys()))
- serialized_user = User.objects.exclude(
- 'password_salt', 'id').only('email').to_json()
- self.assertEqual('[{"email": "ross@example.com"}]',
- serialized_user)
+ serialized_user = (
+ User.objects.exclude("password_salt", "id").only("email").to_json()
+ )
+ self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
def test_only_after_count(self):
"""Test that only() works after count()"""
@@ -4780,9 +4837,9 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
age = IntField()
address = StringField()
+
User.drop_collection()
- user = User(name="User", age=50,
- address="Moscow, Russia").save()
+ user = User(name="User", age=50, address="Moscow, Russia").save()
user_queryset = User.objects(age=50)
@@ -4796,7 +4853,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50})
def test_no_dereference(self):
-
class Organization(Document):
name = StringField()
@@ -4842,12 +4898,14 @@ class QuerySetTest(unittest.TestCase):
self.assertFalse(qs_no_deref._auto_dereference)
# Make sure the instance field is different from the class field
- instance_org_field = user_no_deref._fields['organization']
+ instance_org_field = user_no_deref._fields["organization"]
self.assertIsNot(instance_org_field, cls_organization_field)
self.assertFalse(instance_org_field._auto_dereference)
self.assertIsInstance(user_no_deref.organization, DBRef)
- self.assertTrue(cls_organization_field._auto_dereference, True) # Make sure the class Field wasn't altered
+ self.assertTrue(
+ cls_organization_field._auto_dereference, True
+ ) # Make sure the class Field wasn't altered
def test_no_dereference_no_side_effect_on_existing_instance(self):
# Relates to issue #1677 - ensures no regression of the bug
@@ -4863,8 +4921,7 @@ class QuerySetTest(unittest.TestCase):
Organization.drop_collection()
org = Organization(name="whatever").save()
- User(organization=org,
- organization_gen=org).save()
+ User(organization=org, organization_gen=org).save()
qs = User.objects()
user = qs.first()
@@ -4873,7 +4930,7 @@ class QuerySetTest(unittest.TestCase):
user_no_deref = qs_no_deref.first()
# ReferenceField
- no_derf_org = user_no_deref.organization # was triggering the bug
+ no_derf_org = user_no_deref.organization # was triggering the bug
self.assertIsInstance(no_derf_org, DBRef)
self.assertIsInstance(user.organization, Organization)
@@ -4883,7 +4940,6 @@ class QuerySetTest(unittest.TestCase):
self.assertIsInstance(user.organization_gen, Organization)
def test_no_dereference_embedded_doc(self):
-
class User(Document):
name = StringField()
@@ -4906,17 +4962,15 @@ class QuerySetTest(unittest.TestCase):
member = Member(name="Flash", user=user)
- company = Organization(name="Mongo Inc",
- ceo=user,
- member=member,
- admins=[user],
- members=[member])
+ company = Organization(
+ name="Mongo Inc", ceo=user, member=member, admins=[user], members=[member]
+ )
company.save()
org = Organization.objects().no_dereference().first()
- self.assertNotEqual(id(org._fields['admins']), id(Organization.admins))
- self.assertFalse(org._fields['admins']._auto_dereference)
+ self.assertNotEqual(id(org._fields["admins"]), id(Organization.admins))
+ self.assertFalse(org._fields["admins"]._auto_dereference)
admin = org.admins[0]
self.assertIsInstance(admin, DBRef)
@@ -4981,14 +5035,14 @@ class QuerySetTest(unittest.TestCase):
Person.drop_collection()
qs = Person.objects.no_cache()
- self.assertEqual(repr(qs), '[]')
+ self.assertEqual(repr(qs), "[]")
def test_no_cached_on_a_cached_queryset_raise_error(self):
class Person(Document):
name = StringField()
Person.drop_collection()
- Person(name='a').save()
+ Person(name="a").save()
qs = Person.objects()
_ = list(qs)
with self.assertRaises(OperationError) as ctx_err:
@@ -5008,7 +5062,6 @@ class QuerySetTest(unittest.TestCase):
self.assertIsInstance(qs, QuerySet)
def test_cache_not_cloned(self):
-
class User(Document):
name = StringField()
@@ -5020,7 +5073,7 @@ class QuerySetTest(unittest.TestCase):
User(name="Alice").save()
User(name="Bob").save()
- users = User.objects.all().order_by('name')
+ users = User.objects.all().order_by("name")
self.assertEqual("%s" % users, "[, ]")
self.assertEqual(2, len(users._result_cache))
@@ -5030,6 +5083,7 @@ class QuerySetTest(unittest.TestCase):
def test_no_cache(self):
"""Ensure you can add meta data to file"""
+
class Noddy(Document):
fields = DictField()
@@ -5063,7 +5117,7 @@ class QuerySetTest(unittest.TestCase):
def test_nested_queryset_iterator(self):
# Try iterating the same queryset twice, nested.
- names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George']
+ names = ["Alice", "Bob", "Chuck", "David", "Eric", "Francis", "George"]
class User(Document):
name = StringField()
@@ -5076,7 +5130,7 @@ class QuerySetTest(unittest.TestCase):
for name in names:
User(name=name).save()
- users = User.objects.all().order_by('name')
+ users = User.objects.all().order_by("name")
outer_count = 0
inner_count = 0
inner_total_count = 0
@@ -5114,7 +5168,7 @@ class QuerySetTest(unittest.TestCase):
x = IntField()
y = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class B(A):
z = IntField()
@@ -5151,6 +5205,7 @@ class QuerySetTest(unittest.TestCase):
def test_query_generic_embedded_document(self):
"""Ensure that querying sub field on generic_embedded_field works
"""
+
class A(EmbeddedDocument):
a_name = StringField()
@@ -5161,19 +5216,16 @@ class QuerySetTest(unittest.TestCase):
document = GenericEmbeddedDocumentField(choices=(A, B))
Doc.drop_collection()
- Doc(document=A(a_name='A doc')).save()
- Doc(document=B(b_name='B doc')).save()
+ Doc(document=A(a_name="A doc")).save()
+ Doc(document=B(b_name="B doc")).save()
# Using raw in filter working fine
- self.assertEqual(Doc.objects(
- __raw__={'document.a_name': 'A doc'}).count(), 1)
- self.assertEqual(Doc.objects(
- __raw__={'document.b_name': 'B doc'}).count(), 1)
- self.assertEqual(Doc.objects(document__a_name='A doc').count(), 1)
- self.assertEqual(Doc.objects(document__b_name='B doc').count(), 1)
+ self.assertEqual(Doc.objects(__raw__={"document.a_name": "A doc"}).count(), 1)
+ self.assertEqual(Doc.objects(__raw__={"document.b_name": "B doc"}).count(), 1)
+ self.assertEqual(Doc.objects(document__a_name="A doc").count(), 1)
+ self.assertEqual(Doc.objects(document__b_name="B doc").count(), 1)
def test_query_reference_to_custom_pk_doc(self):
-
class A(Document):
id = StringField(primary_key=True)
@@ -5183,7 +5235,7 @@ class QuerySetTest(unittest.TestCase):
A.drop_collection()
B.drop_collection()
- a = A.objects.create(id='custom_id')
+ a = A.objects.create(id="custom_id")
B.objects.create(a=a)
self.assertEqual(B.objects.count(), 1)
@@ -5191,13 +5243,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(B.objects.get(a=a.id).a, a)
def test_cls_query_in_subclassed_docs(self):
-
class Animal(Document):
name = StringField()
- meta = {
- 'allow_inheritance': True
- }
+ meta = {"allow_inheritance": True}
class Dog(Animal):
pass
@@ -5205,21 +5254,23 @@ class QuerySetTest(unittest.TestCase):
class Cat(Animal):
pass
- self.assertEqual(Animal.objects(name='Charlie')._query, {
- 'name': 'Charlie',
- '_cls': {'$in': ('Animal', 'Animal.Dog', 'Animal.Cat')}
- })
- self.assertEqual(Dog.objects(name='Charlie')._query, {
- 'name': 'Charlie',
- '_cls': 'Animal.Dog'
- })
- self.assertEqual(Cat.objects(name='Charlie')._query, {
- 'name': 'Charlie',
- '_cls': 'Animal.Cat'
- })
+ self.assertEqual(
+ Animal.objects(name="Charlie")._query,
+ {
+ "name": "Charlie",
+ "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")},
+ },
+ )
+ self.assertEqual(
+ Dog.objects(name="Charlie")._query,
+ {"name": "Charlie", "_cls": "Animal.Dog"},
+ )
+ self.assertEqual(
+ Cat.objects(name="Charlie")._query,
+ {"name": "Charlie", "_cls": "Animal.Cat"},
+ )
def test_can_have_field_same_name_as_query_operator(self):
-
class Size(Document):
name = StringField()
@@ -5236,7 +5287,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1)
def test_cursor_in_an_if_stmt(self):
-
class Test(Document):
test_field = StringField()
@@ -5244,23 +5294,23 @@ class QuerySetTest(unittest.TestCase):
queryset = Test.objects
if queryset:
- raise AssertionError('Empty cursor returns True')
+ raise AssertionError("Empty cursor returns True")
test = Test()
- test.test_field = 'test'
+ test.test_field = "test"
test.save()
queryset = Test.objects
if not test:
- raise AssertionError('Cursor has data and returned False')
+ raise AssertionError("Cursor has data and returned False")
queryset.next()
if not queryset:
- raise AssertionError('Cursor has data and it must returns True,'
- ' even in the last item.')
+ raise AssertionError(
+ "Cursor has data and it must returns True, even in the last item."
+ )
def test_bool_performance(self):
-
class Person(Document):
name = StringField()
@@ -5273,10 +5323,11 @@ class QuerySetTest(unittest.TestCase):
pass
self.assertEqual(q, 1)
- op = q.db.system.profile.find({"ns":
- {"$ne": "%s.system.indexes" % q.db.name}})[0]
+ op = q.db.system.profile.find(
+ {"ns": {"$ne": "%s.system.indexes" % q.db.name}}
+ )[0]
- self.assertEqual(op['nreturned'], 1)
+ self.assertEqual(op["nreturned"], 1)
def test_bool_with_ordering(self):
ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version)
@@ -5289,26 +5340,28 @@ class QuerySetTest(unittest.TestCase):
Person(name="Test").save()
# Check that bool(queryset) does not uses the orderby
- qs = Person.objects.order_by('name')
+ qs = Person.objects.order_by("name")
with query_counter() as q:
if bool(qs):
pass
- op = q.db.system.profile.find({"ns":
- {"$ne": "%s.system.indexes" % q.db.name}})[0]
+ op = q.db.system.profile.find(
+ {"ns": {"$ne": "%s.system.indexes" % q.db.name}}
+ )[0]
self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY])
# Check that normal query uses orderby
- qs2 = Person.objects.order_by('name')
+ qs2 = Person.objects.order_by("name")
with query_counter() as q:
for x in qs2:
pass
- op = q.db.system.profile.find({"ns":
- {"$ne": "%s.system.indexes" % q.db.name}})[0]
+ op = q.db.system.profile.find(
+ {"ns": {"$ne": "%s.system.indexes" % q.db.name}}
+ )[0]
self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY])
@@ -5317,9 +5370,7 @@ class QuerySetTest(unittest.TestCase):
class Person(Document):
name = StringField()
- meta = {
- 'ordering': ['name']
- }
+ meta = {"ordering": ["name"]}
Person.drop_collection()
@@ -5332,15 +5383,20 @@ class QuerySetTest(unittest.TestCase):
if Person.objects:
pass
- op = q.db.system.profile.find({"ns":
- {"$ne": "%s.system.indexes" % q.db.name}})[0]
+ op = q.db.system.profile.find(
+ {"ns": {"$ne": "%s.system.indexes" % q.db.name}}
+ )[0]
- self.assertNotIn('$orderby', op[CMD_QUERY_KEY],
- 'BaseQuerySet must remove orderby from meta in boolen test')
+ self.assertNotIn(
+ "$orderby",
+ op[CMD_QUERY_KEY],
+ "BaseQuerySet must remove orderby from meta in boolen test",
+ )
- self.assertEqual(Person.objects.first().name, 'A')
- self.assertTrue(Person.objects._has_data(),
- 'Cursor has data and returned False')
+ self.assertEqual(Person.objects.first().name, "A")
+ self.assertTrue(
+ Person.objects._has_data(), "Cursor has data and returned False"
+ )
def test_queryset_aggregation_framework(self):
class Person(Document):
@@ -5355,40 +5411,44 @@ class QuerySetTest(unittest.TestCase):
Person.objects.insert([p1, p2, p3])
data = Person.objects(age__lte=22).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ {"$project": {"name": {"$toUpper": "$name"}}}
)
- self.assertEqual(list(data), [
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"},
- {'_id': p2.pk, 'name': "WILSON JUNIOR"}
- ])
-
- data = Person.objects(age__lte=22).order_by('-name').aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ self.assertEqual(
+ list(data),
+ [
+ {"_id": p1.pk, "name": "ISABELLA LUANNA"},
+ {"_id": p2.pk, "name": "WILSON JUNIOR"},
+ ],
)
- self.assertEqual(list(data), [
- {'_id': p2.pk, 'name': "WILSON JUNIOR"},
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"}
- ])
+ data = (
+ Person.objects(age__lte=22)
+ .order_by("-name")
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
+ )
- data = Person.objects(age__gte=17, age__lte=40).order_by('-age').aggregate({
- '$group': {
- '_id': None,
- 'total': {'$sum': 1},
- 'avg': {'$avg': '$age'}
- }
- })
- self.assertEqual(list(data), [
- {'_id': None, 'avg': 29, 'total': 2}
- ])
+ self.assertEqual(
+ list(data),
+ [
+ {"_id": p2.pk, "name": "WILSON JUNIOR"},
+ {"_id": p1.pk, "name": "ISABELLA LUANNA"},
+ ],
+ )
- data = Person.objects().aggregate({'$match': {'name': 'Isabella Luanna'}})
- self.assertEqual(list(data), [
- {u'_id': p1.pk,
- u'age': 16,
- u'name': u'Isabella Luanna'}]
- )
+ data = (
+ Person.objects(age__gte=17, age__lte=40)
+ .order_by("-age")
+ .aggregate(
+ {"$group": {"_id": None, "total": {"$sum": 1}, "avg": {"$avg": "$age"}}}
+ )
+ )
+ self.assertEqual(list(data), [{"_id": None, "avg": 29, "total": 2}])
+
+ data = Person.objects().aggregate({"$match": {"name": "Isabella Luanna"}})
+ self.assertEqual(
+ list(data), [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}]
+ )
def test_queryset_aggregation_with_skip(self):
class Person(Document):
@@ -5403,13 +5463,16 @@ class QuerySetTest(unittest.TestCase):
Person.objects.insert([p1, p2, p3])
data = Person.objects.skip(1).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ {"$project": {"name": {"$toUpper": "$name"}}}
)
- self.assertEqual(list(data), [
- {'_id': p2.pk, 'name': "WILSON JUNIOR"},
- {'_id': p3.pk, 'name': "SANDRA MARA"}
- ])
+ self.assertEqual(
+ list(data),
+ [
+ {"_id": p2.pk, "name": "WILSON JUNIOR"},
+ {"_id": p3.pk, "name": "SANDRA MARA"},
+ ],
+ )
def test_queryset_aggregation_with_limit(self):
class Person(Document):
@@ -5424,12 +5487,10 @@ class QuerySetTest(unittest.TestCase):
Person.objects.insert([p1, p2, p3])
data = Person.objects.limit(1).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ {"$project": {"name": {"$toUpper": "$name"}}}
)
- self.assertEqual(list(data), [
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"}
- ])
+ self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}])
def test_queryset_aggregation_with_sort(self):
class Person(Document):
@@ -5443,15 +5504,18 @@ class QuerySetTest(unittest.TestCase):
p3 = Person(name="Sandra Mara", age=37)
Person.objects.insert([p1, p2, p3])
- data = Person.objects.order_by('name').aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ data = Person.objects.order_by("name").aggregate(
+ {"$project": {"name": {"$toUpper": "$name"}}}
)
- self.assertEqual(list(data), [
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"},
- {'_id': p3.pk, 'name': "SANDRA MARA"},
- {'_id': p2.pk, 'name': "WILSON JUNIOR"}
- ])
+ self.assertEqual(
+ list(data),
+ [
+ {"_id": p1.pk, "name": "ISABELLA LUANNA"},
+ {"_id": p3.pk, "name": "SANDRA MARA"},
+ {"_id": p2.pk, "name": "WILSON JUNIOR"},
+ ],
+ )
def test_queryset_aggregation_with_skip_with_limit(self):
class Person(Document):
@@ -5466,18 +5530,18 @@ class QuerySetTest(unittest.TestCase):
Person.objects.insert([p1, p2, p3])
data = list(
- Person.objects.skip(1).limit(1).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
- )
+ Person.objects.skip(1)
+ .limit(1)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
)
- self.assertEqual(list(data), [
- {'_id': p2.pk, 'name': "WILSON JUNIOR"},
- ])
+ self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}])
# Make sure limit/skip chaining order has no impact
- data2 = Person.objects.limit(1).skip(1).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ data2 = (
+ Person.objects.limit(1)
+ .skip(1)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
)
self.assertEqual(data, list(data2))
@@ -5494,34 +5558,40 @@ class QuerySetTest(unittest.TestCase):
p3 = Person(name="Sandra Mara", age=37)
Person.objects.insert([p1, p2, p3])
- data = Person.objects.order_by('name').limit(2).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ data = (
+ Person.objects.order_by("name")
+ .limit(2)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
)
- self.assertEqual(list(data), [
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"},
- {'_id': p3.pk, 'name': "SANDRA MARA"}
- ])
+ self.assertEqual(
+ list(data),
+ [
+ {"_id": p1.pk, "name": "ISABELLA LUANNA"},
+ {"_id": p3.pk, "name": "SANDRA MARA"},
+ ],
+ )
# Verify adding limit/skip steps works as expected
- data = Person.objects.order_by('name').limit(2).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}},
- {'$limit': 1},
+ data = (
+ Person.objects.order_by("name")
+ .limit(2)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}}, {"$limit": 1})
)
- self.assertEqual(list(data), [
- {'_id': p1.pk, 'name': "ISABELLA LUANNA"},
- ])
+ self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}])
- data = Person.objects.order_by('name').limit(2).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}},
- {'$skip': 1},
- {'$limit': 1},
+ data = (
+ Person.objects.order_by("name")
+ .limit(2)
+ .aggregate(
+ {"$project": {"name": {"$toUpper": "$name"}}},
+ {"$skip": 1},
+ {"$limit": 1},
+ )
)
- self.assertEqual(list(data), [
- {'_id': p3.pk, 'name': "SANDRA MARA"},
- ])
+ self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}])
def test_queryset_aggregation_with_sort_with_skip(self):
class Person(Document):
@@ -5535,13 +5605,13 @@ class QuerySetTest(unittest.TestCase):
p3 = Person(name="Sandra Mara", age=37)
Person.objects.insert([p1, p2, p3])
- data = Person.objects.order_by('name').skip(2).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ data = (
+ Person.objects.order_by("name")
+ .skip(2)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
)
- self.assertEqual(list(data), [
- {'_id': p2.pk, 'name': "WILSON JUNIOR"}
- ])
+ self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}])
def test_queryset_aggregation_with_sort_with_skip_with_limit(self):
class Person(Document):
@@ -5555,35 +5625,42 @@ class QuerySetTest(unittest.TestCase):
p3 = Person(name="Sandra Mara", age=37)
Person.objects.insert([p1, p2, p3])
- data = Person.objects.order_by('name').skip(1).limit(1).aggregate(
- {'$project': {'name': {'$toUpper': '$name'}}}
+ data = (
+ Person.objects.order_by("name")
+ .skip(1)
+ .limit(1)
+ .aggregate({"$project": {"name": {"$toUpper": "$name"}}})
)
- self.assertEqual(list(data), [
- {'_id': p3.pk, 'name': "SANDRA MARA"}
- ])
+ self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}])
def test_delete_count(self):
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
- self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count
+ self.assertEqual(
+ self.Person.objects().delete(), 3
+ ) # test ordinary QuerySey delete count
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
- self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents
+ self.assertEqual(
+ self.Person.objects().skip(1).delete(), 2
+ ) # test Document delete with existing documents
self.Person.objects().delete()
- self.assertEqual(self.Person.objects().skip(1).delete(), 0) # test Document delete without existing documents
+ self.assertEqual(
+ self.Person.objects().skip(1).delete(), 0
+ ) # test Document delete without existing documents
def test_max_time_ms(self):
# 778: max_time_ms can get only int or None as input
- self.assertRaises(TypeError,
- self.Person.objects(name="name").max_time_ms,
- 'not a number')
+ self.assertRaises(
+ TypeError, self.Person.objects(name="name").max_time_ms, "not a number"
+ )
def test_subclass_field_query(self):
class Animal(Document):
is_mamal = BooleanField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class Cat(Animal):
whiskers_length = FloatField()
@@ -5605,14 +5682,15 @@ class QuerySetTest(unittest.TestCase):
Person.drop_collection()
- Person._get_collection().insert_one({'name': 'a', 'id': ''})
+ Person._get_collection().insert_one({"name": "a", "id": ""})
for p in Person.objects():
- self.assertEqual(p.name, 'a')
+ self.assertEqual(p.name, "a")
def test_len_during_iteration(self):
"""Tests that calling len on a queyset during iteration doesn't
stop paging.
"""
+
class Data(Document):
pass
@@ -5645,6 +5723,7 @@ class QuerySetTest(unittest.TestCase):
in a given queryset even if there are multiple iterations of it
happening at the same time.
"""
+
class Data(Document):
pass
@@ -5663,6 +5742,7 @@ class QuerySetTest(unittest.TestCase):
"""Ensure that using the `__in` operator on a non-iterable raises an
error.
"""
+
class User(Document):
name = StringField()
@@ -5673,9 +5753,10 @@ class QuerySetTest(unittest.TestCase):
User.drop_collection()
BlogPost.drop_collection()
- author = User.objects.create(name='Test User')
- post = BlogPost.objects.create(content='Had a good coffee today...',
- authors=[author])
+ author = User.objects.create(name="Test User")
+ post = BlogPost.objects.create(
+ content="Had a good coffee today...", authors=[author]
+ )
# Make sure using `__in` with a list works
blog_posts = BlogPost.objects(authors__in=[author])
@@ -5699,5 +5780,5 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 4)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py
index 2c2d018c..cfcd8c22 100644
--- a/tests/queryset/transform.py
+++ b/tests/queryset/transform.py
@@ -9,25 +9,29 @@ __all__ = ("TransformTest",)
class TransformTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
def test_transform_query(self):
"""Ensure that the _transform_query function operates correctly.
"""
- self.assertEqual(transform.query(name='test', age=30),
- {'name': 'test', 'age': 30})
- self.assertEqual(transform.query(age__lt=30),
- {'age': {'$lt': 30}})
- self.assertEqual(transform.query(age__gt=20, age__lt=50),
- {'age': {'$gt': 20, '$lt': 50}})
- self.assertEqual(transform.query(age=20, age__gt=50),
- {'$and': [{'age': {'$gt': 50}}, {'age': 20}]})
- self.assertEqual(transform.query(friend__age__gte=30),
- {'friend.age': {'$gte': 30}})
- self.assertEqual(transform.query(name__exists=True),
- {'name': {'$exists': True}})
+ self.assertEqual(
+ transform.query(name="test", age=30), {"name": "test", "age": 30}
+ )
+ self.assertEqual(transform.query(age__lt=30), {"age": {"$lt": 30}})
+ self.assertEqual(
+ transform.query(age__gt=20, age__lt=50), {"age": {"$gt": 20, "$lt": 50}}
+ )
+ self.assertEqual(
+ transform.query(age=20, age__gt=50),
+ {"$and": [{"age": {"$gt": 50}}, {"age": 20}]},
+ )
+ self.assertEqual(
+ transform.query(friend__age__gte=30), {"friend.age": {"$gte": 30}}
+ )
+ self.assertEqual(
+ transform.query(name__exists=True), {"name": {"$exists": True}}
+ )
def test_transform_update(self):
class LisDoc(Document):
@@ -46,7 +50,11 @@ class TransformTest(unittest.TestCase):
DicDoc().save()
doc = Doc().save()
- for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")):
+ for k, v in (
+ ("set", "$set"),
+ ("set_on_insert", "$setOnInsert"),
+ ("push", "$push"),
+ ):
update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc})
self.assertIsInstance(update[v]["dictField.test"], dict)
@@ -57,55 +65,61 @@ class TransformTest(unittest.TestCase):
update = transform.update(DicDoc, pull__dictField__test=doc)
self.assertIsInstance(update["$pull"]["dictField"]["test"], dict)
- update = transform.update(LisDoc, pull__foo__in=['a'])
- self.assertEqual(update, {'$pull': {'foo': {'$in': ['a']}}})
+ update = transform.update(LisDoc, pull__foo__in=["a"])
+ self.assertEqual(update, {"$pull": {"foo": {"$in": ["a"]}}})
def test_transform_update_push(self):
"""Ensure the differences in behvaior between 'push' and 'push_all'"""
+
class BlogPost(Document):
tags = ListField(StringField())
- update = transform.update(BlogPost, push__tags=['mongo', 'db'])
- self.assertEqual(update, {'$push': {'tags': ['mongo', 'db']}})
+ update = transform.update(BlogPost, push__tags=["mongo", "db"])
+ self.assertEqual(update, {"$push": {"tags": ["mongo", "db"]}})
- update = transform.update(BlogPost, push_all__tags=['mongo', 'db'])
- self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}})
+ update = transform.update(BlogPost, push_all__tags=["mongo", "db"])
+ self.assertEqual(update, {"$push": {"tags": {"$each": ["mongo", "db"]}}})
def test_transform_update_no_operator_default_to_set(self):
"""Ensure the differences in behvaior between 'push' and 'push_all'"""
+
class BlogPost(Document):
tags = ListField(StringField())
- update = transform.update(BlogPost, tags=['mongo', 'db'])
- self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}})
+ update = transform.update(BlogPost, tags=["mongo", "db"])
+ self.assertEqual(update, {"$set": {"tags": ["mongo", "db"]}})
def test_query_field_name(self):
"""Ensure that the correct field name is used when querying.
"""
+
class Comment(EmbeddedDocument):
- content = StringField(db_field='commentContent')
+ content = StringField(db_field="commentContent")
class BlogPost(Document):
- title = StringField(db_field='postTitle')
- comments = ListField(EmbeddedDocumentField(Comment),
- db_field='postComments')
+ title = StringField(db_field="postTitle")
+ comments = ListField(
+ EmbeddedDocumentField(Comment), db_field="postComments"
+ )
BlogPost.drop_collection()
- data = {'title': 'Post 1', 'comments': [Comment(content='test')]}
+ data = {"title": "Post 1", "comments": [Comment(content="test")]}
post = BlogPost(**data)
post.save()
- self.assertIn('postTitle', BlogPost.objects(title=data['title'])._query)
- self.assertFalse('title' in
- BlogPost.objects(title=data['title'])._query)
- self.assertEqual(BlogPost.objects(title=data['title']).count(), 1)
+ self.assertIn("postTitle", BlogPost.objects(title=data["title"])._query)
+ self.assertFalse("title" in BlogPost.objects(title=data["title"])._query)
+ self.assertEqual(BlogPost.objects(title=data["title"]).count(), 1)
- self.assertIn('_id', BlogPost.objects(pk=post.id)._query)
+ self.assertIn("_id", BlogPost.objects(pk=post.id)._query)
self.assertEqual(BlogPost.objects(pk=post.id).count(), 1)
- self.assertIn('postComments.commentContent', BlogPost.objects(comments__content='test')._query)
- self.assertEqual(BlogPost.objects(comments__content='test').count(), 1)
+ self.assertIn(
+ "postComments.commentContent",
+ BlogPost.objects(comments__content="test")._query,
+ )
+ self.assertEqual(BlogPost.objects(comments__content="test").count(), 1)
BlogPost.drop_collection()
@@ -113,18 +127,19 @@ class TransformTest(unittest.TestCase):
"""Ensure that the correct "primary key" field name is used when
querying
"""
+
class BlogPost(Document):
- title = StringField(primary_key=True, db_field='postTitle')
+ title = StringField(primary_key=True, db_field="postTitle")
BlogPost.drop_collection()
- data = {'title': 'Post 1'}
+ data = {"title": "Post 1"}
post = BlogPost(**data)
post.save()
- self.assertIn('_id', BlogPost.objects(pk=data['title'])._query)
- self.assertIn('_id', BlogPost.objects(title=data['title'])._query)
- self.assertEqual(BlogPost.objects(pk=data['title']).count(), 1)
+ self.assertIn("_id", BlogPost.objects(pk=data["title"])._query)
+ self.assertIn("_id", BlogPost.objects(title=data["title"])._query)
+ self.assertEqual(BlogPost.objects(pk=data["title"]).count(), 1)
BlogPost.drop_collection()
@@ -156,78 +171,125 @@ class TransformTest(unittest.TestCase):
"""
Test raw plays nicely
"""
+
class Foo(Document):
name = StringField()
a = StringField()
b = StringField()
c = StringField()
- meta = {
- 'allow_inheritance': False
- }
+ meta = {"allow_inheritance": False}
- query = Foo.objects(__raw__={'$nor': [{'name': 'bar'}]})._query
- self.assertEqual(query, {'$nor': [{'name': 'bar'}]})
+ query = Foo.objects(__raw__={"$nor": [{"name": "bar"}]})._query
+ self.assertEqual(query, {"$nor": [{"name": "bar"}]})
- q1 = {'$or': [{'a': 1}, {'b': 1}]}
+ q1 = {"$or": [{"a": 1}, {"b": 1}]}
query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query
- self.assertEqual(query, {'$or': [{'a': 1}, {'b': 1}], 'c': 1})
+ self.assertEqual(query, {"$or": [{"a": 1}, {"b": 1}], "c": 1})
def test_raw_and_merging(self):
class Doc(Document):
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
- raw_query = Doc.objects(__raw__={
- 'deleted': False,
- 'scraped': 'yes',
- '$nor': [
- {'views.extracted': 'no'},
- {'attachments.views.extracted': 'no'}
- ]
- })._query
+ raw_query = Doc.objects(
+ __raw__={
+ "deleted": False,
+ "scraped": "yes",
+ "$nor": [
+ {"views.extracted": "no"},
+ {"attachments.views.extracted": "no"},
+ ],
+ }
+ )._query
- self.assertEqual(raw_query, {
- 'deleted': False,
- 'scraped': 'yes',
- '$nor': [
- {'views.extracted': 'no'},
- {'attachments.views.extracted': 'no'}
- ]
- })
+ self.assertEqual(
+ raw_query,
+ {
+ "deleted": False,
+ "scraped": "yes",
+ "$nor": [
+ {"views.extracted": "no"},
+ {"attachments.views.extracted": "no"},
+ ],
+ },
+ )
def test_geojson_PointField(self):
class Location(Document):
loc = PointField()
update = transform.update(Location, set__loc=[1, 2])
- self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}})
+ self.assertEqual(
+ update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}}
+ )
- update = transform.update(Location, set__loc={"type": "Point", "coordinates": [1, 2]})
- self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}})
+ update = transform.update(
+ Location, set__loc={"type": "Point", "coordinates": [1, 2]}
+ )
+ self.assertEqual(
+ update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}}
+ )
def test_geojson_LineStringField(self):
class Location(Document):
line = LineStringField()
update = transform.update(Location, set__line=[[1, 2], [2, 2]])
- self.assertEqual(update, {'$set': {'line': {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}})
+ self.assertEqual(
+ update,
+ {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}},
+ )
- update = transform.update(Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]})
- self.assertEqual(update, {'$set': {'line': {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}})
+ update = transform.update(
+ Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]}
+ )
+ self.assertEqual(
+ update,
+ {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}},
+ )
def test_geojson_PolygonField(self):
class Location(Document):
poly = PolygonField()
- update = transform.update(Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]])
- self.assertEqual(update, {'$set': {'poly': {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}}})
+ update = transform.update(
+ Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]
+ )
+ self.assertEqual(
+ update,
+ {
+ "$set": {
+ "poly": {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]],
+ }
+ }
+ },
+ )
- update = transform.update(Location, set__poly={"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]})
- self.assertEqual(update, {'$set': {'poly': {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}}})
+ update = transform.update(
+ Location,
+ set__poly={
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]],
+ },
+ )
+ self.assertEqual(
+ update,
+ {
+ "$set": {
+ "poly": {
+ "type": "Polygon",
+ "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]],
+ }
+ }
+ },
+ )
def test_type(self):
class Doc(Document):
df = DynamicField()
+
Doc(df=True).save()
Doc(df=7).save()
Doc(df="df").save()
@@ -252,7 +314,7 @@ class TransformTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count())
self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count())
- Doc.objects(id=doc.id).update(set__item__type__='sword')
+ Doc.objects(id=doc.id).update(set__item__type__="sword")
self.assertEqual(1, Doc.objects(item__type__="sword").count())
self.assertEqual(0, Doc.objects(item__type__="axe").count())
@@ -272,6 +334,7 @@ class TransformTest(unittest.TestCase):
Test added to check pull operation in update for
EmbeddedDocumentListField which is inside a EmbeddedDocumentField
"""
+
class Word(EmbeddedDocument):
word = StringField()
index = IntField()
@@ -284,18 +347,27 @@ class TransformTest(unittest.TestCase):
title = StringField()
content = EmbeddedDocumentField(SubDoc)
- word = Word(word='abc', index=1)
+ word = Word(word="abc", index=1)
update = transform.update(MainDoc, pull__content__text=word)
- self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}})
+ self.assertEqual(
+ update, {"$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])}}
+ )
- update = transform.update(MainDoc, pull__content__heading='xyz')
- self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}})
+ update = transform.update(MainDoc, pull__content__heading="xyz")
+ self.assertEqual(update, {"$pull": {"content.heading": "xyz"}})
- update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar'])
- self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}})
+ update = transform.update(MainDoc, pull__content__text__word__in=["foo", "bar"])
+ self.assertEqual(
+ update, {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}}
+ )
- update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar'])
- self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}})
+ update = transform.update(
+ MainDoc, pull__content__text__word__nin=["foo", "bar"]
+ )
+ self.assertEqual(
+ update, {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}}
+ )
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py
index 22d274a8..0a22416f 100644
--- a/tests/queryset/visitor.py
+++ b/tests/queryset/visitor.py
@@ -12,14 +12,13 @@ __all__ = ("QTest",)
class QTest(unittest.TestCase):
-
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
class Person(Document):
name = StringField()
age = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
Person.drop_collection()
self.Person = Person
@@ -30,22 +29,22 @@ class QTest(unittest.TestCase):
q1 = Q()
q2 = Q(age__gte=18)
q3 = Q()
- q4 = Q(name='test')
+ q4 = Q(name="test")
q5 = Q()
class Person(Document):
name = StringField()
age = IntField()
- query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]}
+ query = {"$or": [{"age": {"$gte": 18}}, {"name": "test"}]}
self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query)
- query = {'age': {'$gte': 18}, 'name': 'test'}
+ query = {"age": {"$gte": 18}, "name": "test"}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly"""
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
class User(Document):
pass
@@ -62,15 +61,18 @@ class QTest(unittest.TestCase):
def test_and_combination(self):
"""Ensure that Q-objects correctly AND together.
"""
+
class TestDoc(Document):
x = IntField()
y = StringField()
query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc)
- self.assertEqual(query, {'$and': [{'x': {'$lt': 7}}, {'x': {'$lt': 3}}]})
+ self.assertEqual(query, {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]})
query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc)
- self.assertEqual(query, {'$and': [{'y': "a"}, {'x': {'$lt': 7}}, {'x': {'$lt': 3}}]})
+ self.assertEqual(
+ query, {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]}
+ )
# Check normal cases work without an error
query = Q(x__lt=7) & Q(x__gt=3)
@@ -78,69 +80,74 @@ class QTest(unittest.TestCase):
q1 = Q(x__lt=7)
q2 = Q(x__gt=3)
query = (q1 & q2).to_query(TestDoc)
- self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}})
+ self.assertEqual(query, {"x": {"$lt": 7, "$gt": 3}})
# More complex nested example
- query = Q(x__lt=100) & Q(y__ne='NotMyString')
- query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
+ query = Q(x__lt=100) & Q(y__ne="NotMyString")
+ query &= Q(y__in=["a", "b", "c"]) & Q(x__gt=-100)
mongo_query = {
- 'x': {'$lt': 100, '$gt': -100},
- 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
+ "x": {"$lt": 100, "$gt": -100},
+ "y": {"$ne": "NotMyString", "$in": ["a", "b", "c"]},
}
self.assertEqual(query.to_query(TestDoc), mongo_query)
def test_or_combination(self):
"""Ensure that Q-objects correctly OR together.
"""
+
class TestDoc(Document):
x = IntField()
q1 = Q(x__lt=3)
q2 = Q(x__gt=7)
query = (q1 | q2).to_query(TestDoc)
- self.assertEqual(query, {
- '$or': [
- {'x': {'$lt': 3}},
- {'x': {'$gt': 7}},
- ]
- })
+ self.assertEqual(query, {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]})
def test_and_or_combination(self):
"""Ensure that Q-objects handle ANDing ORed components.
"""
+
class TestDoc(Document):
x = IntField()
y = BooleanField()
TestDoc.drop_collection()
- query = (Q(x__gt=0) | Q(x__exists=False))
+ query = Q(x__gt=0) | Q(x__exists=False)
query &= Q(x__lt=100)
- self.assertEqual(query.to_query(TestDoc), {'$and': [
- {'$or': [{'x': {'$gt': 0}},
- {'x': {'$exists': False}}]},
- {'x': {'$lt': 100}}]
- })
+ self.assertEqual(
+ query.to_query(TestDoc),
+ {
+ "$and": [
+ {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]},
+ {"x": {"$lt": 100}},
+ ]
+ },
+ )
- q1 = (Q(x__gt=0) | Q(x__exists=False))
- q2 = (Q(x__lt=100) | Q(y=True))
+ q1 = Q(x__gt=0) | Q(x__exists=False)
+ q2 = Q(x__lt=100) | Q(y=True)
query = (q1 & q2).to_query(TestDoc)
TestDoc(x=101).save()
TestDoc(x=10).save()
TestDoc(y=True).save()
- self.assertEqual(query, {
- '$and': [
- {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]},
- {'$or': [{'x': {'$lt': 100}}, {'y': True}]}
- ]
- })
+ self.assertEqual(
+ query,
+ {
+ "$and": [
+ {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]},
+ {"$or": [{"x": {"$lt": 100}}, {"y": True}]},
+ ]
+ },
+ )
self.assertEqual(2, TestDoc.objects(q1 & q2).count())
def test_or_and_or_combination(self):
"""Ensure that Q-objects handle ORing ANDed ORed components. :)
"""
+
class TestDoc(Document):
x = IntField()
y = BooleanField()
@@ -151,18 +158,29 @@ class QTest(unittest.TestCase):
TestDoc(x=99, y=False).save()
TestDoc(x=101, y=False).save()
- q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False)))
- q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)))
+ q1 = Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))
+ q2 = Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))
query = (q1 | q2).to_query(TestDoc)
- self.assertEqual(query, {
- '$or': [
- {'$and': [{'x': {'$gt': 0}},
- {'$or': [{'y': True}, {'y': {'$exists': False}}]}]},
- {'$and': [{'x': {'$lt': 100}},
- {'$or': [{'y': False}, {'y': {'$exists': False}}]}]}
- ]
- })
+ self.assertEqual(
+ query,
+ {
+ "$or": [
+ {
+ "$and": [
+ {"x": {"$gt": 0}},
+ {"$or": [{"y": True}, {"y": {"$exists": False}}]},
+ ]
+ },
+ {
+ "$and": [
+ {"x": {"$lt": 100}},
+ {"$or": [{"y": False}, {"y": {"$exists": False}}]},
+ ]
+ },
+ ]
+ },
+ )
self.assertEqual(2, TestDoc.objects(q1 | q2).count())
def test_multiple_occurence_in_field(self):
@@ -170,8 +188,8 @@ class QTest(unittest.TestCase):
name = StringField(max_length=40)
title = StringField(max_length=40)
- q1 = Q(name__contains='te') | Q(title__contains='te')
- q2 = Q(name__contains='12') | Q(title__contains='12')
+ q1 = Q(name__contains="te") | Q(title__contains="te")
+ q2 = Q(name__contains="12") | Q(title__contains="12")
q3 = q1 & q2
@@ -180,7 +198,6 @@ class QTest(unittest.TestCase):
self.assertEqual(query["$and"][1], q2.to_query(Test))
def test_q_clone(self):
-
class TestDoc(Document):
x = IntField()
@@ -205,6 +222,7 @@ class QTest(unittest.TestCase):
def test_q(self):
"""Ensure that Q objects may be used to query for documents.
"""
+
class BlogPost(Document):
title = StringField()
publish_date = DateTimeField()
@@ -212,22 +230,26 @@ class QTest(unittest.TestCase):
BlogPost.drop_collection()
- post1 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 8), published=False)
+ post1 = BlogPost(
+ title="Test 1", publish_date=datetime.datetime(2010, 1, 8), published=False
+ )
post1.save()
- post2 = BlogPost(title='Test 2', publish_date=datetime.datetime(2010, 1, 15), published=True)
+ post2 = BlogPost(
+ title="Test 2", publish_date=datetime.datetime(2010, 1, 15), published=True
+ )
post2.save()
- post3 = BlogPost(title='Test 3', published=True)
+ post3 = BlogPost(title="Test 3", published=True)
post3.save()
- post4 = BlogPost(title='Test 4', publish_date=datetime.datetime(2010, 1, 8))
+ post4 = BlogPost(title="Test 4", publish_date=datetime.datetime(2010, 1, 8))
post4.save()
- post5 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 15))
+ post5 = BlogPost(title="Test 1", publish_date=datetime.datetime(2010, 1, 15))
post5.save()
- post6 = BlogPost(title='Test 1', published=False)
+ post6 = BlogPost(title="Test 1", published=False)
post6.save()
# Check ObjectId lookup works
@@ -235,13 +257,13 @@ class QTest(unittest.TestCase):
self.assertEqual(obj, post1)
# Check Q object combination with one does not exist
- q = BlogPost.objects(Q(title='Test 5') | Q(published=True))
+ q = BlogPost.objects(Q(title="Test 5") | Q(published=True))
posts = [post.id for post in q]
published_posts = (post2, post3)
self.assertTrue(all(obj.id in posts for obj in published_posts))
- q = BlogPost.objects(Q(title='Test 1') | Q(published=True))
+ q = BlogPost.objects(Q(title="Test 1") | Q(published=True))
posts = [post.id for post in q]
published_posts = (post1, post2, post3, post5, post6)
self.assertTrue(all(obj.id in posts for obj in published_posts))
@@ -259,85 +281,91 @@ class QTest(unittest.TestCase):
BlogPost.drop_collection()
# Check the 'in' operator
- self.Person(name='user1', age=20).save()
- self.Person(name='user2', age=20).save()
- self.Person(name='user3', age=30).save()
- self.Person(name='user4', age=40).save()
+ self.Person(name="user1", age=20).save()
+ self.Person(name="user2", age=20).save()
+ self.Person(name="user3", age=30).save()
+ self.Person(name="user4", age=40).save()
self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2)
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
# Test invalid query objs
with self.assertRaises(InvalidQueryError):
- self.Person.objects('user1')
+ self.Person.objects("user1")
# filter should fail, too
with self.assertRaises(InvalidQueryError):
- self.Person.objects.filter('user1')
+ self.Person.objects.filter("user1")
def test_q_regex(self):
"""Ensure that Q objects can be queried using regexes.
"""
- person = self.Person(name='Guido van Rossum')
+ person = self.Person(name="Guido van Rossum")
person.save()
- obj = self.Person.objects(Q(name=re.compile('^Gui'))).first()
+ obj = self.Person.objects(Q(name=re.compile("^Gui"))).first()
self.assertEqual(obj, person)
- obj = self.Person.objects(Q(name=re.compile('^gui'))).first()
+ obj = self.Person.objects(Q(name=re.compile("^gui"))).first()
self.assertEqual(obj, None)
- obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first()
+ obj = self.Person.objects(Q(name=re.compile("^gui", re.I))).first()
self.assertEqual(obj, person)
- obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
+ obj = self.Person.objects(Q(name__not=re.compile("^bob"))).first()
self.assertEqual(obj, person)
- obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
+ obj = self.Person.objects(Q(name__not=re.compile("^Gui"))).first()
self.assertEqual(obj, None)
def test_q_repr(self):
- self.assertEqual(repr(Q()), 'Q(**{})')
- self.assertEqual(repr(Q(name='test')), "Q(**{'name': 'test'})")
+ self.assertEqual(repr(Q()), "Q(**{})")
+ self.assertEqual(repr(Q(name="test")), "Q(**{'name': 'test'})")
self.assertEqual(
- repr(Q(name='test') & Q(age__gte=18)),
- "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))")
+ repr(Q(name="test") & Q(age__gte=18)),
+ "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))",
+ )
self.assertEqual(
- repr(Q(name='test') | Q(age__gte=18)),
- "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))")
+ repr(Q(name="test") | Q(age__gte=18)),
+ "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))",
+ )
def test_q_lists(self):
"""Ensure that Q objects query ListFields correctly.
"""
+
class BlogPost(Document):
tags = ListField(StringField())
BlogPost.drop_collection()
- BlogPost(tags=['python', 'mongo']).save()
- BlogPost(tags=['python']).save()
+ BlogPost(tags=["python", "mongo"]).save()
+ BlogPost(tags=["python"]).save()
- self.assertEqual(BlogPost.objects(Q(tags='mongo')).count(), 1)
- self.assertEqual(BlogPost.objects(Q(tags='python')).count(), 2)
+ self.assertEqual(BlogPost.objects(Q(tags="mongo")).count(), 1)
+ self.assertEqual(BlogPost.objects(Q(tags="python")).count(), 2)
BlogPost.drop_collection()
def test_q_merge_queries_edge_case(self):
-
class User(Document):
email = EmailField(required=False)
name = StringField()
User.drop_collection()
pk = ObjectId()
- User(email='example@example.com', pk=pk).save()
+ User(email="example@example.com", pk=pk).save()
- self.assertEqual(1, User.objects.filter(Q(email='example@example.com') |
- Q(name='John Doe')).limit(2).filter(pk=pk).count())
+ self.assertEqual(
+ 1,
+ User.objects.filter(Q(email="example@example.com") | Q(name="John Doe"))
+ .limit(2)
+ .filter(pk=pk)
+ .count(),
+ )
def test_chained_q_or_filtering(self):
-
class Post(EmbeddedDocument):
name = StringField(required=True)
@@ -350,9 +378,16 @@ class QTest(unittest.TestCase):
Item(postables=[Post(name="a"), Post(name="c")]).save()
Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save()
- self.assertEqual(Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2)
- self.assertEqual(Item.objects.filter(postables__name="a").filter(postables__name="b").count(), 2)
+ self.assertEqual(
+ Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2
+ )
+ self.assertEqual(
+ Item.objects.filter(postables__name="a")
+ .filter(postables__name="b")
+ .count(),
+ 2,
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_common.py b/tests/test_common.py
index 04ad5b34..5d702668 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -5,7 +5,6 @@ from mongoengine import Document
class TestCommon(unittest.TestCase):
-
def test__import_class(self):
doc_cls = _import_class("Document")
self.assertIs(doc_cls, Document)
diff --git a/tests/test_connection.py b/tests/test_connection.py
index d3fcc395..25007132 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -14,12 +14,21 @@ import pymongo
from bson.tz_util import utc
from mongoengine import (
- connect, register_connection,
- Document, DateTimeField,
- disconnect_all, StringField)
+ connect,
+ register_connection,
+ Document,
+ DateTimeField,
+ disconnect_all,
+ StringField,
+)
import mongoengine.connection
-from mongoengine.connection import (MongoEngineConnectionError, get_db,
- get_connection, disconnect, DEFAULT_DATABASE_NAME)
+from mongoengine.connection import (
+ MongoEngineConnectionError,
+ get_db,
+ get_connection,
+ disconnect,
+ DEFAULT_DATABASE_NAME,
+)
def get_tz_awareness(connection):
@@ -27,7 +36,6 @@ def get_tz_awareness(connection):
class ConnectionTest(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
disconnect_all()
@@ -43,44 +51,46 @@ class ConnectionTest(unittest.TestCase):
def test_connect(self):
"""Ensure that the connect() method works properly."""
- connect('mongoenginetest')
+ connect("mongoenginetest")
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'mongoenginetest')
+ self.assertEqual(db.name, "mongoenginetest")
- connect('mongoenginetest2', alias='testdb')
- conn = get_connection('testdb')
+ connect("mongoenginetest2", alias="testdb")
+ conn = get_connection("testdb")
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
def test_connect_disconnect_works_properly(self):
class History1(Document):
name = StringField()
- meta = {'db_alias': 'db1'}
+ meta = {"db_alias": "db1"}
class History2(Document):
name = StringField()
- meta = {'db_alias': 'db2'}
+ meta = {"db_alias": "db2"}
- connect('db1', alias='db1')
- connect('db2', alias='db2')
+ connect("db1", alias="db1")
+ connect("db2", alias="db2")
History1.drop_collection()
History2.drop_collection()
- h = History1(name='default').save()
- h1 = History2(name='db1').save()
+ h = History1(name="default").save()
+ h1 = History2(name="db1").save()
- self.assertEqual(list(History1.objects().as_pymongo()),
- [{'_id': h.id, 'name': 'default'}])
- self.assertEqual(list(History2.objects().as_pymongo()),
- [{'_id': h1.id, 'name': 'db1'}])
+ self.assertEqual(
+ list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}]
+ )
+ self.assertEqual(
+ list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}]
+ )
- disconnect('db1')
- disconnect('db2')
+ disconnect("db1")
+ disconnect("db2")
with self.assertRaises(MongoEngineConnectionError):
list(History1.objects().as_pymongo())
@@ -88,13 +98,15 @@ class ConnectionTest(unittest.TestCase):
with self.assertRaises(MongoEngineConnectionError):
list(History2.objects().as_pymongo())
- connect('db1', alias='db1')
- connect('db2', alias='db2')
+ connect("db1", alias="db1")
+ connect("db2", alias="db2")
- self.assertEqual(list(History1.objects().as_pymongo()),
- [{'_id': h.id, 'name': 'default'}])
- self.assertEqual(list(History2.objects().as_pymongo()),
- [{'_id': h1.id, 'name': 'db1'}])
+ self.assertEqual(
+ list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}]
+ )
+ self.assertEqual(
+ list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}]
+ )
def test_connect_different_documents_to_different_database(self):
class History(Document):
@@ -102,99 +114,110 @@ class ConnectionTest(unittest.TestCase):
class History1(Document):
name = StringField()
- meta = {'db_alias': 'db1'}
+ meta = {"db_alias": "db1"}
class History2(Document):
name = StringField()
- meta = {'db_alias': 'db2'}
+ meta = {"db_alias": "db2"}
connect()
- connect('db1', alias='db1')
- connect('db2', alias='db2')
+ connect("db1", alias="db1")
+ connect("db2", alias="db2")
History.drop_collection()
History1.drop_collection()
History2.drop_collection()
- h = History(name='default').save()
- h1 = History1(name='db1').save()
- h2 = History2(name='db2').save()
+ h = History(name="default").save()
+ h1 = History1(name="db1").save()
+ h2 = History2(name="db2").save()
self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME)
- self.assertEqual(History1._collection.database.name, 'db1')
- self.assertEqual(History2._collection.database.name, 'db2')
+ self.assertEqual(History1._collection.database.name, "db1")
+ self.assertEqual(History2._collection.database.name, "db2")
- self.assertEqual(list(History.objects().as_pymongo()),
- [{'_id': h.id, 'name': 'default'}])
- self.assertEqual(list(History1.objects().as_pymongo()),
- [{'_id': h1.id, 'name': 'db1'}])
- self.assertEqual(list(History2.objects().as_pymongo()),
- [{'_id': h2.id, 'name': 'db2'}])
+ self.assertEqual(
+ list(History.objects().as_pymongo()), [{"_id": h.id, "name": "default"}]
+ )
+ self.assertEqual(
+ list(History1.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}]
+ )
+ self.assertEqual(
+ list(History2.objects().as_pymongo()), [{"_id": h2.id, "name": "db2"}]
+ )
def test_connect_fails_if_connect_2_times_with_default_alias(self):
- connect('mongoenginetest')
+ connect("mongoenginetest")
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
- connect('mongoenginetest2')
- self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception))
+ connect("mongoenginetest2")
+ self.assertEqual(
+ "A different connection with alias `default` was already registered. Use disconnect() first",
+ str(ctx_err.exception),
+ )
def test_connect_fails_if_connect_2_times_with_custom_alias(self):
- connect('mongoenginetest', alias='alias1')
+ connect("mongoenginetest", alias="alias1")
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
- connect('mongoenginetest2', alias='alias1')
+ connect("mongoenginetest2", alias="alias1")
- self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception))
+ self.assertEqual(
+ "A different connection with alias `alias1` was already registered. Use disconnect() first",
+ str(ctx_err.exception),
+ )
- def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self):
+ def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(
+ self
+ ):
"""Intended to keep the detecton function simple but robust"""
- db_name = 'mongoenginetest'
- db_alias = 'alias1'
- connect(db=db_name, alias=db_alias, host='localhost', port=27017)
+ db_name = "mongoenginetest"
+ db_alias = "alias1"
+ connect(db=db_name, alias=db_alias, host="localhost", port=27017)
with self.assertRaises(MongoEngineConnectionError):
- connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias)
+ connect(host="mongodb://localhost:27017/%s" % db_name, alias=db_alias)
def test_connect_passes_silently_connect_multiple_times_with_same_config(self):
# test default connection to `test`
connect()
connect()
self.assertEqual(len(mongoengine.connection._connections), 1)
- connect('test01', alias='test01')
- connect('test01', alias='test01')
+ connect("test01", alias="test01")
+ connect("test01", alias="test01")
self.assertEqual(len(mongoengine.connection._connections), 2)
- connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02')
- connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02')
+ connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02")
+ connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02")
self.assertEqual(len(mongoengine.connection._connections), 3)
def test_connect_with_invalid_db_name(self):
"""Ensure that connect() method fails fast if db name is invalid
"""
with self.assertRaises(InvalidName):
- connect('mongomock://localhost')
+ connect("mongomock://localhost")
def test_connect_with_db_name_external(self):
"""Ensure that connect() works if db name is $external
"""
"""Ensure that the connect() method works properly."""
- connect('$external')
+ connect("$external")
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, '$external')
+ self.assertEqual(db.name, "$external")
- connect('$external', alias='testdb')
- conn = get_connection('testdb')
+ connect("$external", alias="testdb")
+ conn = get_connection("testdb")
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
def test_connect_with_invalid_db_name_type(self):
"""Ensure that connect() method fails fast if db name has invalid type
"""
with self.assertRaises(TypeError):
- non_string_db_name = ['e. g. list instead of a string']
+ non_string_db_name = ["e. g. list instead of a string"]
connect(non_string_db_name)
def test_connect_in_mocking(self):
@@ -203,34 +226,47 @@ class ConnectionTest(unittest.TestCase):
try:
import mongomock
except ImportError:
- raise SkipTest('you need mongomock installed to run this testcase')
+ raise SkipTest("you need mongomock installed to run this testcase")
- connect('mongoenginetest', host='mongomock://localhost')
+ connect("mongoenginetest", host="mongomock://localhost")
conn = get_connection()
self.assertIsInstance(conn, mongomock.MongoClient)
- connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2')
- conn = get_connection('testdb2')
+ connect("mongoenginetest2", host="mongomock://localhost", alias="testdb2")
+ conn = get_connection("testdb2")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3')
- conn = get_connection('testdb3')
+ connect(
+ "mongoenginetest3",
+ host="mongodb://localhost",
+ is_mock=True,
+ alias="testdb3",
+ )
+ conn = get_connection("testdb3")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect('mongoenginetest4', is_mock=True, alias='testdb4')
- conn = get_connection('testdb4')
+ connect("mongoenginetest4", is_mock=True, alias="testdb4")
+ conn = get_connection("testdb4")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5')
- conn = get_connection('testdb5')
+ connect(
+ host="mongodb://localhost:27017/mongoenginetest5",
+ is_mock=True,
+ alias="testdb5",
+ )
+ conn = get_connection("testdb5")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6')
- conn = get_connection('testdb6')
+ connect(host="mongomock://localhost:27017/mongoenginetest6", alias="testdb6")
+ conn = get_connection("testdb6")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7')
- conn = get_connection('testdb7')
+ connect(
+ host="mongomock://localhost:27017/mongoenginetest7",
+ is_mock=True,
+ alias="testdb7",
+ )
+ conn = get_connection("testdb7")
self.assertIsInstance(conn, mongomock.MongoClient)
def test_connect_with_host_list(self):
@@ -241,30 +277,39 @@ class ConnectionTest(unittest.TestCase):
try:
import mongomock
except ImportError:
- raise SkipTest('you need mongomock installed to run this testcase')
+ raise SkipTest("you need mongomock installed to run this testcase")
- connect(host=['mongomock://localhost'])
+ connect(host=["mongomock://localhost"])
conn = get_connection()
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2')
- conn = get_connection('testdb2')
+ connect(host=["mongodb://localhost"], is_mock=True, alias="testdb2")
+ conn = get_connection("testdb2")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host=['localhost'], is_mock=True, alias='testdb3')
- conn = get_connection('testdb3')
+ connect(host=["localhost"], is_mock=True, alias="testdb3")
+ conn = get_connection("testdb3")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host=['mongomock://localhost:27017', 'mongomock://localhost:27018'], alias='testdb4')
- conn = get_connection('testdb4')
+ connect(
+ host=["mongomock://localhost:27017", "mongomock://localhost:27018"],
+ alias="testdb4",
+ )
+ conn = get_connection("testdb4")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5')
- conn = get_connection('testdb5')
+ connect(
+ host=["mongodb://localhost:27017", "mongodb://localhost:27018"],
+ is_mock=True,
+ alias="testdb5",
+ )
+ conn = get_connection("testdb5")
self.assertIsInstance(conn, mongomock.MongoClient)
- connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6')
- conn = get_connection('testdb6')
+ connect(
+ host=["localhost:27017", "localhost:27018"], is_mock=True, alias="testdb6"
+ )
+ conn = get_connection("testdb6")
self.assertIsInstance(conn, mongomock.MongoClient)
def test_disconnect_cleans_globals(self):
@@ -273,7 +318,7 @@ class ConnectionTest(unittest.TestCase):
dbs = mongoengine.connection._dbs
connection_settings = mongoengine.connection._connection_settings
- connect('mongoenginetest')
+ connect("mongoenginetest")
self.assertEqual(len(connections), 1)
self.assertEqual(len(dbs), 0)
@@ -292,7 +337,7 @@ class ConnectionTest(unittest.TestCase):
def test_disconnect_cleans_cached_collection_attribute_in_document(self):
"""Ensure that the disconnect() method works properly"""
- conn1 = connect('mongoenginetest')
+ conn1 = connect("mongoenginetest")
class History(Document):
pass
@@ -301,7 +346,7 @@ class ConnectionTest(unittest.TestCase):
History.drop_collection()
- History.objects.first() # will trigger the caching of _collection attribute
+ History.objects.first() # will trigger the caching of _collection attribute
self.assertIsNotNone(History._collection)
disconnect()
@@ -310,15 +355,17 @@ class ConnectionTest(unittest.TestCase):
with self.assertRaises(MongoEngineConnectionError) as ctx_err:
History.objects.first()
- self.assertEqual("You have not defined a default connection", str(ctx_err.exception))
+ self.assertEqual(
+ "You have not defined a default connection", str(ctx_err.exception)
+ )
def test_connect_disconnect_works_on_same_document(self):
"""Ensure that the connect/disconnect works properly with a single Document"""
- db1 = 'db1'
- db2 = 'db2'
+ db1 = "db1"
+ db2 = "db2"
# Ensure freshness of the 2 databases through pymongo
- client = MongoClient('localhost', 27017)
+ client = MongoClient("localhost", 27017)
client.drop_database(db1)
client.drop_database(db2)
@@ -328,44 +375,44 @@ class ConnectionTest(unittest.TestCase):
class User(Document):
name = StringField(required=True)
- user1 = User(name='John is in db1').save()
+ user1 = User(name="John is in db1").save()
disconnect()
# Make sure save doesnt work at this stage
with self.assertRaises(MongoEngineConnectionError):
- User(name='Wont work').save()
+ User(name="Wont work").save()
# Save in db2
connect(db2)
- user2 = User(name='Bob is in db2').save()
+ user2 = User(name="Bob is in db2").save()
disconnect()
db1_users = list(client[db1].user.find())
- self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}])
+ self.assertEqual(db1_users, [{"_id": user1.id, "name": "John is in db1"}])
db2_users = list(client[db2].user.find())
- self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}])
+ self.assertEqual(db2_users, [{"_id": user2.id, "name": "Bob is in db2"}])
def test_disconnect_silently_pass_if_alias_does_not_exist(self):
connections = mongoengine.connection._connections
self.assertEqual(len(connections), 0)
- disconnect(alias='not_exist')
+ disconnect(alias="not_exist")
def test_disconnect_all(self):
connections = mongoengine.connection._connections
dbs = mongoengine.connection._dbs
connection_settings = mongoengine.connection._connection_settings
- connect('mongoenginetest')
- connect('mongoenginetest2', alias='db1')
+ connect("mongoenginetest")
+ connect("mongoenginetest2", alias="db1")
class History(Document):
pass
class History1(Document):
name = StringField()
- meta = {'db_alias': 'db1'}
+ meta = {"db_alias": "db1"}
- History.drop_collection() # will trigger the caching of _collection attribute
+ History.drop_collection() # will trigger the caching of _collection attribute
History.objects.first()
History1.drop_collection()
History1.objects.first()
@@ -398,11 +445,11 @@ class ConnectionTest(unittest.TestCase):
def test_sharing_connections(self):
"""Ensure that connections are shared when the connection settings are exactly the same
"""
- connect('mongoenginetests', alias='testdb1')
- expected_connection = get_connection('testdb1')
+ connect("mongoenginetests", alias="testdb1")
+ expected_connection = get_connection("testdb1")
- connect('mongoenginetests', alias='testdb2')
- actual_connection = get_connection('testdb2')
+ connect("mongoenginetests", alias="testdb2")
+ actual_connection = get_connection("testdb2")
expected_connection.server_info()
@@ -410,7 +457,7 @@ class ConnectionTest(unittest.TestCase):
def test_connect_uri(self):
"""Ensure that the connect() method works properly with URIs."""
- c = connect(db='mongoenginetest', alias='admin')
+ c = connect(db="mongoenginetest", alias="admin")
c.admin.system.users.delete_many({})
c.mongoenginetest.system.users.delete_many({})
@@ -418,14 +465,16 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password")
c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"])
- connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
+ connect(
+ "testdb_uri", host="mongodb://username:password@localhost/mongoenginetest"
+ )
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'mongoenginetest')
+ self.assertEqual(db.name, "mongoenginetest")
c.admin.system.users.delete_many({})
c.mongoenginetest.system.users.delete_many({})
@@ -434,35 +483,35 @@ class ConnectionTest(unittest.TestCase):
"""Ensure connect() method works properly if the URI doesn't
include a database name.
"""
- connect("mongoenginetest", host='mongodb://localhost/')
+ connect("mongoenginetest", host="mongodb://localhost/")
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'mongoenginetest')
+ self.assertEqual(db.name, "mongoenginetest")
def test_connect_uri_default_db(self):
"""Ensure connect() defaults to the right database name if
the URI and the database_name don't explicitly specify it.
"""
- connect(host='mongodb://localhost/')
+ connect(host="mongodb://localhost/")
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'test')
+ self.assertEqual(db.name, "test")
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
"""
- c = connect(host='mongodb://localhost/mongoenginetest',
- username='user',
- password='pass')
+ c = connect(
+ host="mongodb://localhost/mongoenginetest", username="user", password="pass"
+ )
# OperationFailure means that mongoengine attempted authentication
# w/ the provided username/password and failed - that's the desired
@@ -474,27 +523,31 @@ class ConnectionTest(unittest.TestCase):
option in the URI.
"""
# Create users
- c = connect('mongoenginetest')
+ c = connect("mongoenginetest")
c.admin.system.users.delete_many({})
c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"])
# Authentication fails without "authSource"
test_conn = connect(
- 'mongoenginetest', alias='test1',
- host='mongodb://username2:password@localhost/mongoenginetest'
+ "mongoenginetest",
+ alias="test1",
+ host="mongodb://username2:password@localhost/mongoenginetest",
)
self.assertRaises(OperationFailure, test_conn.server_info)
# Authentication succeeds with "authSource"
authd_conn = connect(
- 'mongoenginetest', alias='test2',
- host=('mongodb://username2:password@localhost/'
- 'mongoenginetest?authSource=admin')
+ "mongoenginetest",
+ alias="test2",
+ host=(
+ "mongodb://username2:password@localhost/"
+ "mongoenginetest?authSource=admin"
+ ),
)
- db = get_db('test2')
+ db = get_db("test2")
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'mongoenginetest')
+ self.assertEqual(db.name, "mongoenginetest")
# Clear all users
authd_conn.admin.system.users.delete_many({})
@@ -502,82 +555,86 @@ class ConnectionTest(unittest.TestCase):
def test_register_connection(self):
"""Ensure that connections with different aliases may be registered.
"""
- register_connection('testdb', 'mongoenginetest2')
+ register_connection("testdb", "mongoenginetest2")
self.assertRaises(MongoEngineConnectionError, get_connection)
- conn = get_connection('testdb')
+ conn = get_connection("testdb")
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
- db = get_db('testdb')
+ db = get_db("testdb")
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'mongoenginetest2')
+ self.assertEqual(db.name, "mongoenginetest2")
def test_register_connection_defaults(self):
"""Ensure that defaults are used when the host and port are None.
"""
- register_connection('testdb', 'mongoenginetest', host=None, port=None)
+ register_connection("testdb", "mongoenginetest", host=None, port=None)
- conn = get_connection('testdb')
+ conn = get_connection("testdb")
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
def test_connection_kwargs(self):
"""Ensure that connection kwargs get passed to pymongo."""
- connect('mongoenginetest', alias='t1', tz_aware=True)
- conn = get_connection('t1')
+ connect("mongoenginetest", alias="t1", tz_aware=True)
+ conn = get_connection("t1")
self.assertTrue(get_tz_awareness(conn))
- connect('mongoenginetest2', alias='t2')
- conn = get_connection('t2')
+ connect("mongoenginetest2", alias="t2")
+ conn = get_connection("t2")
self.assertFalse(get_tz_awareness(conn))
def test_connection_pool_via_kwarg(self):
"""Ensure we can specify a max connection pool size using
a connection kwarg.
"""
- pool_size_kwargs = {'maxpoolsize': 100}
+ pool_size_kwargs = {"maxpoolsize": 100}
- conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs)
+ conn = connect(
+ "mongoenginetest", alias="max_pool_size_via_kwarg", **pool_size_kwargs
+ )
self.assertEqual(conn.max_pool_size, 100)
def test_connection_pool_via_uri(self):
"""Ensure we can specify a max connection pool size using
an option in a connection URI.
"""
- conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri')
+ conn = connect(
+ host="mongodb://localhost/test?maxpoolsize=100",
+ alias="max_pool_size_via_uri",
+ )
self.assertEqual(conn.max_pool_size, 100)
def test_write_concern(self):
"""Ensure write concern can be specified in connect() via
a kwarg or as part of the connection URI.
"""
- conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true')
- conn2 = connect('testing', alias='conn2', w=1, j=True)
- self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
- self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
+ conn1 = connect(alias="conn1", host="mongodb://localhost/testing?w=1&j=true")
+ conn2 = connect("testing", alias="conn2", w=1, j=True)
+ self.assertEqual(conn1.write_concern.document, {"w": 1, "j": True})
+ self.assertEqual(conn2.write_concern.document, {"w": 1, "j": True})
def test_connect_with_replicaset_via_uri(self):
"""Ensure connect() works when specifying a replicaSet via the
MongoDB URI.
"""
- c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
+ c = connect(host="mongodb://localhost/test?replicaSet=local-rs")
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'test')
+ self.assertEqual(db.name, "test")
def test_connect_with_replicaset_via_kwargs(self):
"""Ensure connect() works when specifying a replicaSet via the
connection kwargs
"""
- c = connect(replicaset='local-rs')
- self.assertEqual(c._MongoClient__options.replica_set_name,
- 'local-rs')
+ c = connect(replicaset="local-rs")
+ self.assertEqual(c._MongoClient__options.replica_set_name, "local-rs")
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
- self.assertEqual(db.name, 'test')
+ self.assertEqual(db.name, "test")
def test_connect_tz_aware(self):
- connect('mongoenginetest', tz_aware=True)
+ connect("mongoenginetest", tz_aware=True)
d = datetime.datetime(2010, 5, 5, tzinfo=utc)
class DateDoc(Document):
@@ -590,37 +647,39 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(d, date_doc.the_date)
def test_read_preference_from_parse(self):
- conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
+ conn = connect(
+ host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred"
+ )
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
def test_multiple_connection_settings(self):
- connect('mongoenginetest', alias='t1', host="localhost")
+ connect("mongoenginetest", alias="t1", host="localhost")
- connect('mongoenginetest2', alias='t2', host="127.0.0.1")
+ connect("mongoenginetest2", alias="t2", host="127.0.0.1")
mongo_connections = mongoengine.connection._connections
self.assertEqual(len(mongo_connections.items()), 2)
- self.assertIn('t1', mongo_connections.keys())
- self.assertIn('t2', mongo_connections.keys())
+ self.assertIn("t1", mongo_connections.keys())
+ self.assertIn("t2", mongo_connections.keys())
# Handle PyMongo 3+ Async Connection
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
# Purposely not catching exception to fail test if thrown.
- mongo_connections['t1'].server_info()
- mongo_connections['t2'].server_info()
- self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
- self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
+ mongo_connections["t1"].server_info()
+ mongo_connections["t2"].server_info()
+ self.assertEqual(mongo_connections["t1"].address[0], "localhost")
+ self.assertEqual(mongo_connections["t2"].address[0], "127.0.0.1")
def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self):
- c1 = connect(alias='testdb1', db='testdb1')
- c2 = connect(alias='testdb2', db='testdb2')
+ c1 = connect(alias="testdb1", db="testdb1")
+ c2 = connect(alias="testdb2", db="testdb2")
self.assertIs(c1, c2)
def test_connect_2_databases_uses_different_client_if_different_parameters(self):
- c1 = connect(alias='testdb1', db='testdb1', username='u1')
- c2 = connect(alias='testdb2', db='testdb2', username='u2')
+ c1 = connect(alias="testdb1", db="testdb1", username="u1")
+ c2 = connect(alias="testdb2", db="testdb2", username="u2")
self.assertIsNot(c1, c2)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py
index 529032fe..dc9b9bf3 100644
--- a/tests/test_context_managers.py
+++ b/tests/test_context_managers.py
@@ -2,17 +2,20 @@ import unittest
from mongoengine import *
from mongoengine.connection import get_db
-from mongoengine.context_managers import (switch_db, switch_collection,
- no_sub_classes, no_dereference,
- query_counter)
+from mongoengine.context_managers import (
+ switch_db,
+ switch_collection,
+ no_sub_classes,
+ no_dereference,
+ query_counter,
+)
from mongoengine.pymongo_support import count_documents
class ContextManagersTest(unittest.TestCase):
-
def test_switch_db_context_manager(self):
- connect('mongoenginetest')
- register_connection('testdb-1', 'mongoenginetest2')
+ connect("mongoenginetest")
+ register_connection("testdb-1", "mongoenginetest2")
class Group(Document):
name = StringField()
@@ -22,7 +25,7 @@ class ContextManagersTest(unittest.TestCase):
Group(name="hello - default").save()
self.assertEqual(1, Group.objects.count())
- with switch_db(Group, 'testdb-1') as Group:
+ with switch_db(Group, "testdb-1") as Group:
self.assertEqual(0, Group.objects.count())
@@ -36,21 +39,21 @@ class ContextManagersTest(unittest.TestCase):
self.assertEqual(1, Group.objects.count())
def test_switch_collection_context_manager(self):
- connect('mongoenginetest')
- register_connection(alias='testdb-1', db='mongoenginetest2')
+ connect("mongoenginetest")
+ register_connection(alias="testdb-1", db="mongoenginetest2")
class Group(Document):
name = StringField()
- Group.drop_collection() # drops in default
+ Group.drop_collection() # drops in default
- with switch_collection(Group, 'group1') as Group:
- Group.drop_collection() # drops in group1
+ with switch_collection(Group, "group1") as Group:
+ Group.drop_collection() # drops in group1
Group(name="hello - group").save()
self.assertEqual(1, Group.objects.count())
- with switch_collection(Group, 'group1') as Group:
+ with switch_collection(Group, "group1") as Group:
self.assertEqual(0, Group.objects.count())
@@ -66,7 +69,7 @@ class ContextManagersTest(unittest.TestCase):
def test_no_dereference_context_manager_object_id(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
- connect('mongoenginetest')
+ connect("mongoenginetest")
class User(Document):
name = StringField()
@@ -80,14 +83,14 @@ class ContextManagersTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- User(name='user %s' % i).save()
+ User(name="user %s" % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
- self.assertTrue(Group._fields['members']._auto_dereference)
- self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
+ self.assertTrue(Group._fields["members"]._auto_dereference)
+ self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
@@ -104,7 +107,7 @@ class ContextManagersTest(unittest.TestCase):
def test_no_dereference_context_manager_dbref(self):
"""Ensure that DBRef items in ListFields aren't dereferenced.
"""
- connect('mongoenginetest')
+ connect("mongoenginetest")
class User(Document):
name = StringField()
@@ -118,31 +121,29 @@ class ContextManagersTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- User(name='user %s' % i).save()
+ User(name="user %s" % i).save()
user = User.objects.first()
Group(ref=user, members=User.objects, generic=user).save()
with no_dereference(Group) as NoDeRefGroup:
- self.assertTrue(Group._fields['members']._auto_dereference)
- self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
+ self.assertTrue(Group._fields["members"]._auto_dereference)
+ self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference)
with no_dereference(Group) as Group:
group = Group.objects.first()
- self.assertTrue(all([not isinstance(m, User)
- for m in group.members]))
+ self.assertTrue(all([not isinstance(m, User) for m in group.members]))
self.assertNotIsInstance(group.ref, User)
self.assertNotIsInstance(group.generic, User)
- self.assertTrue(all([isinstance(m, User)
- for m in group.members]))
+ self.assertTrue(all([isinstance(m, User) for m in group.members]))
self.assertIsInstance(group.ref, User)
self.assertIsInstance(group.generic, User)
def test_no_sub_classes(self):
class A(Document):
x = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class B(A):
z = IntField()
@@ -188,20 +189,20 @@ class ContextManagersTest(unittest.TestCase):
def test_no_sub_classes_modification_to_document_class_are_temporary(self):
class A(Document):
x = IntField()
- meta = {'allow_inheritance': True}
+ meta = {"allow_inheritance": True}
class B(A):
z = IntField()
- self.assertEqual(A._subclasses, ('A', 'A.B'))
+ self.assertEqual(A._subclasses, ("A", "A.B"))
with no_sub_classes(A):
- self.assertEqual(A._subclasses, ('A',))
- self.assertEqual(A._subclasses, ('A', 'A.B'))
+ self.assertEqual(A._subclasses, ("A",))
+ self.assertEqual(A._subclasses, ("A", "A.B"))
- self.assertEqual(B._subclasses, ('A.B',))
+ self.assertEqual(B._subclasses, ("A.B",))
with no_sub_classes(B):
- self.assertEqual(B._subclasses, ('A.B',))
- self.assertEqual(B._subclasses, ('A.B',))
+ self.assertEqual(B._subclasses, ("A.B",))
+ self.assertEqual(B._subclasses, ("A.B",))
def test_no_subclass_context_manager_does_not_swallow_exception(self):
class User(Document):
@@ -218,7 +219,7 @@ class ContextManagersTest(unittest.TestCase):
raise TypeError()
def test_query_counter_temporarily_modifies_profiling_level(self):
- connect('mongoenginetest')
+ connect("mongoenginetest")
db = get_db()
initial_profiling_level = db.profiling_level()
@@ -231,11 +232,13 @@ class ContextManagersTest(unittest.TestCase):
self.assertEqual(db.profiling_level(), 2)
self.assertEqual(db.profiling_level(), NEW_LEVEL)
except Exception:
- db.set_profiling_level(initial_profiling_level) # Ensures it gets reseted no matter the outcome of the test
+ db.set_profiling_level(
+ initial_profiling_level
+ ) # Ensures it gets reseted no matter the outcome of the test
raise
def test_query_counter(self):
- connect('mongoenginetest')
+ connect("mongoenginetest")
db = get_db()
collection = db.query_counter
@@ -245,7 +248,7 @@ class ContextManagersTest(unittest.TestCase):
count_documents(collection, {})
def issue_1_insert_query():
- collection.insert_one({'test': 'garbage'})
+ collection.insert_one({"test": "garbage"})
def issue_1_find_query():
collection.find_one()
@@ -253,7 +256,9 @@ class ContextManagersTest(unittest.TestCase):
counter = 0
with query_counter() as q:
self.assertEqual(q, counter)
- self.assertEqual(q, counter) # Ensures previous count query did not get counted
+ self.assertEqual(
+ q, counter
+ ) # Ensures previous count query did not get counted
for _ in range(10):
issue_1_insert_query()
@@ -270,23 +275,25 @@ class ContextManagersTest(unittest.TestCase):
counter += 1
self.assertEqual(q, counter)
- self.assertEqual(int(q), counter) # test __int__
+ self.assertEqual(int(q), counter) # test __int__
self.assertEqual(repr(q), str(int(q))) # test __repr__
- self.assertGreater(q, -1) # test __gt__
- self.assertGreaterEqual(q, int(q)) # test __gte__
+ self.assertGreater(q, -1) # test __gt__
+ self.assertGreaterEqual(q, int(q)) # test __gte__
self.assertNotEqual(q, -1)
self.assertLess(q, 1000)
self.assertLessEqual(q, int(q))
def test_query_counter_counts_getmore_queries(self):
- connect('mongoenginetest')
+ connect("mongoenginetest")
db = get_db()
collection = db.query_counter
collection.drop()
- many_docs = [{'test': 'garbage %s' % i} for i in range(150)]
- collection.insert_many(many_docs) # first batch of documents contains 101 documents
+ many_docs = [{"test": "garbage %s" % i} for i in range(150)]
+ collection.insert_many(
+ many_docs
+ ) # first batch of documents contains 101 documents
with query_counter() as q:
self.assertEqual(q, 0)
@@ -294,24 +301,26 @@ class ContextManagersTest(unittest.TestCase):
self.assertEqual(q, 2) # 1st select + 1 getmore
def test_query_counter_ignores_particular_queries(self):
- connect('mongoenginetest')
+ connect("mongoenginetest")
db = get_db()
collection = db.query_counter
- collection.insert_many([{'test': 'garbage %s' % i} for i in range(10)])
+ collection.insert_many([{"test": "garbage %s" % i} for i in range(10)])
with query_counter() as q:
self.assertEqual(q, 0)
cursor = collection.find()
- self.assertEqual(q, 0) # cursor wasn't opened yet
- _ = next(cursor) # opens the cursor and fires the find query
+ self.assertEqual(q, 0) # cursor wasn't opened yet
+ _ = next(cursor) # opens the cursor and fires the find query
self.assertEqual(q, 1)
- cursor.close() # issues a `killcursors` query that is ignored by the context
+ cursor.close() # issues a `killcursors` query that is ignored by the context
self.assertEqual(q, 1)
- _ = db.system.indexes.find_one() # queries on db.system.indexes are ignored as well
+ _ = (
+ db.system.indexes.find_one()
+ ) # queries on db.system.indexes are ignored as well
self.assertEqual(q, 1)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index a9ef98e7..7def2ac7 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -14,128 +14,129 @@ class DocumentStub(object):
class TestBaseDict(unittest.TestCase):
-
@staticmethod
def _get_basedict(dict_items):
"""Get a BaseList bound to a fake document instance"""
fake_doc = DocumentStub()
- base_list = BaseDict(dict_items, instance=None, name='my_name')
- base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor
+ base_list = BaseDict(dict_items, instance=None, name="my_name")
+ base_list._instance = (
+ fake_doc
+ ) # hack to inject the mock, it does not work in the constructor
return base_list
def test___init___(self):
class MyDoc(Document):
pass
- dict_items = {'k': 'v'}
+ dict_items = {"k": "v"}
doc = MyDoc()
- base_dict = BaseDict(dict_items, instance=doc, name='my_name')
+ base_dict = BaseDict(dict_items, instance=doc, name="my_name")
self.assertIsInstance(base_dict._instance, Document)
- self.assertEqual(base_dict._name, 'my_name')
+ self.assertEqual(base_dict._name, "my_name")
self.assertEqual(base_dict, dict_items)
def test_setdefault_calls_mark_as_changed(self):
base_dict = self._get_basedict({})
- base_dict.setdefault('k', 'v')
+ base_dict.setdefault("k", "v")
self.assertEqual(base_dict._instance._changed_fields, [base_dict._name])
def test_popitems_calls_mark_as_changed(self):
- base_dict = self._get_basedict({'k': 'v'})
- self.assertEqual(base_dict.popitem(), ('k', 'v'))
+ base_dict = self._get_basedict({"k": "v"})
+ self.assertEqual(base_dict.popitem(), ("k", "v"))
self.assertEqual(base_dict._instance._changed_fields, [base_dict._name])
self.assertFalse(base_dict)
def test_pop_calls_mark_as_changed(self):
- base_dict = self._get_basedict({'k': 'v'})
- self.assertEqual(base_dict.pop('k'), 'v')
+ base_dict = self._get_basedict({"k": "v"})
+ self.assertEqual(base_dict.pop("k"), "v")
self.assertEqual(base_dict._instance._changed_fields, [base_dict._name])
self.assertFalse(base_dict)
def test_pop_calls_does_not_mark_as_changed_when_it_fails(self):
- base_dict = self._get_basedict({'k': 'v'})
+ base_dict = self._get_basedict({"k": "v"})
with self.assertRaises(KeyError):
- base_dict.pop('X')
+ base_dict.pop("X")
self.assertFalse(base_dict._instance._changed_fields)
def test_clear_calls_mark_as_changed(self):
- base_dict = self._get_basedict({'k': 'v'})
+ base_dict = self._get_basedict({"k": "v"})
base_dict.clear()
- self.assertEqual(base_dict._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name"])
self.assertEqual(base_dict, {})
def test___delitem___calls_mark_as_changed(self):
- base_dict = self._get_basedict({'k': 'v'})
- del base_dict['k']
- self.assertEqual(base_dict._instance._changed_fields, ['my_name.k'])
+ base_dict = self._get_basedict({"k": "v"})
+ del base_dict["k"]
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"])
self.assertEqual(base_dict, {})
def test___getitem____KeyError(self):
base_dict = self._get_basedict({})
with self.assertRaises(KeyError):
- base_dict['new']
+ base_dict["new"]
def test___getitem____simple_value(self):
- base_dict = self._get_basedict({'k': 'v'})
- base_dict['k'] = 'v'
+ base_dict = self._get_basedict({"k": "v"})
+ base_dict["k"] = "v"
def test___getitem____sublist_gets_converted_to_BaseList(self):
- base_dict = self._get_basedict({'k': [0, 1, 2]})
- sub_list = base_dict['k']
+ base_dict = self._get_basedict({"k": [0, 1, 2]})
+ sub_list = base_dict["k"]
self.assertEqual(sub_list, [0, 1, 2])
self.assertIsInstance(sub_list, BaseList)
self.assertIs(sub_list._instance, base_dict._instance)
- self.assertEqual(sub_list._name, 'my_name.k')
+ self.assertEqual(sub_list._name, "my_name.k")
self.assertEqual(base_dict._instance._changed_fields, [])
# Challenge mark_as_changed from sublist
sub_list[1] = None
- self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.1'])
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.1"])
def test___getitem____subdict_gets_converted_to_BaseDict(self):
- base_dict = self._get_basedict({'k': {'subk': 'subv'}})
- sub_dict = base_dict['k']
- self.assertEqual(sub_dict, {'subk': 'subv'})
+ base_dict = self._get_basedict({"k": {"subk": "subv"}})
+ sub_dict = base_dict["k"]
+ self.assertEqual(sub_dict, {"subk": "subv"})
self.assertIsInstance(sub_dict, BaseDict)
self.assertIs(sub_dict._instance, base_dict._instance)
- self.assertEqual(sub_dict._name, 'my_name.k')
+ self.assertEqual(sub_dict._name, "my_name.k")
self.assertEqual(base_dict._instance._changed_fields, [])
# Challenge mark_as_changed from subdict
- sub_dict['subk'] = None
- self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk'])
+ sub_dict["subk"] = None
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.subk"])
def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self):
- base_dict = self._get_basedict({'k': [0, 1, 2]})
- sub_list = base_dict.get('k')
+ base_dict = self._get_basedict({"k": [0, 1, 2]})
+ sub_list = base_dict.get("k")
self.assertEqual(sub_list, [0, 1, 2])
self.assertIsInstance(sub_list, BaseList)
def test_get_returns_the_same_as___getitem__(self):
- base_dict = self._get_basedict({'k': [0, 1, 2]})
- get_ = base_dict.get('k')
- getitem_ = base_dict['k']
+ base_dict = self._get_basedict({"k": [0, 1, 2]})
+ get_ = base_dict.get("k")
+ getitem_ = base_dict["k"]
self.assertEqual(get_, getitem_)
def test_get_default(self):
base_dict = self._get_basedict({})
sentinel = object()
- self.assertEqual(base_dict.get('new'), None)
- self.assertIs(base_dict.get('new', sentinel), sentinel)
+ self.assertEqual(base_dict.get("new"), None)
+ self.assertIs(base_dict.get("new", sentinel), sentinel)
def test___setitem___calls_mark_as_changed(self):
base_dict = self._get_basedict({})
- base_dict['k'] = 'v'
- self.assertEqual(base_dict._instance._changed_fields, ['my_name.k'])
- self.assertEqual(base_dict, {'k': 'v'})
+ base_dict["k"] = "v"
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"])
+ self.assertEqual(base_dict, {"k": "v"})
def test_update_calls_mark_as_changed(self):
base_dict = self._get_basedict({})
- base_dict.update({'k': 'v'})
- self.assertEqual(base_dict._instance._changed_fields, ['my_name'])
+ base_dict.update({"k": "v"})
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name"])
def test___setattr____not_tracked_by_changes(self):
base_dict = self._get_basedict({})
- base_dict.a_new_attr = 'test'
+ base_dict.a_new_attr = "test"
self.assertEqual(base_dict._instance._changed_fields, [])
def test___delattr____tracked_by_changes(self):
@@ -143,19 +144,20 @@ class TestBaseDict(unittest.TestCase):
# This is even bad because it could be that there is an attribute
# with the same name as a key
base_dict = self._get_basedict({})
- base_dict.a_new_attr = 'test'
+ base_dict.a_new_attr = "test"
del base_dict.a_new_attr
- self.assertEqual(base_dict._instance._changed_fields, ['my_name.a_new_attr'])
+ self.assertEqual(base_dict._instance._changed_fields, ["my_name.a_new_attr"])
class TestBaseList(unittest.TestCase):
-
@staticmethod
def _get_baselist(list_items):
"""Get a BaseList bound to a fake document instance"""
fake_doc = DocumentStub()
- base_list = BaseList(list_items, instance=None, name='my_name')
- base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor
+ base_list = BaseList(list_items, instance=None, name="my_name")
+ base_list._instance = (
+ fake_doc
+ ) # hack to inject the mock, it does not work in the constructor
return base_list
def test___init___(self):
@@ -164,19 +166,19 @@ class TestBaseList(unittest.TestCase):
list_items = [True]
doc = MyDoc()
- base_list = BaseList(list_items, instance=doc, name='my_name')
+ base_list = BaseList(list_items, instance=doc, name="my_name")
self.assertIsInstance(base_list._instance, Document)
- self.assertEqual(base_list._name, 'my_name')
+ self.assertEqual(base_list._name, "my_name")
self.assertEqual(base_list, list_items)
def test___iter__(self):
values = [True, False, True, False]
- base_list = BaseList(values, instance=None, name='my_name')
+ base_list = BaseList(values, instance=None, name="my_name")
self.assertEqual(values, list(base_list))
def test___iter___allow_modification_while_iterating_withou_error(self):
# regular list allows for this, thus this subclass must comply to that
- base_list = BaseList([True, False, True, False], instance=None, name='my_name')
+ base_list = BaseList([True, False, True, False], instance=None, name="my_name")
for idx, val in enumerate(base_list):
if val:
base_list.pop(idx)
@@ -185,7 +187,7 @@ class TestBaseList(unittest.TestCase):
base_list = self._get_baselist([])
self.assertFalse(base_list._instance._changed_fields)
base_list.append(True)
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_subclass_append(self):
# Due to the way mark_as_changed_wrapper is implemented
@@ -193,7 +195,7 @@ class TestBaseList(unittest.TestCase):
class SubBaseList(BaseList):
pass
- base_list = SubBaseList([], instance=None, name='my_name')
+ base_list = SubBaseList([], instance=None, name="my_name")
base_list.append(True)
def test___getitem__using_simple_index(self):
@@ -217,54 +219,45 @@ class TestBaseList(unittest.TestCase):
self.assertEqual(base_list._instance._changed_fields, [])
def test___getitem__sublist_returns_BaseList_bound_to_instance(self):
- base_list = self._get_baselist(
- [
- [1, 2],
- [3, 4]
- ]
- )
+ base_list = self._get_baselist([[1, 2], [3, 4]])
sub_list = base_list[0]
self.assertEqual(sub_list, [1, 2])
self.assertIsInstance(sub_list, BaseList)
self.assertIs(sub_list._instance, base_list._instance)
- self.assertEqual(sub_list._name, 'my_name.0')
+ self.assertEqual(sub_list._name, "my_name.0")
self.assertEqual(base_list._instance._changed_fields, [])
# Challenge mark_as_changed from sublist
sub_list[1] = None
- self.assertEqual(base_list._instance._changed_fields, ['my_name.0.1'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name.0.1"])
def test___getitem__subdict_returns_BaseList_bound_to_instance(self):
- base_list = self._get_baselist(
- [
- {'subk': 'subv'}
- ]
- )
+ base_list = self._get_baselist([{"subk": "subv"}])
sub_dict = base_list[0]
- self.assertEqual(sub_dict, {'subk': 'subv'})
+ self.assertEqual(sub_dict, {"subk": "subv"})
self.assertIsInstance(sub_dict, BaseDict)
self.assertIs(sub_dict._instance, base_list._instance)
- self.assertEqual(sub_dict._name, 'my_name.0')
+ self.assertEqual(sub_dict._name, "my_name.0")
self.assertEqual(base_list._instance._changed_fields, [])
# Challenge mark_as_changed from subdict
- sub_dict['subk'] = None
- self.assertEqual(base_list._instance._changed_fields, ['my_name.0.subk'])
+ sub_dict["subk"] = None
+ self.assertEqual(base_list._instance._changed_fields, ["my_name.0.subk"])
def test_extend_calls_mark_as_changed(self):
base_list = self._get_baselist([])
base_list.extend([True])
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_insert_calls_mark_as_changed(self):
base_list = self._get_baselist([])
base_list.insert(0, True)
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_remove_calls_mark_as_changed(self):
base_list = self._get_baselist([True])
base_list.remove(True)
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_remove_not_mark_as_changed_when_it_fails(self):
base_list = self._get_baselist([True])
@@ -275,70 +268,76 @@ class TestBaseList(unittest.TestCase):
def test_pop_calls_mark_as_changed(self):
base_list = self._get_baselist([True])
base_list.pop()
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_reverse_calls_mark_as_changed(self):
base_list = self._get_baselist([True, False])
base_list.reverse()
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test___delitem___calls_mark_as_changed(self):
base_list = self._get_baselist([True])
del base_list[0]
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test___setitem___calls_with_full_slice_mark_as_changed(self):
base_list = self._get_baselist([])
- base_list[:] = [0, 1] # Will use __setslice__ under py2 and __setitem__ under py3
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ base_list[:] = [
+ 0,
+ 1,
+ ] # Will use __setslice__ under py2 and __setitem__ under py3
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [0, 1])
def test___setitem___calls_with_partial_slice_mark_as_changed(self):
base_list = self._get_baselist([0, 1, 2])
- base_list[0:2] = [1, 0] # Will use __setslice__ under py2 and __setitem__ under py3
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ base_list[0:2] = [
+ 1,
+ 0,
+ ] # Will use __setslice__ under py2 and __setitem__ under py3
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [1, 0, 2])
def test___setitem___calls_with_step_slice_mark_as_changed(self):
base_list = self._get_baselist([0, 1, 2])
- base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [-1, 1, -2])
def test___setitem___with_slice(self):
base_list = self._get_baselist([0, 1, 2, 3, 4, 5])
base_list[0:6:2] = [None, None, None]
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [None, 1, None, 3, None, 5])
def test___setitem___item_0_calls_mark_as_changed(self):
base_list = self._get_baselist([True])
base_list[0] = False
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [False])
def test___setitem___item_1_calls_mark_as_changed(self):
base_list = self._get_baselist([True, True])
base_list[1] = False
- self.assertEqual(base_list._instance._changed_fields, ['my_name.1'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name.1"])
self.assertEqual(base_list, [True, False])
def test___delslice___calls_mark_as_changed(self):
base_list = self._get_baselist([0, 1])
del base_list[0:1]
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
self.assertEqual(base_list, [1])
def test___iadd___calls_mark_as_changed(self):
base_list = self._get_baselist([True])
base_list += [False]
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test___imul___calls_mark_as_changed(self):
base_list = self._get_baselist([True])
self.assertEqual(base_list._instance._changed_fields, [])
base_list *= 2
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_sort_calls_not_marked_as_changed_when_it_fails(self):
base_list = self._get_baselist([True])
@@ -350,7 +349,7 @@ class TestBaseList(unittest.TestCase):
def test_sort_calls_mark_as_changed(self):
base_list = self._get_baselist([True, False])
base_list.sort()
- self.assertEqual(base_list._instance._changed_fields, ['my_name'])
+ self.assertEqual(base_list._instance._changed_fields, ["my_name"])
def test_sort_calls_with_key(self):
base_list = self._get_baselist([1, 2, 11])
@@ -371,7 +370,7 @@ class TestStrictDict(unittest.TestCase):
def test_iterkeys(self):
d = self.dtype(a=1)
- self.assertEqual(list(iterkeys(d)), ['a'])
+ self.assertEqual(list(iterkeys(d)), ["a"])
def test_len(self):
d = self.dtype(a=1)
@@ -379,9 +378,9 @@ class TestStrictDict(unittest.TestCase):
def test_pop(self):
d = self.dtype(a=1)
- self.assertIn('a', d)
- d.pop('a')
- self.assertNotIn('a', d)
+ self.assertIn("a", d)
+ d.pop("a")
+ self.assertNotIn("a", d)
def test_repr(self):
d = self.dtype(a=1, b=2, c=3)
@@ -416,7 +415,7 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype()
d.a = 1
self.assertEqual(d.a, 1)
- self.assertRaises(AttributeError, getattr, d, 'b')
+ self.assertRaises(AttributeError, getattr, d, "b")
def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype()
@@ -430,20 +429,20 @@ class TestStrictDict(unittest.TestCase):
def test_get(self):
d = self.dtype(a=1)
- self.assertEqual(d.get('a'), 1)
- self.assertEqual(d.get('b', 'bla'), 'bla')
+ self.assertEqual(d.get("a"), 1)
+ self.assertEqual(d.get("b", "bla"), "bla")
def test_items(self):
d = self.dtype(a=1)
- self.assertEqual(d.items(), [('a', 1)])
+ self.assertEqual(d.items(), [("a", 1)])
d = self.dtype(a=1, b=2)
- self.assertEqual(d.items(), [('a', 1), ('b', 2)])
+ self.assertEqual(d.items(), [("a", 1), ("b", 2)])
def test_mappings_protocol(self):
d = self.dtype(a=1, b=2)
- self.assertEqual(dict(d), {'a': 1, 'b': 2})
- self.assertEqual(dict(**d), {'a': 1, 'b': 2})
+ self.assertEqual(dict(d), {"a": 1, "b": 2})
+ self.assertEqual(dict(**d), {"a": 1, "b": 2})
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_dereference.py b/tests/test_dereference.py
index 9c565810..4730e2e3 100644
--- a/tests/test_dereference.py
+++ b/tests/test_dereference.py
@@ -10,18 +10,18 @@ from mongoengine.context_managers import query_counter
class FieldTest(unittest.TestCase):
-
@classmethod
def setUpClass(cls):
- cls.db = connect(db='mongoenginetest')
+ cls.db = connect(db="mongoenginetest")
@classmethod
def tearDownClass(cls):
- cls.db.drop_database('mongoenginetest')
+ cls.db.drop_database("mongoenginetest")
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -32,7 +32,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- user = User(name='user %s' % i)
+ user = User(name="user %s" % i)
user.save()
group = Group(members=User.objects)
@@ -47,7 +47,7 @@ class FieldTest(unittest.TestCase):
group_obj = Group.objects.first()
self.assertEqual(q, 1)
- len(group_obj._data['members'])
+ len(group_obj._data["members"])
self.assertEqual(q, 1)
len(group_obj.members)
@@ -80,6 +80,7 @@ class FieldTest(unittest.TestCase):
def test_list_item_dereference_dref_false(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -90,7 +91,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- user = User(name='user %s' % i)
+ user = User(name="user %s" % i)
user.save()
group = Group(members=User.objects)
@@ -105,14 +106,14 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 2)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
# verifies that no additional queries gets executed
# if we re-iterate over the ListField once it is
# dereferenced
[m for m in group_obj.members]
self.assertEqual(q, 2)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
# Document select_related
with query_counter() as q:
@@ -136,6 +137,7 @@ class FieldTest(unittest.TestCase):
def test_list_item_dereference_orphan_dbref(self):
"""Ensure that orphan DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -146,7 +148,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- user = User(name='user %s' % i)
+ user = User(name="user %s" % i)
user.save()
group = Group(members=User.objects)
@@ -164,14 +166,14 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 2)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
# verifies that no additional queries gets executed
# if we re-iterate over the ListField once it is
# dereferenced
[m for m in group_obj.members]
self.assertEqual(q, 2)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
User.drop_collection()
Group.drop_collection()
@@ -179,6 +181,7 @@ class FieldTest(unittest.TestCase):
def test_list_item_dereference_dref_false_stores_as_type(self):
"""Ensure that DBRef items are stored as their type
"""
+
class User(Document):
my_id = IntField(primary_key=True)
name = StringField()
@@ -189,17 +192,18 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
- user = User(my_id=1, name='user 1').save()
+ user = User(my_id=1, name="user 1").save()
Group(members=User.objects).save()
group = Group.objects.first()
- self.assertEqual(Group._get_collection().find_one()['members'], [1])
+ self.assertEqual(Group._get_collection().find_one()["members"], [1])
self.assertEqual(group.members, [user])
def test_handle_old_style_references(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -210,7 +214,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 26):
- user = User(name='user %s' % i)
+ user = User(name="user %s" % i)
user.save()
group = Group(members=User.objects)
@@ -227,8 +231,8 @@ class FieldTest(unittest.TestCase):
group.save()
group = Group.objects.first()
- self.assertEqual(group.members[0].name, 'user 1')
- self.assertEqual(group.members[-1].name, 'String!')
+ self.assertEqual(group.members[0].name, "user 1")
+ self.assertEqual(group.members[-1].name, "String!")
def test_migrate_references(self):
"""Example of migrating ReferenceField storage
@@ -249,8 +253,8 @@ class FieldTest(unittest.TestCase):
group = Group(author=user, members=[user]).save()
raw_data = Group._get_collection().find_one()
- self.assertIsInstance(raw_data['author'], DBRef)
- self.assertIsInstance(raw_data['members'][0], DBRef)
+ self.assertIsInstance(raw_data["author"], DBRef)
+ self.assertIsInstance(raw_data["members"][0], DBRef)
group = Group.objects.first()
self.assertEqual(group.author, user)
@@ -264,8 +268,8 @@ class FieldTest(unittest.TestCase):
# Migrate the data
for g in Group.objects():
# Explicitly mark as changed so resets
- g._mark_as_changed('author')
- g._mark_as_changed('members')
+ g._mark_as_changed("author")
+ g._mark_as_changed("members")
g.save()
group = Group.objects.first()
@@ -273,35 +277,36 @@ class FieldTest(unittest.TestCase):
self.assertEqual(group.members, [user])
raw_data = Group._get_collection().find_one()
- self.assertIsInstance(raw_data['author'], ObjectId)
- self.assertIsInstance(raw_data['members'][0], ObjectId)
+ self.assertIsInstance(raw_data["author"], ObjectId)
+ self.assertIsInstance(raw_data["members"][0], ObjectId)
def test_recursive_reference(self):
"""Ensure that ReferenceFields can reference their own documents.
"""
+
class Employee(Document):
name = StringField()
- boss = ReferenceField('self')
- friends = ListField(ReferenceField('self'))
+ boss = ReferenceField("self")
+ friends = ListField(ReferenceField("self"))
Employee.drop_collection()
- bill = Employee(name='Bill Lumbergh')
+ bill = Employee(name="Bill Lumbergh")
bill.save()
- michael = Employee(name='Michael Bolton')
+ michael = Employee(name="Michael Bolton")
michael.save()
- samir = Employee(name='Samir Nagheenanajar')
+ samir = Employee(name="Samir Nagheenanajar")
samir.save()
friends = [michael, samir]
- peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
+ peter = Employee(name="Peter Gibbons", boss=bill, friends=friends)
peter.save()
- Employee(name='Funky Gibbon', boss=bill, friends=friends).save()
- Employee(name='Funky Gibbon', boss=bill, friends=friends).save()
- Employee(name='Funky Gibbon', boss=bill, friends=friends).save()
+ Employee(name="Funky Gibbon", boss=bill, friends=friends).save()
+ Employee(name="Funky Gibbon", boss=bill, friends=friends).save()
+ Employee(name="Funky Gibbon", boss=bill, friends=friends).save()
with query_counter() as q:
self.assertEqual(q, 0)
@@ -343,7 +348,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2)
def test_list_of_lists_of_references(self):
-
class User(Document):
name = StringField()
@@ -357,9 +361,9 @@ class FieldTest(unittest.TestCase):
Post.drop_collection()
SimpleList.drop_collection()
- u1 = User.objects.create(name='u1')
- u2 = User.objects.create(name='u2')
- u3 = User.objects.create(name='u3')
+ u1 = User.objects.create(name="u1")
+ u2 = User.objects.create(name="u2")
+ u3 = User.objects.create(name="u3")
SimpleList.objects.create(users=[u1, u2, u3])
self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3])
@@ -370,13 +374,14 @@ class FieldTest(unittest.TestCase):
def test_circular_reference(self):
"""Ensure you can handle circular references
"""
+
class Relation(EmbeddedDocument):
name = StringField()
- person = ReferenceField('Person')
+ person = ReferenceField("Person")
class Person(Document):
name = StringField()
- relations = ListField(EmbeddedDocumentField('Relation'))
+ relations = ListField(EmbeddedDocumentField("Relation"))
def __repr__(self):
return "" % self.name
@@ -398,14 +403,17 @@ class FieldTest(unittest.TestCase):
daughter.relations.append(self_rel)
daughter.save()
- self.assertEqual("[, ]", "%s" % Person.objects())
+ self.assertEqual(
+ "[, ]", "%s" % Person.objects()
+ )
def test_circular_reference_on_self(self):
"""Ensure you can handle circular references
"""
+
class Person(Document):
name = StringField()
- relations = ListField(ReferenceField('self'))
+ relations = ListField(ReferenceField("self"))
def __repr__(self):
return "" % self.name
@@ -424,14 +432,17 @@ class FieldTest(unittest.TestCase):
daughter.relations.append(daughter)
daughter.save()
- self.assertEqual("[, ]", "%s" % Person.objects())
+ self.assertEqual(
+ "[, ]", "%s" % Person.objects()
+ )
def test_circular_tree_reference(self):
"""Ensure you can handle circular references with more than one level
"""
+
class Other(EmbeddedDocument):
name = StringField()
- friends = ListField(ReferenceField('Person'))
+ friends = ListField(ReferenceField("Person"))
class Person(Document):
name = StringField()
@@ -443,8 +454,8 @@ class FieldTest(unittest.TestCase):
Person.drop_collection()
paul = Person(name="Paul").save()
maria = Person(name="Maria").save()
- julia = Person(name='Julia').save()
- anna = Person(name='Anna').save()
+ julia = Person(name="Julia").save()
+ anna = Person(name="Anna").save()
paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"
@@ -464,11 +475,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(
"[, , , ]",
- "%s" % Person.objects()
+ "%s" % Person.objects(),
)
def test_generic_reference(self):
-
class UserA(Document):
name = StringField()
@@ -488,13 +498,13 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
- b = UserB(name='User B %s' % i)
+ b = UserB(name="User B %s" % i)
b.save()
- c = UserC(name='User C %s' % i)
+ c = UserC(name="User C %s" % i)
c.save()
members += [a, b, c]
@@ -518,7 +528,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Document select_related
with query_counter() as q:
@@ -534,7 +544,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Queryset select_related
with query_counter() as q:
@@ -551,8 +561,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
-
+ self.assertIn("User", m.__class__.__name__)
def test_generic_reference_orphan_dbref(self):
"""Ensure that generic orphan DBRef items in ListFields are dereferenced.
@@ -577,13 +586,13 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
- b = UserB(name='User B %s' % i)
+ b = UserB(name="User B %s" % i)
b.save()
- c = UserC(name='User C %s' % i)
+ c = UserC(name="User C %s" % i)
c.save()
members += [a, b, c]
@@ -602,11 +611,11 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 4)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
[m for m in group_obj.members]
self.assertEqual(q, 4)
- self.assertTrue(group_obj._data['members']._dereferenced)
+ self.assertTrue(group_obj._data["members"]._dereferenced)
UserA.drop_collection()
UserB.drop_collection()
@@ -614,7 +623,6 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_list_field_complex(self):
-
class UserA(Document):
name = StringField()
@@ -634,13 +642,13 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
- b = UserB(name='User B %s' % i)
+ b = UserB(name="User B %s" % i)
b.save()
- c = UserC(name='User C %s' % i)
+ c = UserC(name="User C %s" % i)
c.save()
members += [a, b, c]
@@ -664,7 +672,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Document select_related
with query_counter() as q:
@@ -680,7 +688,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Queryset select_related
with query_counter() as q:
@@ -697,7 +705,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for m in group_obj.members:
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
UserA.drop_collection()
UserB.drop_collection()
@@ -705,7 +713,6 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_map_field_reference(self):
-
class User(Document):
name = StringField()
@@ -717,7 +724,7 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- user = User(name='user %s' % i)
+ user = User(name="user %s" % i)
user.save()
members.append(user)
@@ -752,7 +759,7 @@ class FieldTest(unittest.TestCase):
for k, m in iteritems(group_obj.members):
self.assertIsInstance(m, User)
- # Queryset select_related
+ # Queryset select_related
with query_counter() as q:
self.assertEqual(q, 0)
@@ -770,7 +777,6 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_dict_field(self):
-
class UserA(Document):
name = StringField()
@@ -790,13 +796,13 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
- b = UserB(name='User B %s' % i)
+ b = UserB(name="User B %s" % i)
b.save()
- c = UserC(name='User C %s' % i)
+ c = UserC(name="User C %s" % i)
c.save()
members += [a, b, c]
@@ -819,7 +825,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Document select_related
with query_counter() as q:
@@ -835,7 +841,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Queryset select_related
with query_counter() as q:
@@ -852,7 +858,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
Group.objects.delete()
Group().save()
@@ -873,10 +879,9 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_dict_field_no_field_inheritance(self):
-
class UserA(Document):
name = StringField()
- meta = {'allow_inheritance': False}
+ meta = {"allow_inheritance": False}
class Group(Document):
members = DictField()
@@ -886,7 +891,7 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
members += [a]
@@ -949,7 +954,6 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_generic_reference_map_field(self):
-
class UserA(Document):
name = StringField()
@@ -969,13 +973,13 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i)
+ a = UserA(name="User A %s" % i)
a.save()
- b = UserB(name='User B %s' % i)
+ b = UserB(name="User B %s" % i)
b.save()
- c = UserC(name='User C %s' % i)
+ c = UserC(name="User C %s" % i)
c.save()
members += [a, b, c]
@@ -998,7 +1002,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Document select_related
with query_counter() as q:
@@ -1014,7 +1018,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
# Queryset select_related
with query_counter() as q:
@@ -1031,7 +1035,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 4)
for k, m in iteritems(group_obj.members):
- self.assertIn('User', m.__class__.__name__)
+ self.assertIn("User", m.__class__.__name__)
Group.objects.delete()
Group().save()
@@ -1051,7 +1055,6 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
def test_multidirectional_lists(self):
-
class Asset(Document):
name = StringField(max_length=250, required=True)
path = StringField()
@@ -1062,10 +1065,10 @@ class FieldTest(unittest.TestCase):
Asset.drop_collection()
- root = Asset(name='', path="/", title="Site Root")
+ root = Asset(name="", path="/", title="Site Root")
root.save()
- company = Asset(name='company', title='Company', parent=root, parents=[root])
+ company = Asset(name="company", title="Company", parent=root, parents=[root])
company.save()
root.children = [company]
@@ -1076,7 +1079,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(company.parents, [root])
def test_dict_in_dbref_instance(self):
-
class Person(Document):
name = StringField(max_length=250, required=True)
@@ -1087,34 +1089,35 @@ class FieldTest(unittest.TestCase):
Person.drop_collection()
Room.drop_collection()
- bob = Person.objects.create(name='Bob')
+ bob = Person.objects.create(name="Bob")
bob.save()
- sarah = Person.objects.create(name='Sarah')
+ sarah = Person.objects.create(name="Sarah")
sarah.save()
room_101 = Room.objects.create(number="101")
room_101.staffs_with_position = [
- {'position_key': 'window', 'staff': sarah},
- {'position_key': 'door', 'staff': bob.to_dbref()}]
+ {"position_key": "window", "staff": sarah},
+ {"position_key": "door", "staff": bob.to_dbref()},
+ ]
room_101.save()
room = Room.objects.first().select_related()
- self.assertEqual(room.staffs_with_position[0]['staff'], sarah)
- self.assertEqual(room.staffs_with_position[1]['staff'], bob)
+ self.assertEqual(room.staffs_with_position[0]["staff"], sarah)
+ self.assertEqual(room.staffs_with_position[1]["staff"], bob)
def test_document_reload_no_inheritance(self):
class Foo(Document):
- meta = {'allow_inheritance': False}
- bar = ReferenceField('Bar')
- baz = ReferenceField('Baz')
+ meta = {"allow_inheritance": False}
+ bar = ReferenceField("Bar")
+ baz = ReferenceField("Baz")
class Bar(Document):
- meta = {'allow_inheritance': False}
- msg = StringField(required=True, default='Blammo!')
+ meta = {"allow_inheritance": False}
+ msg = StringField(required=True, default="Blammo!")
class Baz(Document):
- meta = {'allow_inheritance': False}
- msg = StringField(required=True, default='Kaboom!')
+ meta = {"allow_inheritance": False}
+ msg = StringField(required=True, default="Kaboom!")
Foo.drop_collection()
Bar.drop_collection()
@@ -1138,11 +1141,14 @@ class FieldTest(unittest.TestCase):
Ensure reloading a document with multiple similar id
in different collections doesn't mix them.
"""
+
class Topic(Document):
id = IntField(primary_key=True)
+
class User(Document):
id = IntField(primary_key=True)
name = StringField()
+
class Message(Document):
id = IntField(primary_key=True)
topic = ReferenceField(Topic)
@@ -1154,23 +1160,24 @@ class FieldTest(unittest.TestCase):
# All objects share the same id, but each in a different collection
topic = Topic(id=1).save()
- user = User(id=1, name='user-name').save()
+ user = User(id=1, name="user-name").save()
Message(id=1, topic=topic, author=user).save()
concurrent_change_user = User.objects.get(id=1)
- concurrent_change_user.name = 'new-name'
+ concurrent_change_user.name = "new-name"
concurrent_change_user.save()
- self.assertNotEqual(user.name, 'new-name')
+ self.assertNotEqual(user.name, "new-name")
msg = Message.objects.get(id=1)
msg.reload()
self.assertEqual(msg.topic, topic)
self.assertEqual(msg.author, user)
- self.assertEqual(msg.author.name, 'new-name')
+ self.assertEqual(msg.author.name, "new-name")
def test_list_lookup_not_checked_in_map(self):
"""Ensure we dereference list data correctly
"""
+
class Comment(Document):
id = IntField(primary_key=True)
text = StringField()
@@ -1182,8 +1189,8 @@ class FieldTest(unittest.TestCase):
Comment.drop_collection()
Message.drop_collection()
- c1 = Comment(id=0, text='zero').save()
- c2 = Comment(id=1, text='one').save()
+ c1 = Comment(id=0, text="zero").save()
+ c2 = Comment(id=1, text="one").save()
Message(id=1, comments=[c1, c2]).save()
msg = Message.objects.get(id=1)
@@ -1193,6 +1200,7 @@ class FieldTest(unittest.TestCase):
def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -1204,7 +1212,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- User(name='user %s' % i).save()
+ User(name="user %s" % i).save()
Group(name="Test", members=User.objects).save()
@@ -1222,6 +1230,7 @@ class FieldTest(unittest.TestCase):
def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
+
class User(Document):
name = StringField()
@@ -1233,7 +1242,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
for i in range(1, 51):
- User(name='user %s' % i).save()
+ User(name="user %s" % i).save()
Group(name="Test", members=User.objects).save()
@@ -1249,7 +1258,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2)
def test_generic_reference_save_doesnt_cause_extra_queries(self):
-
class UserA(Document):
name = StringField()
@@ -1270,9 +1278,9 @@ class FieldTest(unittest.TestCase):
members = []
for i in range(1, 51):
- a = UserA(name='User A %s' % i).save()
- b = UserB(name='User B %s' % i).save()
- c = UserC(name='User C %s' % i).save()
+ a = UserA(name="User A %s" % i).save()
+ b = UserB(name="User B %s" % i).save()
+ c = UserC(name="User C %s" % i).save()
members += [a, b, c]
@@ -1292,7 +1300,7 @@ class FieldTest(unittest.TestCase):
def test_objectid_reference_across_databases(self):
# mongoenginetest - Is default connection alias from setUp()
# Register Aliases
- register_connection('testdb-1', 'mongoenginetest2')
+ register_connection("testdb-1", "mongoenginetest2")
class User(Document):
name = StringField()
@@ -1311,16 +1319,17 @@ class FieldTest(unittest.TestCase):
# Can't use query_counter across databases - so test the _data object
book = Book.objects.first()
- self.assertNotIsInstance(book._data['author'], User)
+ self.assertNotIsInstance(book._data["author"], User)
book.select_related()
- self.assertIsInstance(book._data['author'], User)
+ self.assertIsInstance(book._data["author"], User)
def test_non_ascii_pk(self):
"""
Ensure that dbref conversion to string does not fail when
non-ascii characters are used in primary key
"""
+
class Brand(Document):
title = StringField(max_length=255, primary_key=True)
@@ -1341,7 +1350,7 @@ class FieldTest(unittest.TestCase):
def test_dereferencing_embedded_listfield_referencefield(self):
class Tag(Document):
- meta = {'collection': 'tags'}
+ meta = {"collection": "tags"}
name = StringField()
class Post(EmbeddedDocument):
@@ -1349,22 +1358,21 @@ class FieldTest(unittest.TestCase):
tags = ListField(ReferenceField("Tag", dbref=True))
class Page(Document):
- meta = {'collection': 'pages'}
+ meta = {"collection": "pages"}
tags = ListField(ReferenceField("Tag", dbref=True))
posts = ListField(EmbeddedDocumentField(Post))
Tag.drop_collection()
Page.drop_collection()
- tag = Tag(name='test').save()
- post = Post(body='test body', tags=[tag])
+ tag = Tag(name="test").save()
+ post = Post(body="test body", tags=[tag])
Page(tags=[tag], posts=[post]).save()
page = Page.objects.first()
self.assertEqual(page.tags[0], page.posts[0].tags[0])
def test_select_related_follows_embedded_referencefields(self):
-
class Song(Document):
title = StringField()
@@ -1390,5 +1398,5 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py
index cacdce8b..5e3aa493 100644
--- a/tests/test_replicaset_connection.py
+++ b/tests/test_replicaset_connection.py
@@ -12,7 +12,6 @@ READ_PREF = ReadPreference.SECONDARY
class ConnectionTest(unittest.TestCase):
-
def setUp(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
@@ -28,9 +27,11 @@ class ConnectionTest(unittest.TestCase):
"""
try:
- conn = mongoengine.connect(db='mongoenginetest',
- host="mongodb://localhost/mongoenginetest?replicaSet=rs",
- read_preference=READ_PREF)
+ conn = mongoengine.connect(
+ db="mongoenginetest",
+ host="mongodb://localhost/mongoenginetest?replicaSet=rs",
+ read_preference=READ_PREF,
+ )
except MongoEngineConnectionError as e:
return
@@ -41,5 +42,5 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(conn.read_preference, READ_PREF)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_signals.py b/tests/test_signals.py
index 34cb43c3..1d0607d7 100644
--- a/tests/test_signals.py
+++ b/tests/test_signals.py
@@ -20,7 +20,7 @@ class SignalTests(unittest.TestCase):
return signal_output
def setUp(self):
- connect(db='mongoenginetest')
+ connect(db="mongoenginetest")
class Author(Document):
# Make the id deterministic for easier testing
@@ -32,60 +32,63 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_init(cls, sender, document, *args, **kwargs):
- signal_output.append('pre_init signal, %s' % cls.__name__)
- signal_output.append(kwargs['values'])
+ signal_output.append("pre_init signal, %s" % cls.__name__)
+ signal_output.append(kwargs["values"])
@classmethod
def post_init(cls, sender, document, **kwargs):
- signal_output.append('post_init signal, %s, document._created = %s' % (document, document._created))
+ signal_output.append(
+ "post_init signal, %s, document._created = %s"
+ % (document, document._created)
+ )
@classmethod
def pre_save(cls, sender, document, **kwargs):
- signal_output.append('pre_save signal, %s' % document)
+ signal_output.append("pre_save signal, %s" % document)
signal_output.append(kwargs)
@classmethod
def pre_save_post_validation(cls, sender, document, **kwargs):
- signal_output.append('pre_save_post_validation signal, %s' % document)
- if kwargs.pop('created', False):
- signal_output.append('Is created')
+ signal_output.append("pre_save_post_validation signal, %s" % document)
+ if kwargs.pop("created", False):
+ signal_output.append("Is created")
else:
- signal_output.append('Is updated')
+ signal_output.append("Is updated")
signal_output.append(kwargs)
@classmethod
def post_save(cls, sender, document, **kwargs):
dirty_keys = document._delta()[0].keys() + document._delta()[1].keys()
- signal_output.append('post_save signal, %s' % document)
- signal_output.append('post_save dirty keys, %s' % dirty_keys)
- if kwargs.pop('created', False):
- signal_output.append('Is created')
+ signal_output.append("post_save signal, %s" % document)
+ signal_output.append("post_save dirty keys, %s" % dirty_keys)
+ if kwargs.pop("created", False):
+ signal_output.append("Is created")
else:
- signal_output.append('Is updated')
+ signal_output.append("Is updated")
signal_output.append(kwargs)
@classmethod
def pre_delete(cls, sender, document, **kwargs):
- signal_output.append('pre_delete signal, %s' % document)
+ signal_output.append("pre_delete signal, %s" % document)
signal_output.append(kwargs)
@classmethod
def post_delete(cls, sender, document, **kwargs):
- signal_output.append('post_delete signal, %s' % document)
+ signal_output.append("post_delete signal, %s" % document)
signal_output.append(kwargs)
@classmethod
def pre_bulk_insert(cls, sender, documents, **kwargs):
- signal_output.append('pre_bulk_insert signal, %s' % documents)
+ signal_output.append("pre_bulk_insert signal, %s" % documents)
signal_output.append(kwargs)
@classmethod
def post_bulk_insert(cls, sender, documents, **kwargs):
- signal_output.append('post_bulk_insert signal, %s' % documents)
- if kwargs.pop('loaded', False):
- signal_output.append('Is loaded')
+ signal_output.append("post_bulk_insert signal, %s" % documents)
+ if kwargs.pop("loaded", False):
+ signal_output.append("Is loaded")
else:
- signal_output.append('Not loaded')
+ signal_output.append("Not loaded")
signal_output.append(kwargs)
self.Author = Author
@@ -101,12 +104,12 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_delete(cls, sender, document, **kwargs):
- signal_output.append('pre_delete signal, %s' % document)
+ signal_output.append("pre_delete signal, %s" % document)
signal_output.append(kwargs)
@classmethod
def post_delete(cls, sender, document, **kwargs):
- signal_output.append('post_delete signal, %s' % document)
+ signal_output.append("post_delete signal, %s" % document)
signal_output.append(kwargs)
self.Another = Another
@@ -117,11 +120,11 @@ class SignalTests(unittest.TestCase):
@classmethod
def post_save(cls, sender, document, **kwargs):
- if 'created' in kwargs:
- if kwargs['created']:
- signal_output.append('Is created')
+ if "created" in kwargs:
+ if kwargs["created"]:
+ signal_output.append("Is created")
else:
- signal_output.append('Is updated')
+ signal_output.append("Is updated")
self.ExplicitId = ExplicitId
ExplicitId.drop_collection()
@@ -136,9 +139,13 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_bulk_insert(cls, sender, documents, **kwargs):
- signal_output.append('pre_bulk_insert signal, %s' %
- [(doc, {'active': documents[n].active})
- for n, doc in enumerate(documents)])
+ signal_output.append(
+ "pre_bulk_insert signal, %s"
+ % [
+ (doc, {"active": documents[n].active})
+ for n, doc in enumerate(documents)
+ ]
+ )
# make changes here, this is just an example -
# it could be anything that needs pre-validation or looks-ups before bulk bulk inserting
@@ -149,13 +156,17 @@ class SignalTests(unittest.TestCase):
@classmethod
def post_bulk_insert(cls, sender, documents, **kwargs):
- signal_output.append('post_bulk_insert signal, %s' %
- [(doc, {'active': documents[n].active})
- for n, doc in enumerate(documents)])
- if kwargs.pop('loaded', False):
- signal_output.append('Is loaded')
+ signal_output.append(
+ "post_bulk_insert signal, %s"
+ % [
+ (doc, {"active": documents[n].active})
+ for n, doc in enumerate(documents)
+ ]
+ )
+ if kwargs.pop("loaded", False):
+ signal_output.append("Is loaded")
else:
- signal_output.append('Not loaded')
+ signal_output.append("Not loaded")
signal_output.append(kwargs)
self.Post = Post
@@ -178,7 +189,9 @@ class SignalTests(unittest.TestCase):
signals.pre_init.connect(Author.pre_init, sender=Author)
signals.post_init.connect(Author.post_init, sender=Author)
signals.pre_save.connect(Author.pre_save, sender=Author)
- signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author)
+ signals.pre_save_post_validation.connect(
+ Author.pre_save_post_validation, sender=Author
+ )
signals.post_save.connect(Author.post_save, sender=Author)
signals.pre_delete.connect(Author.pre_delete, sender=Author)
signals.post_delete.connect(Author.post_delete, sender=Author)
@@ -199,7 +212,9 @@ class SignalTests(unittest.TestCase):
signals.post_delete.disconnect(self.Author.post_delete)
signals.pre_delete.disconnect(self.Author.pre_delete)
signals.post_save.disconnect(self.Author.post_save)
- signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation)
+ signals.pre_save_post_validation.disconnect(
+ self.Author.pre_save_post_validation
+ )
signals.pre_save.disconnect(self.Author.pre_save)
signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert)
signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert)
@@ -236,203 +251,236 @@ class SignalTests(unittest.TestCase):
""" Model saves should throw some signals. """
def create_author():
- self.Author(name='Bill Shakespeare')
+ self.Author(name="Bill Shakespeare")
def bulk_create_author_with_load():
- a1 = self.Author(name='Bill Shakespeare')
+ a1 = self.Author(name="Bill Shakespeare")
self.Author.objects.insert([a1], load_bulk=True)
def bulk_create_author_without_load():
- a1 = self.Author(name='Bill Shakespeare')
+ a1 = self.Author(name="Bill Shakespeare")
self.Author.objects.insert([a1], load_bulk=False)
def load_existing_author():
- a = self.Author(name='Bill Shakespeare')
+ a = self.Author(name="Bill Shakespeare")
a.save()
self.get_signal_output(lambda: None) # eliminate signal output
- a1 = self.Author.objects(name='Bill Shakespeare')[0]
+ a1 = self.Author.objects(name="Bill Shakespeare")[0]
- self.assertEqual(self.get_signal_output(create_author), [
- "pre_init signal, Author",
- {'name': 'Bill Shakespeare'},
- "post_init signal, Bill Shakespeare, document._created = True",
- ])
+ self.assertEqual(
+ self.get_signal_output(create_author),
+ [
+ "pre_init signal, Author",
+ {"name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = True",
+ ],
+ )
- a1 = self.Author(name='Bill Shakespeare')
- self.assertEqual(self.get_signal_output(a1.save), [
- "pre_save signal, Bill Shakespeare",
- {},
- "pre_save_post_validation signal, Bill Shakespeare",
- "Is created",
- {},
- "post_save signal, Bill Shakespeare",
- "post_save dirty keys, ['name']",
- "Is created",
- {}
- ])
+ a1 = self.Author(name="Bill Shakespeare")
+ self.assertEqual(
+ self.get_signal_output(a1.save),
+ [
+ "pre_save signal, Bill Shakespeare",
+ {},
+ "pre_save_post_validation signal, Bill Shakespeare",
+ "Is created",
+ {},
+ "post_save signal, Bill Shakespeare",
+ "post_save dirty keys, ['name']",
+ "Is created",
+ {},
+ ],
+ )
a1.reload()
- a1.name = 'William Shakespeare'
- self.assertEqual(self.get_signal_output(a1.save), [
- "pre_save signal, William Shakespeare",
- {},
- "pre_save_post_validation signal, William Shakespeare",
- "Is updated",
- {},
- "post_save signal, William Shakespeare",
- "post_save dirty keys, ['name']",
- "Is updated",
- {}
- ])
+ a1.name = "William Shakespeare"
+ self.assertEqual(
+ self.get_signal_output(a1.save),
+ [
+ "pre_save signal, William Shakespeare",
+ {},
+ "pre_save_post_validation signal, William Shakespeare",
+ "Is updated",
+ {},
+ "post_save signal, William Shakespeare",
+ "post_save dirty keys, ['name']",
+ "Is updated",
+ {},
+ ],
+ )
- self.assertEqual(self.get_signal_output(a1.delete), [
- 'pre_delete signal, William Shakespeare',
- {},
- 'post_delete signal, William Shakespeare',
- {}
- ])
+ self.assertEqual(
+ self.get_signal_output(a1.delete),
+ [
+ "pre_delete signal, William Shakespeare",
+ {},
+ "post_delete signal, William Shakespeare",
+ {},
+ ],
+ )
- self.assertEqual(self.get_signal_output(load_existing_author), [
- "pre_init signal, Author",
- {'id': 2, 'name': 'Bill Shakespeare'},
- "post_init signal, Bill Shakespeare, document._created = False"
- ])
+ self.assertEqual(
+ self.get_signal_output(load_existing_author),
+ [
+ "pre_init signal, Author",
+ {"id": 2, "name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = False",
+ ],
+ )
- self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [
- 'pre_init signal, Author',
- {'name': 'Bill Shakespeare'},
- 'post_init signal, Bill Shakespeare, document._created = True',
- 'pre_bulk_insert signal, []',
- {},
- 'pre_init signal, Author',
- {'id': 3, 'name': 'Bill Shakespeare'},
- 'post_init signal, Bill Shakespeare, document._created = False',
- 'post_bulk_insert signal, []',
- 'Is loaded',
- {}
- ])
+ self.assertEqual(
+ self.get_signal_output(bulk_create_author_with_load),
+ [
+ "pre_init signal, Author",
+ {"name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = True",
+ "pre_bulk_insert signal, []",
+ {},
+ "pre_init signal, Author",
+ {"id": 3, "name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = False",
+ "post_bulk_insert signal, []",
+ "Is loaded",
+ {},
+ ],
+ )
- self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [
- "pre_init signal, Author",
- {'name': 'Bill Shakespeare'},
- "post_init signal, Bill Shakespeare, document._created = True",
- "pre_bulk_insert signal, []",
- {},
- "post_bulk_insert signal, []",
- "Not loaded",
- {}
- ])
+ self.assertEqual(
+ self.get_signal_output(bulk_create_author_without_load),
+ [
+ "pre_init signal, Author",
+ {"name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = True",
+ "pre_bulk_insert signal, []",
+ {},
+ "post_bulk_insert signal, []",
+ "Not loaded",
+ {},
+ ],
+ )
def test_signal_kwargs(self):
""" Make sure signal_kwargs is passed to signals calls. """
def live_and_let_die():
- a = self.Author(name='Bill Shakespeare')
- a.save(signal_kwargs={'live': True, 'die': False})
- a.delete(signal_kwargs={'live': False, 'die': True})
+ a = self.Author(name="Bill Shakespeare")
+ a.save(signal_kwargs={"live": True, "die": False})
+ a.delete(signal_kwargs={"live": False, "die": True})
- self.assertEqual(self.get_signal_output(live_and_let_die), [
- "pre_init signal, Author",
- {'name': 'Bill Shakespeare'},
- "post_init signal, Bill Shakespeare, document._created = True",
- "pre_save signal, Bill Shakespeare",
- {'die': False, 'live': True},
- "pre_save_post_validation signal, Bill Shakespeare",
- "Is created",
- {'die': False, 'live': True},
- "post_save signal, Bill Shakespeare",
- "post_save dirty keys, ['name']",
- "Is created",
- {'die': False, 'live': True},
- 'pre_delete signal, Bill Shakespeare',
- {'die': True, 'live': False},
- 'post_delete signal, Bill Shakespeare',
- {'die': True, 'live': False}
- ])
+ self.assertEqual(
+ self.get_signal_output(live_and_let_die),
+ [
+ "pre_init signal, Author",
+ {"name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = True",
+ "pre_save signal, Bill Shakespeare",
+ {"die": False, "live": True},
+ "pre_save_post_validation signal, Bill Shakespeare",
+ "Is created",
+ {"die": False, "live": True},
+ "post_save signal, Bill Shakespeare",
+ "post_save dirty keys, ['name']",
+ "Is created",
+ {"die": False, "live": True},
+ "pre_delete signal, Bill Shakespeare",
+ {"die": True, "live": False},
+ "post_delete signal, Bill Shakespeare",
+ {"die": True, "live": False},
+ ],
+ )
def bulk_create_author():
- a1 = self.Author(name='Bill Shakespeare')
- self.Author.objects.insert([a1], signal_kwargs={'key': True})
+ a1 = self.Author(name="Bill Shakespeare")
+ self.Author.objects.insert([a1], signal_kwargs={"key": True})
- self.assertEqual(self.get_signal_output(bulk_create_author), [
- 'pre_init signal, Author',
- {'name': 'Bill Shakespeare'},
- 'post_init signal, Bill Shakespeare, document._created = True',
- 'pre_bulk_insert signal, []',
- {'key': True},
- 'pre_init signal, Author',
- {'id': 2, 'name': 'Bill Shakespeare'},
- 'post_init signal, Bill Shakespeare, document._created = False',
- 'post_bulk_insert signal, []',
- 'Is loaded',
- {'key': True}
- ])
+ self.assertEqual(
+ self.get_signal_output(bulk_create_author),
+ [
+ "pre_init signal, Author",
+ {"name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = True",
+ "pre_bulk_insert signal, []",
+ {"key": True},
+ "pre_init signal, Author",
+ {"id": 2, "name": "Bill Shakespeare"},
+ "post_init signal, Bill Shakespeare, document._created = False",
+ "post_bulk_insert signal, []",
+ "Is loaded",
+ {"key": True},
+ ],
+ )
def test_queryset_delete_signals(self):
""" Queryset delete should throw some signals. """
- self.Another(name='Bill Shakespeare').save()
- self.assertEqual(self.get_signal_output(self.Another.objects.delete), [
- 'pre_delete signal, Bill Shakespeare',
- {},
- 'post_delete signal, Bill Shakespeare',
- {}
- ])
+ self.Another(name="Bill Shakespeare").save()
+ self.assertEqual(
+ self.get_signal_output(self.Another.objects.delete),
+ [
+ "pre_delete signal, Bill Shakespeare",
+ {},
+ "post_delete signal, Bill Shakespeare",
+ {},
+ ],
+ )
def test_signals_with_explicit_doc_ids(self):
""" Model saves must have a created flag the first time."""
ei = self.ExplicitId(id=123)
# post save must received the created flag, even if there's already
# an object id present
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
# second time, it must be an update
- self.assertEqual(self.get_signal_output(ei.save), ['Is updated'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is updated"])
def test_signals_with_switch_collection(self):
ei = self.ExplicitId(id=123)
ei.switch_collection("explicit__1")
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
ei.switch_collection("explicit__1")
- self.assertEqual(self.get_signal_output(ei.save), ['Is updated'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is updated"])
ei.switch_collection("explicit__1", keep_created=False)
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
ei.switch_collection("explicit__1", keep_created=False)
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
def test_signals_with_switch_db(self):
- connect('mongoenginetest')
- register_connection('testdb-1', 'mongoenginetest2')
+ connect("mongoenginetest")
+ register_connection("testdb-1", "mongoenginetest2")
ei = self.ExplicitId(id=123)
ei.switch_db("testdb-1")
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
ei.switch_db("testdb-1")
- self.assertEqual(self.get_signal_output(ei.save), ['Is updated'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is updated"])
ei.switch_db("testdb-1", keep_created=False)
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
ei.switch_db("testdb-1", keep_created=False)
- self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
+ self.assertEqual(self.get_signal_output(ei.save), ["Is created"])
def test_signals_bulk_insert(self):
def bulk_set_active_post():
posts = [
- self.Post(title='Post 1'),
- self.Post(title='Post 2'),
- self.Post(title='Post 3')
+ self.Post(title="Post 1"),
+ self.Post(title="Post 2"),
+ self.Post(title="Post 3"),
]
self.Post.objects.insert(posts)
results = self.get_signal_output(bulk_set_active_post)
- self.assertEqual(results, [
- "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]",
- {},
- "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]",
- 'Is loaded',
- {}
- ])
+ self.assertEqual(
+ results,
+ [
+ "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]",
+ {},
+ "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]",
+ "Is loaded",
+ {},
+ ],
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 562cc1ff..2d1e8b00 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -7,32 +7,33 @@ signal_output = []
class LazyRegexCompilerTest(unittest.TestCase):
-
def test_lazy_regex_compiler_verify_laziness_of_descriptor(self):
class UserEmail(object):
- EMAIL_REGEX = LazyRegexCompiler('@', flags=32)
+ EMAIL_REGEX = LazyRegexCompiler("@", flags=32)
- descriptor = UserEmail.__dict__['EMAIL_REGEX']
+ descriptor = UserEmail.__dict__["EMAIL_REGEX"]
self.assertIsNone(descriptor._compiled_regex)
regex = UserEmail.EMAIL_REGEX
- self.assertEqual(regex, re.compile('@', flags=32))
- self.assertEqual(regex.search('user@domain.com').group(), '@')
+ self.assertEqual(regex, re.compile("@", flags=32))
+ self.assertEqual(regex.search("user@domain.com").group(), "@")
user_email = UserEmail()
self.assertIs(user_email.EMAIL_REGEX, UserEmail.EMAIL_REGEX)
def test_lazy_regex_compiler_verify_cannot_set_descriptor_on_instance(self):
class UserEmail(object):
- EMAIL_REGEX = LazyRegexCompiler('@')
+ EMAIL_REGEX = LazyRegexCompiler("@")
user_email = UserEmail()
with self.assertRaises(AttributeError):
- user_email.EMAIL_REGEX = re.compile('@')
+ user_email.EMAIL_REGEX = re.compile("@")
def test_lazy_regex_compiler_verify_can_override_class_attr(self):
class UserEmail(object):
- EMAIL_REGEX = LazyRegexCompiler('@')
+ EMAIL_REGEX = LazyRegexCompiler("@")
- UserEmail.EMAIL_REGEX = re.compile('cookies')
- self.assertEqual(UserEmail.EMAIL_REGEX.search('Cake & cookies').group(), 'cookies')
+ UserEmail.EMAIL_REGEX = re.compile("cookies")
+ self.assertEqual(
+ UserEmail.EMAIL_REGEX.search("Cake & cookies").group(), "cookies"
+ )
diff --git a/tests/utils.py b/tests/utils.py
index 27d5ada7..eb3f016f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -8,7 +8,7 @@ from mongoengine.connection import get_db, disconnect_all
from mongoengine.mongodb_support import get_mongodb_version
-MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
+MONGO_TEST_DB = "mongoenginetest" # standard name for the test database
class MongoDBTestCase(unittest.TestCase):
@@ -53,12 +53,15 @@ def _decorated_with_ver_requirement(func, mongo_version_req, oper):
:param mongo_version_req: The mongodb version requirement (tuple(int, int))
:param oper: The operator to apply (e.g: operator.ge)
"""
+
def _inner(*args, **kwargs):
mongodb_v = get_mongodb_version()
if oper(mongodb_v, mongo_version_req):
return func(*args, **kwargs)
- raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in mongo_version_req)))
+ raise SkipTest(
+ "Needs MongoDB v{}+".format(".".join(str(n) for n in mongo_version_req))
+ )
_inner.__name__ = func.__name__
_inner.__doc__ = func.__doc__