From b47669403b49428f00716ee983ffd269e402557f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20W=C3=B3jcik?= Date: Thu, 27 Jun 2019 13:05:54 +0200 Subject: [PATCH] Format the codebase using Black (#2109) This commit: 1. Formats all of our existing code using `black`. 2. Adds a note about using `black` to `CONTRIBUTING.rst`. 3. Runs `black --check` as part of CI (failing builds that aren't properly formatted). --- .travis.yml | 9 +- CONTRIBUTING.rst | 8 +- benchmarks/test_basic_doc_ops.py | 92 +- benchmarks/test_inserts.py | 42 +- docs/code/tumblelog.py | 36 +- docs/conf.py | 93 +- mongoengine/__init__.py | 13 +- mongoengine/base/__init__.py | 23 +- mongoengine/base/common.py | 51 +- mongoengine/base/datastructures.py | 70 +- mongoengine/base/document.py | 528 ++-- mongoengine/base/fields.py | 278 ++- mongoengine/base/metaclasses.py | 333 +-- mongoengine/common.py | 22 +- mongoengine/connection.py | 190 +- mongoengine/context_managers.py | 51 +- mongoengine/dereference.py | 110 +- mongoengine/document.py | 400 +-- mongoengine/errors.py | 44 +- mongoengine/fields.py | 890 ++++--- mongoengine/mongodb_support.py | 2 +- mongoengine/pymongo_support.py | 2 +- mongoengine/queryset/__init__.py | 21 +- mongoengine/queryset/base.py | 538 ++-- mongoengine/queryset/field_list.py | 13 +- mongoengine/queryset/manager.py | 4 +- mongoengine/queryset/queryset.py | 36 +- mongoengine/queryset/transform.py | 346 +-- mongoengine/queryset/visitor.py | 14 +- mongoengine/signals.py | 43 +- requirements.txt | 1 + setup.cfg | 2 +- setup.py | 64 +- tests/all_warnings/__init__.py | 18 +- tests/document/__init__.py | 2 +- tests/document/class_methods.py | 229 +- tests/document/delta.py | 712 +++--- tests/document/dynamic.py | 255 +- tests/document/indexes.py | 697 +++--- tests/document/inheritance.py | 477 ++-- tests/document/instance.py | 1273 +++++----- tests/document/json_serialisation.py | 45 +- tests/document/validation.py | 108 +- tests/fields/fields.py | 974 ++++---- tests/fields/file_tests.py | 141 +- tests/fields/geo.py | 149 +- tests/fields/test_binary_field.py | 33 +- tests/fields/test_boolean_field.py | 15 +- tests/fields/test_cached_reference_field.py | 312 +-- tests/fields/test_complex_datetime_field.py | 31 +- tests/fields/test_date_field.py | 22 +- tests/fields/test_datetime_field.py | 45 +- tests/fields/test_decimal_field.py | 55 +- tests/fields/test_dict_field.py | 205 +- tests/fields/test_email_field.py | 45 +- tests/fields/test_embedded_document_field.py | 152 +- tests/fields/test_float_field.py | 6 +- tests/fields/test_int_field.py | 4 +- tests/fields/test_lazy_reference_field.py | 120 +- tests/fields/test_long_field.py | 9 +- tests/fields/test_map_field.py | 59 +- tests/fields/test_reference_field.py | 51 +- tests/fields/test_sequence_field.py | 148 +- tests/fields/test_url_field.py | 24 +- tests/fields/test_uuid_field.py | 19 +- tests/fixtures.py | 8 +- tests/queryset/field_list.py | 249 +- tests/queryset/geo.py | 178 +- tests/queryset/modify.py | 34 +- tests/queryset/pickable.py | 24 +- tests/queryset/queryset.py | 2345 +++++++++--------- tests/queryset/transform.py | 240 +- tests/queryset/visitor.py | 205 +- tests/test_common.py | 1 - tests/test_connection.py | 393 +-- tests/test_context_managers.py | 111 +- tests/test_datastructures.py | 205 +- tests/test_dereference.py | 252 +- tests/test_replicaset_connection.py | 11 +- tests/test_signals.py | 392 +-- tests/test_utils.py | 21 +- tests/utils.py | 7 +- 82 files changed, 8405 insertions(+), 7075 deletions(-) diff --git a/.travis.yml b/.travis.yml index f9993e79..8af73c6b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,19 +52,22 @@ install: - wget http://fastdl.mongodb.org/linux/mongodb-linux-x86_64-${MONGODB}.tgz - tar xzf mongodb-linux-x86_64-${MONGODB}.tgz - ${PWD}/mongodb-linux-x86_64-${MONGODB}/bin/mongod --version - # Install python dependencies + # Install Python dependencies. - pip install --upgrade pip - pip install coveralls - pip install flake8 flake8-import-order - pip install tox # tox 3.11.0 has requirement virtualenv>=14.0.0 - pip install virtualenv # virtualenv>=14.0.0 has dropped Python 3.2 support (and pypy3 is based on py32) - # Install the tox venv + # Install the tox venv. - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test + # Install black for Python v3.7 only. + - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then pip install black; fi before_script: - mkdir ${PWD}/mongodb-linux-x86_64-${MONGODB}/data - ${PWD}/mongodb-linux-x86_64-${MONGODB}/bin/mongod --dbpath ${PWD}/mongodb-linux-x86_64-${MONGODB}/data --logpath ${PWD}/mongodb-linux-x86_64-${MONGODB}/mongodb.log --fork - - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then flake8 .; else echo "flake8 only runs on py27"; fi # Run flake8 for py27 + - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then flake8 .; else echo "flake8 only runs on py27"; fi # Run flake8 for Python 2.7 only + - if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then black --check .; else echo "black only runs on py37"; fi # Run black for Python 3.7 only - mongo --eval 'db.version();' # Make sure mongo is awake script: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index f7b15c85..4711c1d3 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -31,12 +31,8 @@ build. You should ensure that your code is properly converted by Style Guide ----------- -MongoEngine aims to follow `PEP8 `_ -including 4 space indents. When possible we try to stick to 79 character line -limits. However, screens got bigger and an ORM has a strong focus on -readability and if it can help, we accept 119 as maximum line length, in a -similar way as `django does -`_ +MongoEngine uses `black `_ for code +formatting. Testing ------- diff --git a/benchmarks/test_basic_doc_ops.py b/benchmarks/test_basic_doc_ops.py index 06f0538b..e840f97a 100644 --- a/benchmarks/test_basic_doc_ops.py +++ b/benchmarks/test_basic_doc_ops.py @@ -1,11 +1,18 @@ from timeit import repeat import mongoengine -from mongoengine import (BooleanField, Document, EmailField, EmbeddedDocument, - EmbeddedDocumentField, IntField, ListField, - StringField) +from mongoengine import ( + BooleanField, + Document, + EmailField, + EmbeddedDocument, + EmbeddedDocumentField, + IntField, + ListField, + StringField, +) -mongoengine.connect(db='mongoengine_benchmark_test') +mongoengine.connect(db="mongoengine_benchmark_test") def timeit(f, n=10000): @@ -24,46 +31,41 @@ def test_basic(): def init_book(): return Book( - name='Always be closing', + name="Always be closing", pages=100, - tags=['self-help', 'sales'], + tags=["self-help", "sales"], is_published=True, - author_email='alec@example.com', + author_email="alec@example.com", ) - print('Doc initialization: %.3fus' % (timeit(init_book, 1000) * 10**6)) + print("Doc initialization: %.3fus" % (timeit(init_book, 1000) * 10 ** 6)) b = init_book() - print('Doc getattr: %.3fus' % (timeit(lambda: b.name, 10000) * 10**6)) + print("Doc getattr: %.3fus" % (timeit(lambda: b.name, 10000) * 10 ** 6)) print( - 'Doc setattr: %.3fus' % ( - timeit(lambda: setattr(b, 'name', 'New name'), 10000) * 10**6 - ) + "Doc setattr: %.3fus" + % (timeit(lambda: setattr(b, "name", "New name"), 10000) * 10 ** 6) ) - print('Doc to mongo: %.3fus' % (timeit(b.to_mongo, 1000) * 10**6)) + print("Doc to mongo: %.3fus" % (timeit(b.to_mongo, 1000) * 10 ** 6)) - print('Doc validation: %.3fus' % (timeit(b.validate, 1000) * 10**6)) + print("Doc validation: %.3fus" % (timeit(b.validate, 1000) * 10 ** 6)) def save_book(): - b._mark_as_changed('name') - b._mark_as_changed('tags') + b._mark_as_changed("name") + b._mark_as_changed("tags") b.save() - print('Save to database: %.3fus' % (timeit(save_book, 100) * 10**6)) + print("Save to database: %.3fus" % (timeit(save_book, 100) * 10 ** 6)) son = b.to_mongo() print( - 'Load from SON: %.3fus' % ( - timeit(lambda: Book._from_son(son), 1000) * 10**6 - ) + "Load from SON: %.3fus" % (timeit(lambda: Book._from_son(son), 1000) * 10 ** 6) ) print( - 'Load from database: %.3fus' % ( - timeit(lambda: Book.objects[0], 100) * 10**6 - ) + "Load from database: %.3fus" % (timeit(lambda: Book.objects[0], 100) * 10 ** 6) ) def create_and_delete_book(): @@ -72,9 +74,8 @@ def test_basic(): b.delete() print( - 'Init + save to database + delete: %.3fms' % ( - timeit(create_and_delete_book, 10) * 10**3 - ) + "Init + save to database + delete: %.3fms" + % (timeit(create_and_delete_book, 10) * 10 ** 3) ) @@ -92,42 +93,36 @@ def test_big_doc(): def init_company(): return Company( - name='MongoDB, Inc.', + name="MongoDB, Inc.", contacts=[ - Contact( - name='Contact %d' % x, - title='CEO', - address='Address %d' % x, - ) + Contact(name="Contact %d" % x, title="CEO", address="Address %d" % x) for x in range(1000) - ] + ], ) company = init_company() - print('Big doc to mongo: %.3fms' % (timeit(company.to_mongo, 100) * 10**3)) + print("Big doc to mongo: %.3fms" % (timeit(company.to_mongo, 100) * 10 ** 3)) - print('Big doc validation: %.3fms' % (timeit(company.validate, 1000) * 10**3)) + print("Big doc validation: %.3fms" % (timeit(company.validate, 1000) * 10 ** 3)) company.save() def save_company(): - company._mark_as_changed('name') - company._mark_as_changed('contacts') + company._mark_as_changed("name") + company._mark_as_changed("contacts") company.save() - print('Save to database: %.3fms' % (timeit(save_company, 100) * 10**3)) + print("Save to database: %.3fms" % (timeit(save_company, 100) * 10 ** 3)) son = company.to_mongo() print( - 'Load from SON: %.3fms' % ( - timeit(lambda: Company._from_son(son), 100) * 10**3 - ) + "Load from SON: %.3fms" + % (timeit(lambda: Company._from_son(son), 100) * 10 ** 3) ) print( - 'Load from database: %.3fms' % ( - timeit(lambda: Company.objects[0], 100) * 10**3 - ) + "Load from database: %.3fms" + % (timeit(lambda: Company.objects[0], 100) * 10 ** 3) ) def create_and_delete_company(): @@ -136,13 +131,12 @@ def test_big_doc(): c.delete() print( - 'Init + save to database + delete: %.3fms' % ( - timeit(create_and_delete_company, 10) * 10**3 - ) + "Init + save to database + delete: %.3fms" + % (timeit(create_and_delete_company, 10) * 10 ** 3) ) -if __name__ == '__main__': +if __name__ == "__main__": test_basic() - print('-' * 100) + print("-" * 100) test_big_doc() diff --git a/benchmarks/test_inserts.py b/benchmarks/test_inserts.py index 8113d988..fd017bae 100644 --- a/benchmarks/test_inserts.py +++ b/benchmarks/test_inserts.py @@ -26,10 +26,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print('-' * 100) - print('PyMongo: Creating 10000 dictionaries.') + print("-" * 100) + print("PyMongo: Creating 10000 dictionaries.") t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) stmt = """ from pymongo import MongoClient, WriteConcern @@ -49,10 +49,10 @@ myNoddys = noddy.find() [n for n in myNoddys] # iterate """ - print('-' * 100) + print("-" * 100) print('PyMongo: Creating 10000 dictionaries (write_concern={"w": 0}).') t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) setup = """ from pymongo import MongoClient @@ -78,10 +78,10 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print('-' * 100) - print('MongoEngine: Creating 10000 dictionaries.') + print("-" * 100) + print("MongoEngine: Creating 10000 dictionaries.") t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) stmt = """ for i in range(10000): @@ -96,10 +96,10 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print('-' * 100) - print('MongoEngine: Creating 10000 dictionaries (using a single field assignment).') + print("-" * 100) + print("MongoEngine: Creating 10000 dictionaries (using a single field assignment).") t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) stmt = """ for i in range(10000): @@ -112,10 +112,10 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print('-' * 100) + print("-" * 100) print('MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}).') t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) stmt = """ for i in range(10000): @@ -128,10 +128,12 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print('-' * 100) - print('MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}, validate=False).') + print("-" * 100) + print( + 'MongoEngine: Creating 10000 dictionaries (write_concern={"w": 0}, validate=False).' + ) t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) stmt = """ for i in range(10000): @@ -144,10 +146,12 @@ myNoddys = Noddy.objects() [n for n in myNoddys] # iterate """ - print('-' * 100) - print('MongoEngine: Creating 10000 dictionaries (force_insert=True, write_concern={"w": 0}, validate=False).') + print("-" * 100) + print( + 'MongoEngine: Creating 10000 dictionaries (force_insert=True, write_concern={"w": 0}, validate=False).' + ) t = timeit.Timer(stmt=stmt, setup=setup) - print('{}s'.format(t.timeit(1))) + print("{}s".format(t.timeit(1))) if __name__ == "__main__": diff --git a/docs/code/tumblelog.py b/docs/code/tumblelog.py index 796336e6..3ca2384c 100644 --- a/docs/code/tumblelog.py +++ b/docs/code/tumblelog.py @@ -1,16 +1,19 @@ from mongoengine import * -connect('tumblelog') +connect("tumblelog") + class Comment(EmbeddedDocument): content = StringField() name = StringField(max_length=120) + class User(Document): email = StringField(required=True) first_name = StringField(max_length=50) last_name = StringField(max_length=50) + class Post(Document): title = StringField(max_length=120, required=True) author = ReferenceField(User) @@ -18,54 +21,57 @@ class Post(Document): comments = ListField(EmbeddedDocumentField(Comment)) # bugfix - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class TextPost(Post): content = StringField() + class ImagePost(Post): image_path = StringField() + class LinkPost(Post): link_url = StringField() + Post.drop_collection() -john = User(email='jdoe@example.com', first_name='John', last_name='Doe') +john = User(email="jdoe@example.com", first_name="John", last_name="Doe") john.save() -post1 = TextPost(title='Fun with MongoEngine', author=john) -post1.content = 'Took a look at MongoEngine today, looks pretty cool.' -post1.tags = ['mongodb', 'mongoengine'] +post1 = TextPost(title="Fun with MongoEngine", author=john) +post1.content = "Took a look at MongoEngine today, looks pretty cool." +post1.tags = ["mongodb", "mongoengine"] post1.save() -post2 = LinkPost(title='MongoEngine Documentation', author=john) -post2.link_url = 'http://tractiondigital.com/labs/mongoengine/docs' -post2.tags = ['mongoengine'] +post2 = LinkPost(title="MongoEngine Documentation", author=john) +post2.link_url = "http://tractiondigital.com/labs/mongoengine/docs" +post2.tags = ["mongoengine"] post2.save() -print('ALL POSTS') +print("ALL POSTS") print() for post in Post.objects: print(post.title) - #print '=' * post.title.count() + # print '=' * post.title.count() print("=" * 20) if isinstance(post, TextPost): print(post.content) if isinstance(post, LinkPost): - print('Link:', post.link_url) + print("Link:", post.link_url) print() print() -print('POSTS TAGGED \'MONGODB\'') +print("POSTS TAGGED 'MONGODB'") print() -for post in Post.objects(tags='mongodb'): +for post in Post.objects(tags="mongodb"): print(post.title) print() -num_posts = Post.objects(tags='mongodb').count() +num_posts = Post.objects(tags="mongodb").count() print('Found %d posts with tag "mongodb"' % num_posts) diff --git a/docs/conf.py b/docs/conf.py index 468e71e0..0d642e0c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,29 +20,29 @@ import mongoengine # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ----------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.todo'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.todo"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8' +# source_encoding = 'utf-8' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'MongoEngine' -copyright = u'2009, MongoEngine Authors' +project = u"MongoEngine" +copyright = u"2009, MongoEngine Authors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -55,68 +55,66 @@ release = mongoengine.get_version() # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of documents that shouldn't be included in the build. -#unused_docs = [] +# unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. Major themes that come with # Sphinx are currently 'default' and 'sphinxdoc'. -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -html_theme_options = { - 'canonical_url': 'http://docs.mongoengine.org/en/latest/' -} +html_theme_options = {"canonical_url": "http://docs.mongoengine.org/en/latest/"} # Add any paths that contain custom themes here, relative to this directory. html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -126,11 +124,11 @@ html_favicon = "favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. @@ -138,69 +136,68 @@ html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = { - 'index': ['globaltoc.html', 'searchbox.html'], - '**': ['localtoc.html', 'relations.html', 'searchbox.html'] + "index": ["globaltoc.html", "searchbox.html"], + "**": ["localtoc.html", "relations.html", "searchbox.html"], } # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_use_modindex = True +# html_use_modindex = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' +# html_file_suffix = '' # Output file base name for HTML help builder. -htmlhelp_basename = 'MongoEnginedoc' +htmlhelp_basename = "MongoEnginedoc" # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). -latex_paper_size = 'a4' +latex_paper_size = "a4" # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'MongoEngine.tex', 'MongoEngine Documentation', - 'Ross Lawley', 'manual'), + ("index", "MongoEngine.tex", "MongoEngine Documentation", "Ross Lawley", "manual") ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_use_modindex = True +# latex_use_modindex = True -autoclass_content = 'both' +autoclass_content = "both" diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index bb7a4e57..d7093d28 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -18,9 +18,14 @@ from mongoengine.queryset import * from mongoengine.signals import * -__all__ = (list(document.__all__) + list(fields.__all__) + - list(connection.__all__) + list(queryset.__all__) + - list(signals.__all__) + list(errors.__all__)) +__all__ = ( + list(document.__all__) + + list(fields.__all__) + + list(connection.__all__) + + list(queryset.__all__) + + list(signals.__all__) + + list(errors.__all__) +) VERSION = (0, 18, 2) @@ -31,7 +36,7 @@ def get_version(): For example, if `VERSION == (0, 10, 7)`, return '0.10.7'. """ - return '.'.join(map(str, VERSION)) + return ".".join(map(str, VERSION)) __version__ = get_version() diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index e069a147..dca0c4bb 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -12,17 +12,22 @@ from mongoengine.base.metaclasses import * __all__ = ( # common - 'UPDATE_OPERATORS', '_document_registry', 'get_document', - + "UPDATE_OPERATORS", + "_document_registry", + "get_document", # datastructures - 'BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference', - + "BaseDict", + "BaseList", + "EmbeddedDocumentList", + "LazyReference", # document - 'BaseDocument', - + "BaseDocument", # fields - 'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField', - + "BaseField", + "ComplexBaseField", + "ObjectIdField", + "GeoJsonBaseField", # metaclasses - 'DocumentMetaclass', 'TopLevelDocumentMetaclass' + "DocumentMetaclass", + "TopLevelDocumentMetaclass", ) diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index 999fd23a..85897324 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -1,12 +1,25 @@ from mongoengine.errors import NotRegistered -__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry') +__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry") -UPDATE_OPERATORS = {'set', 'unset', 'inc', 'dec', 'mul', - 'pop', 'push', 'push_all', 'pull', - 'pull_all', 'add_to_set', 'set_on_insert', - 'min', 'max', 'rename'} +UPDATE_OPERATORS = { + "set", + "unset", + "inc", + "dec", + "mul", + "pop", + "push", + "push_all", + "pull", + "pull_all", + "add_to_set", + "set_on_insert", + "min", + "max", + "rename", +} _document_registry = {} @@ -17,25 +30,33 @@ def get_document(name): doc = _document_registry.get(name, None) if not doc: # Possible old style name - single_end = name.split('.')[-1] - compound_end = '.%s' % single_end - possible_match = [k for k in _document_registry - if k.endswith(compound_end) or k == single_end] + single_end = name.split(".")[-1] + compound_end = ".%s" % single_end + possible_match = [ + k for k in _document_registry if k.endswith(compound_end) or k == single_end + ] if len(possible_match) == 1: doc = _document_registry.get(possible_match.pop(), None) if not doc: - raise NotRegistered(""" + raise NotRegistered( + """ `%s` has not been registered in the document registry. Importing the document class automatically registers it, has it been imported? - """.strip() % name) + """.strip() + % name + ) return doc def _get_documents_by_db(connection_alias, default_connection_alias): """Get all registered Documents class attached to a given database""" - def get_doc_alias(doc_cls): - return doc_cls._meta.get('db_alias', default_connection_alias) - return [doc_cls for doc_cls in _document_registry.values() - if get_doc_alias(doc_cls) == connection_alias] + def get_doc_alias(doc_cls): + return doc_cls._meta.get("db_alias", default_connection_alias) + + return [ + doc_cls + for doc_cls in _document_registry.values() + if get_doc_alias(doc_cls) == connection_alias + ] diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index cce71846..d1b5ae76 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -7,26 +7,36 @@ from six import iteritems from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned -__all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference') +__all__ = ( + "BaseDict", + "StrictDict", + "BaseList", + "EmbeddedDocumentList", + "LazyReference", +) def mark_as_changed_wrapper(parent_method): """Decorator that ensures _mark_as_changed method gets called.""" + def wrapper(self, *args, **kwargs): # Can't use super() in the decorator. result = parent_method(self, *args, **kwargs) self._mark_as_changed() return result + return wrapper def mark_key_as_changed_wrapper(parent_method): """Decorator that ensures _mark_as_changed method gets called with the key argument""" + def wrapper(self, key, *args, **kwargs): # Can't use super() in the decorator. result = parent_method(self, key, *args, **kwargs) self._mark_as_changed(key) return result + return wrapper @@ -38,7 +48,7 @@ class BaseDict(dict): _name = None def __init__(self, dict_items, instance, name): - BaseDocument = _import_class('BaseDocument') + BaseDocument = _import_class("BaseDocument") if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) @@ -55,15 +65,15 @@ class BaseDict(dict): def __getitem__(self, key): value = super(BaseDict, self).__getitem__(key) - EmbeddedDocument = _import_class('EmbeddedDocument') + EmbeddedDocument = _import_class("EmbeddedDocument") if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, None, '%s.%s' % (self._name, key)) + value = BaseDict(value, None, "%s.%s" % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): - value = BaseList(value, None, '%s.%s' % (self._name, key)) + value = BaseList(value, None, "%s.%s" % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance return value @@ -87,9 +97,9 @@ class BaseDict(dict): setdefault = mark_as_changed_wrapper(dict.setdefault) def _mark_as_changed(self, key=None): - if hasattr(self._instance, '_mark_as_changed'): + if hasattr(self._instance, "_mark_as_changed"): if key: - self._instance._mark_as_changed('%s.%s' % (self._name, key)) + self._instance._mark_as_changed("%s.%s" % (self._name, key)) else: self._instance._mark_as_changed(self._name) @@ -102,7 +112,7 @@ class BaseList(list): _name = None def __init__(self, list_items, instance, name): - BaseDocument = _import_class('BaseDocument') + BaseDocument = _import_class("BaseDocument") if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) @@ -117,17 +127,17 @@ class BaseList(list): # to parent's instance. This is buggy for now but would require more work to be handled properly return value - EmbeddedDocument = _import_class('EmbeddedDocument') + EmbeddedDocument = _import_class("EmbeddedDocument") if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): # Replace dict by BaseDict - value = BaseDict(value, None, '%s.%s' % (self._name, key)) + value = BaseDict(value, None, "%s.%s" % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): # Replace list by BaseList - value = BaseList(value, None, '%s.%s' % (self._name, key)) + value = BaseList(value, None, "%s.%s" % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance return value @@ -181,17 +191,14 @@ class BaseList(list): return self.__getitem__(slice(i, j)) def _mark_as_changed(self, key=None): - if hasattr(self._instance, '_mark_as_changed'): + if hasattr(self._instance, "_mark_as_changed"): if key: - self._instance._mark_as_changed( - '%s.%s' % (self._name, key % len(self)) - ) + self._instance._mark_as_changed("%s.%s" % (self._name, key % len(self))) else: self._instance._mark_as_changed(self._name) class EmbeddedDocumentList(BaseList): - def __init__(self, list_items, instance, name): super(EmbeddedDocumentList, self).__init__(list_items, instance, name) self._instance = instance @@ -276,12 +283,10 @@ class EmbeddedDocumentList(BaseList): """ values = self.__only_matches(self, kwargs) if len(values) == 0: - raise DoesNotExist( - '%s matching query does not exist.' % self._name - ) + raise DoesNotExist("%s matching query does not exist." % self._name) elif len(values) > 1: raise MultipleObjectsReturned( - '%d items returned, instead of 1' % len(values) + "%d items returned, instead of 1" % len(values) ) return values[0] @@ -362,7 +367,7 @@ class EmbeddedDocumentList(BaseList): class StrictDict(object): __slots__ = () - _special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'} + _special_fields = {"get", "pop", "iteritems", "items", "keys", "create"} _classes = {} def __init__(self, **kwargs): @@ -370,14 +375,14 @@ class StrictDict(object): setattr(self, k, v) def __getitem__(self, key): - key = '_reserved_' + key if key in self._special_fields else key + key = "_reserved_" + key if key in self._special_fields else key try: return getattr(self, key) except AttributeError: raise KeyError(key) def __setitem__(self, key, value): - key = '_reserved_' + key if key in self._special_fields else key + key = "_reserved_" + key if key in self._special_fields else key return setattr(self, key, value) def __contains__(self, key): @@ -424,27 +429,32 @@ class StrictDict(object): @classmethod def create(cls, allowed_keys): - allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) + allowed_keys_tuple = tuple( + ("_reserved_" + k if k in cls._special_fields else k) for k in allowed_keys + ) allowed_keys = frozenset(allowed_keys_tuple) if allowed_keys not in cls._classes: + class SpecificStrictDict(cls): __slots__ = allowed_keys_tuple def __repr__(self): - return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) + return "{%s}" % ", ".join( + '"{0!s}": {1!r}'.format(k, v) for k, v in self.items() + ) cls._classes[allowed_keys] = SpecificStrictDict return cls._classes[allowed_keys] class LazyReference(DBRef): - __slots__ = ('_cached_doc', 'passthrough', 'document_type') + __slots__ = ("_cached_doc", "passthrough", "document_type") def fetch(self, force=False): if not self._cached_doc or force: self._cached_doc = self.document_type.objects.get(pk=self.pk) if not self._cached_doc: - raise DoesNotExist('Trying to dereference unknown document %s' % (self)) + raise DoesNotExist("Trying to dereference unknown document %s" % (self)) return self._cached_doc @property @@ -455,7 +465,9 @@ class LazyReference(DBRef): self.document_type = document_type self._cached_doc = cached_doc self.passthrough = passthrough - super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk) + super(LazyReference, self).__init__( + self.document_type._get_collection_name(), pk + ) def __getitem__(self, name): if not self.passthrough: @@ -464,7 +476,7 @@ class LazyReference(DBRef): return document[name] def __getattr__(self, name): - if not object.__getattribute__(self, 'passthrough'): + if not object.__getattribute__(self, "passthrough"): raise AttributeError() document = self.fetch() try: diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 047d50a4..928a00c2 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -9,19 +9,27 @@ from six import iteritems from mongoengine import signals from mongoengine.base.common import get_document -from mongoengine.base.datastructures import (BaseDict, BaseList, - EmbeddedDocumentList, - LazyReference, - StrictDict) +from mongoengine.base.datastructures import ( + BaseDict, + BaseList, + EmbeddedDocumentList, + LazyReference, + StrictDict, +) from mongoengine.base.fields import ComplexBaseField from mongoengine.common import _import_class -from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, - LookUpError, OperationError, ValidationError) +from mongoengine.errors import ( + FieldDoesNotExist, + InvalidDocumentError, + LookUpError, + OperationError, + ValidationError, +) from mongoengine.python_support import Hashable -__all__ = ('BaseDocument', 'NON_FIELD_ERRORS') +__all__ = ("BaseDocument", "NON_FIELD_ERRORS") -NON_FIELD_ERRORS = '__all__' +NON_FIELD_ERRORS = "__all__" class BaseDocument(object): @@ -35,9 +43,16 @@ class BaseDocument(object): # field is primarily set via `_from_son` or `_clear_changed_fields`, # though there are also other methods that manipulate it. # 4. The codebase is littered with `hasattr` calls for `_changed_fields`. - __slots__ = ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields', '_auto_id_field', '_db_field_map', - '__weakref__') + __slots__ = ( + "_changed_fields", + "_initialised", + "_created", + "_data", + "_dynamic_fields", + "_auto_id_field", + "_db_field_map", + "__weakref__", + ) _dynamic = False _dynamic_lock = True @@ -61,27 +76,28 @@ class BaseDocument(object): if args: raise TypeError( - 'Instantiating a document with positional arguments is not ' - 'supported. Please use `field_name=value` keyword arguments.' + "Instantiating a document with positional arguments is not " + "supported. Please use `field_name=value` keyword arguments." ) - __auto_convert = values.pop('__auto_convert', True) + __auto_convert = values.pop("__auto_convert", True) - __only_fields = set(values.pop('__only_fields', values)) + __only_fields = set(values.pop("__only_fields", values)) - _created = values.pop('_created', True) + _created = values.pop("_created", True) signals.pre_init.send(self.__class__, document=self, values=values) # Check if there are undefined fields supplied to the constructor, # if so raise an Exception. - if not self._dynamic and (self._meta.get('strict', True) or _created): + if not self._dynamic and (self._meta.get("strict", True) or _created): _undefined_fields = set(values.keys()) - set( - self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) + self._fields.keys() + ["id", "pk", "_cls", "_text_score"] + ) if _undefined_fields: - msg = ( - 'The fields "{0}" do not exist on the document "{1}"' - ).format(_undefined_fields, self._class_name) + msg = ('The fields "{0}" do not exist on the document "{1}"').format( + _undefined_fields, self._class_name + ) raise FieldDoesNotExist(msg) if self.STRICT and not self._dynamic: @@ -100,22 +116,22 @@ class BaseDocument(object): value = getattr(self, key, None) setattr(self, key, value) - if '_cls' not in values: + if "_cls" not in values: self._cls = self._class_name # Set passed values after initialisation if self._dynamic: dynamic_data = {} for key, value in iteritems(values): - if key in self._fields or key == '_id': + if key in self._fields or key == "_id": setattr(self, key, value) else: dynamic_data[key] = value else: - FileField = _import_class('FileField') + FileField = _import_class("FileField") for key, value in iteritems(values): key = self._reverse_db_field_map.get(key, key) - if key in self._fields or key in ('id', 'pk', '_cls'): + if key in self._fields or key in ("id", "pk", "_cls"): if __auto_convert and value is not None: field = self._fields.get(key) if field and not isinstance(field, FileField): @@ -153,20 +169,20 @@ class BaseDocument(object): # Handle dynamic data only if an initialised dynamic document if self._dynamic and not self._dynamic_lock: - if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class('DynamicField') + if not hasattr(self, name) and not name.startswith("_"): + DynamicField = _import_class("DynamicField") field = DynamicField(db_field=name, null=True) field.name = name self._dynamic_fields[name] = field self._fields_ordered += (name,) - if not name.startswith('_'): + if not name.startswith("_"): value = self.__expand_dynamic_values(name, value) # Handle marking data as changed if name in self._dynamic_fields: self._data[name] = value - if hasattr(self, '_changed_fields'): + if hasattr(self, "_changed_fields"): self._mark_as_changed(name) try: self__created = self._created @@ -174,12 +190,12 @@ class BaseDocument(object): self__created = True if ( - self._is_document and - not self__created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value + self._is_document + and not self__created + and name in self._meta.get("shard_key", tuple()) + and self._data.get(name) != value ): - msg = 'Shard Keys are immutable. Tried to update %s' % name + msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) try: @@ -187,38 +203,52 @@ class BaseDocument(object): except AttributeError: self__initialised = False # Check if the user has created a new instance of a class - if (self._is_document and self__initialised and - self__created and name == self._meta.get('id_field')): - super(BaseDocument, self).__setattr__('_created', False) + if ( + self._is_document + and self__initialised + and self__created + and name == self._meta.get("id_field") + ): + super(BaseDocument, self).__setattr__("_created", False) super(BaseDocument, self).__setattr__(name, value) def __getstate__(self): data = {} - for k in ('_changed_fields', '_initialised', '_created', - '_dynamic_fields', '_fields_ordered'): + for k in ( + "_changed_fields", + "_initialised", + "_created", + "_dynamic_fields", + "_fields_ordered", + ): if hasattr(self, k): data[k] = getattr(self, k) - data['_data'] = self.to_mongo() + data["_data"] = self.to_mongo() return data def __setstate__(self, data): - if isinstance(data['_data'], SON): - data['_data'] = self.__class__._from_son(data['_data'])._data - for k in ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields'): + if isinstance(data["_data"], SON): + data["_data"] = self.__class__._from_son(data["_data"])._data + for k in ( + "_changed_fields", + "_initialised", + "_created", + "_data", + "_dynamic_fields", + ): if k in data: setattr(self, k, data[k]) - if '_fields_ordered' in data: + if "_fields_ordered" in data: if self._dynamic: - setattr(self, '_fields_ordered', data['_fields_ordered']) + setattr(self, "_fields_ordered", data["_fields_ordered"]) else: _super_fields_ordered = type(self)._fields_ordered - setattr(self, '_fields_ordered', _super_fields_ordered) + setattr(self, "_fields_ordered", _super_fields_ordered) - dynamic_fields = data.get('_dynamic_fields') or SON() + dynamic_fields = data.get("_dynamic_fields") or SON() for k in dynamic_fields.keys(): - setattr(self, k, data['_data'].get(k)) + setattr(self, k, data["_data"].get(k)) def __iter__(self): return iter(self._fields_ordered) @@ -255,24 +285,30 @@ class BaseDocument(object): try: u = self.__str__() except (UnicodeEncodeError, UnicodeDecodeError): - u = '[Bad Unicode data]' + u = "[Bad Unicode data]" repr_type = str if u is None else type(u) - return repr_type('<%s: %s>' % (self.__class__.__name__, u)) + return repr_type("<%s: %s>" % (self.__class__.__name__, u)) def __str__(self): # TODO this could be simpler? - if hasattr(self, '__unicode__'): + if hasattr(self, "__unicode__"): if six.PY3: return self.__unicode__() else: - return six.text_type(self).encode('utf-8') - return six.text_type('%s object' % self.__class__.__name__) + return six.text_type(self).encode("utf-8") + return six.text_type("%s object" % self.__class__.__name__) def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: + if ( + isinstance(other, self.__class__) + and hasattr(other, "id") + and other.id is not None + ): return self.id == other.id if isinstance(other, DBRef): - return self._get_collection_name() == other.collection and self.id == other.id + return ( + self._get_collection_name() == other.collection and self.id == other.id + ) if self.id is None: return self is other return False @@ -295,10 +331,12 @@ class BaseDocument(object): Get text score from text query """ - if '_text_score' not in self._data: - raise InvalidDocumentError('This document is not originally built from a text query') + if "_text_score" not in self._data: + raise InvalidDocumentError( + "This document is not originally built from a text query" + ) - return self._data['_text_score'] + return self._data["_text_score"] def to_mongo(self, use_db_field=True, fields=None): """ @@ -307,11 +345,11 @@ class BaseDocument(object): fields = fields or [] data = SON() - data['_id'] = None - data['_cls'] = self._class_name + data["_id"] = None + data["_cls"] = self._class_name # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] - root_fields = {f.split('.')[0] for f in fields} + root_fields = {f.split(".")[0] for f in fields} for field_name in self: if root_fields and field_name not in root_fields: @@ -326,16 +364,16 @@ class BaseDocument(object): if value is not None: f_inputs = field.to_mongo.__code__.co_varnames ex_vars = {} - if fields and 'fields' in f_inputs: - key = '%s.' % field_name + if fields and "fields" in f_inputs: + key = "%s." % field_name embedded_fields = [ - i.replace(key, '') for i in fields - if i.startswith(key)] + i.replace(key, "") for i in fields if i.startswith(key) + ] - ex_vars['fields'] = embedded_fields + ex_vars["fields"] = embedded_fields - if 'use_db_field' in f_inputs: - ex_vars['use_db_field'] = use_db_field + if "use_db_field" in f_inputs: + ex_vars["use_db_field"] = use_db_field value = field.to_mongo(value, **ex_vars) @@ -351,8 +389,8 @@ class BaseDocument(object): data[field.name] = value # Only add _cls if allow_inheritance is True - if not self._meta.get('allow_inheritance'): - data.pop('_cls') + if not self._meta.get("allow_inheritance"): + data.pop("_cls") return data @@ -372,18 +410,23 @@ class BaseDocument(object): errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values - fields = [(self._fields.get(name, self._dynamic_fields.get(name)), - self._data.get(name)) for name in self._fields_ordered] + fields = [ + ( + self._fields.get(name, self._dynamic_fields.get(name)), + self._data.get(name), + ) + for name in self._fields_ordered + ] - EmbeddedDocumentField = _import_class('EmbeddedDocumentField') - GenericEmbeddedDocumentField = _import_class( - 'GenericEmbeddedDocumentField') + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") for field, value in fields: if value is not None: try: - if isinstance(field, (EmbeddedDocumentField, - GenericEmbeddedDocumentField)): + if isinstance( + field, (EmbeddedDocumentField, GenericEmbeddedDocumentField) + ): field._validate(value, clean=clean) else: field._validate(value) @@ -391,17 +434,18 @@ class BaseDocument(object): errors[field.name] = error.errors or error except (ValueError, AttributeError, AssertionError) as error: errors[field.name] = error - elif field.required and not getattr(field, '_auto_gen', False): - errors[field.name] = ValidationError('Field is required', - field_name=field.name) + elif field.required and not getattr(field, "_auto_gen", False): + errors[field.name] = ValidationError( + "Field is required", field_name=field.name + ) if errors: - pk = 'None' - if hasattr(self, 'pk'): + pk = "None" + if hasattr(self, "pk"): pk = self.pk - elif self._instance and hasattr(self._instance, 'pk'): + elif self._instance and hasattr(self._instance, "pk"): pk = self._instance.pk - message = 'ValidationError (%s:%s) ' % (self._class_name, pk) + message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) def to_json(self, *args, **kwargs): @@ -411,7 +455,7 @@ class BaseDocument(object): MongoDB (as opposed to attribute names on this document). Defaults to True. """ - use_db_field = kwargs.pop('use_db_field', True) + use_db_field = kwargs.pop("use_db_field", True) return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs) @classmethod @@ -434,22 +478,18 @@ class BaseDocument(object): # If the value is a dict with '_cls' in it, turn it into a document is_dict = isinstance(value, dict) - if is_dict and '_cls' in value: - cls = get_document(value['_cls']) + if is_dict and "_cls" in value: + cls = get_document(value["_cls"]) return cls(**value) if is_dict: - value = { - k: self.__expand_dynamic_values(k, v) - for k, v in value.items() - } + value = {k: self.__expand_dynamic_values(k, v) for k, v in value.items()} else: value = [self.__expand_dynamic_values(name, v) for v in value] # Convert lists / values so we can watch for any changes on them - EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): + EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField") + if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): if issubclass(type(self), EmbeddedDocumentListField): value = EmbeddedDocumentList(value, self, name) else: @@ -464,26 +504,26 @@ class BaseDocument(object): if not key: return - if not hasattr(self, '_changed_fields'): + if not hasattr(self, "_changed_fields"): return - if '.' in key: - key, rest = key.split('.', 1) + if "." in key: + key, rest = key.split(".", 1) key = self._db_field_map.get(key, key) - key = '%s.%s' % (key, rest) + key = "%s.%s" % (key, rest) else: key = self._db_field_map.get(key, key) if key not in self._changed_fields: - levels, idx = key.split('.'), 1 + levels, idx = key.split("."), 1 while idx <= len(levels): - if '.'.join(levels[:idx]) in self._changed_fields: + if ".".join(levels[:idx]) in self._changed_fields: break idx += 1 else: self._changed_fields.append(key) # remove lower level changed fields - level = '.'.join(levels[:idx]) + '.' + level = ".".join(levels[:idx]) + "." remove = self._changed_fields.remove for field in self._changed_fields[:]: if field.startswith(level): @@ -494,7 +534,7 @@ class BaseDocument(object): are marked as changed. """ for changed in self._get_changed_fields(): - parts = changed.split('.') + parts = changed.split(".") data = self for part in parts: if isinstance(data, list): @@ -507,8 +547,10 @@ class BaseDocument(object): else: data = getattr(data, part, None) - if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'): - if getattr(data, '_is_document', False): + if not isinstance(data, LazyReference) and hasattr( + data, "_changed_fields" + ): + if getattr(data, "_is_document", False): continue data._changed_fields = [] @@ -524,39 +566,38 @@ class BaseDocument(object): """ # Loop list / dict fields as they contain documents # Determine the iterator to use - if not hasattr(data, 'items'): + if not hasattr(data, "items"): iterator = enumerate(data) else: iterator = iteritems(data) for index_or_key, value in iterator: - item_key = '%s%s.' % (base_key, index_or_key) + item_key = "%s%s." % (base_key, index_or_key) # don't check anything lower if this key is already marked # as changed. if item_key[:-1] in changed_fields: continue - if hasattr(value, '_get_changed_fields'): + if hasattr(value, "_get_changed_fields"): changed = value._get_changed_fields() - changed_fields += ['%s%s' % (item_key, k) for k in changed if k] + changed_fields += ["%s%s" % (item_key, k) for k in changed if k] elif isinstance(value, (list, tuple, dict)): - self._nestable_types_changed_fields( - changed_fields, item_key, value) + self._nestable_types_changed_fields(changed_fields, item_key, value) def _get_changed_fields(self): """Return a list of all fields that have explicitly been changed. """ - EmbeddedDocument = _import_class('EmbeddedDocument') - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - SortedListField = _import_class('SortedListField') + EmbeddedDocument = _import_class("EmbeddedDocument") + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") + SortedListField = _import_class("SortedListField") changed_fields = [] - changed_fields += getattr(self, '_changed_fields', []) + changed_fields += getattr(self, "_changed_fields", []) for field_name in self._fields_ordered: db_field_name = self._db_field_map.get(field_name, field_name) - key = '%s.' % db_field_name + key = "%s." % db_field_name data = self._data.get(field_name, None) field = self._fields.get(field_name) @@ -564,16 +605,17 @@ class BaseDocument(object): # Whole field already marked as changed, no need to go further continue - if isinstance(field, ReferenceField): # Don't follow referenced documents + if isinstance(field, ReferenceField): # Don't follow referenced documents continue if isinstance(data, EmbeddedDocument): # Find all embedded fields that have been changed changed = data._get_changed_fields() - changed_fields += ['%s%s' % (key, k) for k in changed if k] + changed_fields += ["%s%s" % (key, k) for k in changed if k] elif isinstance(data, (list, tuple, dict)): - if (hasattr(field, 'field') and - isinstance(field.field, (ReferenceField, GenericReferenceField))): + if hasattr(field, "field") and isinstance( + field.field, (ReferenceField, GenericReferenceField) + ): continue elif isinstance(field, SortedListField) and field._ordering: # if ordering is affected whole list is changed @@ -581,8 +623,7 @@ class BaseDocument(object): changed_fields.append(db_field_name) continue - self._nestable_types_changed_fields( - changed_fields, key, data) + self._nestable_types_changed_fields(changed_fields, key, data) return changed_fields def _delta(self): @@ -594,11 +635,11 @@ class BaseDocument(object): set_fields = self._get_changed_fields() unset_data = {} - if hasattr(self, '_changed_fields'): + if hasattr(self, "_changed_fields"): set_data = {} # Fetch each set item from its path for path in set_fields: - parts = path.split('.') + parts = path.split(".") d = doc new_path = [] for p in parts: @@ -608,26 +649,27 @@ class BaseDocument(object): elif isinstance(d, list) and p.isdigit(): # An item of a list (identified by its index) is updated d = d[int(p)] - elif hasattr(d, 'get'): + elif hasattr(d, "get"): # dict-like (dict, embedded document) d = d.get(p) new_path.append(p) - path = '.'.join(new_path) + path = ".".join(new_path) set_data[path] = d else: set_data = doc - if '_id' in set_data: - del set_data['_id'] + if "_id" in set_data: + del set_data["_id"] # Determine if any changed items were actually unset. for path, value in set_data.items(): - if value or isinstance(value, (numbers.Number, bool)): # Account for 0 and True that are truthy + if value or isinstance( + value, (numbers.Number, bool) + ): # Account for 0 and True that are truthy continue - parts = path.split('.') + parts = path.split(".") - if (self._dynamic and len(parts) and parts[0] in - self._dynamic_fields): + if self._dynamic and len(parts) and parts[0] in self._dynamic_fields: del set_data[path] unset_data[path] = 1 continue @@ -642,16 +684,16 @@ class BaseDocument(object): for p in parts: if isinstance(d, list) and p.isdigit(): d = d[int(p)] - elif (hasattr(d, '__getattribute__') and - not isinstance(d, dict)): + elif hasattr(d, "__getattribute__") and not isinstance(d, dict): real_path = d._reverse_db_field_map.get(p, p) d = getattr(d, real_path) else: d = d.get(p) - if hasattr(d, '_fields'): - field_name = d._reverse_db_field_map.get(db_field_name, - db_field_name) + if hasattr(d, "_fields"): + field_name = d._reverse_db_field_map.get( + db_field_name, db_field_name + ) if field_name in d._fields: default = d._fields.get(field_name).default else: @@ -672,7 +714,7 @@ class BaseDocument(object): """Return the collection name for this class. None for abstract class. """ - return cls._meta.get('collection', None) + return cls._meta.get("collection", None) @classmethod def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): @@ -685,7 +727,7 @@ class BaseDocument(object): # Get the class name from the document, falling back to the given # class if unavailable - class_name = son.get('_cls', cls._class_name) + class_name = son.get("_cls", cls._class_name) # Convert SON to a data dict, making sure each key is a string and # corresponds to the right db field. @@ -710,18 +752,20 @@ class BaseDocument(object): if field.db_field in data: value = data[field.db_field] try: - data[field_name] = (value if value is None - else field.to_python(value)) + data[field_name] = ( + value if value is None else field.to_python(value) + ) if field_name != field.db_field: del data[field.db_field] except (AttributeError, ValueError) as e: errors_dict[field_name] = e if errors_dict: - errors = '\n'.join(['%s - %s' % (k, v) - for k, v in errors_dict.items()]) - msg = ('Invalid data to create a `%s` instance.\n%s' - % (cls._class_name, errors)) + errors = "\n".join(["%s - %s" % (k, v) for k, v in errors_dict.items()]) + msg = "Invalid data to create a `%s` instance.\n%s" % ( + cls._class_name, + errors, + ) raise InvalidDocumentError(msg) # In STRICT documents, remove any keys that aren't in cls._fields @@ -729,10 +773,7 @@ class BaseDocument(object): data = {k: v for k, v in iteritems(data) if k in cls._fields} obj = cls( - __auto_convert=False, - _created=created, - __only_fields=only_fields, - **data + __auto_convert=False, _created=created, __only_fields=only_fields, **data ) obj._changed_fields = [] if not _auto_dereference: @@ -754,15 +795,13 @@ class BaseDocument(object): # Create a map of index fields to index spec. We're converting # the fields from a list to a tuple so that it's hashable. - spec_fields = { - tuple(index['fields']): index for index in index_specs - } + spec_fields = {tuple(index["fields"]): index for index in index_specs} # For each new index, if there's an existing index with the same # fields list, update the existing spec with all data from the # new spec. for new_index in indices: - candidate = spec_fields.get(tuple(new_index['fields'])) + candidate = spec_fields.get(tuple(new_index["fields"])) if candidate is None: index_specs.append(new_index) else: @@ -779,9 +818,9 @@ class BaseDocument(object): def _build_index_spec(cls, spec): """Build a PyMongo index spec from a MongoEngine index spec.""" if isinstance(spec, six.string_types): - spec = {'fields': [spec]} + spec = {"fields": [spec]} elif isinstance(spec, (list, tuple)): - spec = {'fields': list(spec)} + spec = {"fields": list(spec)} elif isinstance(spec, dict): spec = dict(spec) @@ -789,19 +828,21 @@ class BaseDocument(object): direction = None # Check to see if we need to include _cls - allow_inheritance = cls._meta.get('allow_inheritance') + allow_inheritance = cls._meta.get("allow_inheritance") include_cls = ( - allow_inheritance and - not spec.get('sparse', False) and - spec.get('cls', True) and - '_cls' not in spec['fields'] + allow_inheritance + and not spec.get("sparse", False) + and spec.get("cls", True) + and "_cls" not in spec["fields"] ) # 733: don't include cls if index_cls is False unless there is an explicit cls with the index - include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) - if 'cls' in spec: - spec.pop('cls') - for key in spec['fields']: + include_cls = include_cls and ( + spec.get("cls", False) or cls._meta.get("index_cls", True) + ) + if "cls" in spec: + spec.pop("cls") + for key in spec["fields"]: # If inherited spec continue if isinstance(key, (list, tuple)): continue @@ -814,51 +855,54 @@ class BaseDocument(object): # GEOHAYSTACK from ) # GEO2D from * direction = pymongo.ASCENDING - if key.startswith('-'): + if key.startswith("-"): direction = pymongo.DESCENDING - elif key.startswith('$'): + elif key.startswith("$"): direction = pymongo.TEXT - elif key.startswith('#'): + elif key.startswith("#"): direction = pymongo.HASHED - elif key.startswith('('): + elif key.startswith("("): direction = pymongo.GEOSPHERE - elif key.startswith(')'): + elif key.startswith(")"): direction = pymongo.GEOHAYSTACK - elif key.startswith('*'): + elif key.startswith("*"): direction = pymongo.GEO2D - if key.startswith(('+', '-', '*', '$', '#', '(', ')')): + if key.startswith(("+", "-", "*", "$", "#", "(", ")")): key = key[1:] # Use real field name, do it manually because we need field # objects for the next part (list field checking) - parts = key.split('.') - if parts in (['pk'], ['id'], ['_id']): - key = '_id' + parts = key.split(".") + if parts in (["pk"], ["id"], ["_id"]): + key = "_id" else: fields = cls._lookup_field(parts) parts = [] for field in fields: try: - if field != '_id': + if field != "_id": field = field.db_field except AttributeError: pass parts.append(field) - key = '.'.join(parts) + key = ".".join(parts) index_list.append((key, direction)) # Don't add cls to a geo index if include_cls and direction not in ( - pymongo.GEO2D, pymongo.GEOHAYSTACK, pymongo.GEOSPHERE): - index_list.insert(0, ('_cls', 1)) + pymongo.GEO2D, + pymongo.GEOHAYSTACK, + pymongo.GEOSPHERE, + ): + index_list.insert(0, ("_cls", 1)) if index_list: - spec['fields'] = index_list + spec["fields"] = index_list return spec @classmethod - def _unique_with_indexes(cls, namespace=''): + def _unique_with_indexes(cls, namespace=""): """Find unique indexes in the document schema and return them.""" unique_indexes = [] for field_name, field in cls._fields.items(): @@ -876,36 +920,39 @@ class BaseDocument(object): # Convert unique_with field names to real field names unique_with = [] for other_name in field.unique_with: - parts = other_name.split('.') + parts = other_name.split(".") # Lookup real name parts = cls._lookup_field(parts) name_parts = [part.db_field for part in parts] - unique_with.append('.'.join(name_parts)) + unique_with.append(".".join(name_parts)) # Unique field should be required parts[-1].required = True - sparse = (not sparse and - parts[-1].name not in cls.__dict__) + sparse = not sparse and parts[-1].name not in cls.__dict__ unique_fields += unique_with # Add the new index to the list fields = [ - ('%s%s' % (namespace, f), pymongo.ASCENDING) - for f in unique_fields + ("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields ] - index = {'fields': fields, 'unique': True, 'sparse': sparse} + index = {"fields": fields, "unique": True, "sparse": sparse} unique_indexes.append(index) - if field.__class__.__name__ in {'EmbeddedDocumentListField', - 'ListField', 'SortedListField'}: + if field.__class__.__name__ in { + "EmbeddedDocumentListField", + "ListField", + "SortedListField", + }: field = field.field # Grab any embedded document field unique indexes - if (field.__class__.__name__ == 'EmbeddedDocumentField' and - field.document_type != cls): - field_namespace = '%s.' % field_name + if ( + field.__class__.__name__ == "EmbeddedDocumentField" + and field.document_type != cls + ): + field_namespace = "%s." % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) @@ -917,32 +964,36 @@ class BaseDocument(object): geo_indices = [] inspected.append(cls) - geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField', - 'PointField', 'LineStringField', - 'PolygonField') + geo_field_type_names = ( + "EmbeddedDocumentField", + "GeoPointField", + "PointField", + "LineStringField", + "PolygonField", + ) - geo_field_types = tuple([_import_class(field) - for field in geo_field_type_names]) + geo_field_types = tuple( + [_import_class(field) for field in geo_field_type_names] + ) for field in cls._fields.values(): if not isinstance(field, geo_field_types): continue - if hasattr(field, 'document_type'): + if hasattr(field, "document_type"): field_cls = field.document_type if field_cls in inspected: continue - if hasattr(field_cls, '_geo_indices'): + if hasattr(field_cls, "_geo_indices"): geo_indices += field_cls._geo_indices( - inspected, parent_field=field.db_field) + inspected, parent_field=field.db_field + ) elif field._geo_index: field_name = field.db_field if parent_field: - field_name = '%s.%s' % (parent_field, field_name) - geo_indices.append({ - 'fields': [(field_name, field._geo_index)] - }) + field_name = "%s.%s" % (parent_field, field_name) + geo_indices.append({"fields": [(field_name, field._geo_index)]}) return geo_indices @@ -983,8 +1034,8 @@ class BaseDocument(object): # TODO this method is WAY too complicated. Simplify it. # TODO don't think returning a string for embedded non-existent fields is desired - ListField = _import_class('ListField') - DynamicField = _import_class('DynamicField') + ListField = _import_class("ListField") + DynamicField = _import_class("DynamicField") if not isinstance(parts, (list, tuple)): parts = [parts] @@ -1000,15 +1051,17 @@ class BaseDocument(object): # Look up first field from the document if field is None: - if field_name == 'pk': + if field_name == "pk": # Deal with "primary key" alias - field_name = cls._meta['id_field'] + field_name = cls._meta["id_field"] if field_name in cls._fields: field = cls._fields[field_name] elif cls._dynamic: field = DynamicField(db_field=field_name) - elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False): + elif cls._meta.get("allow_inheritance") or cls._meta.get( + "abstract", False + ): # 744: in case the field is defined in a subclass for subcls in cls.__subclasses__(): try: @@ -1023,38 +1076,41 @@ class BaseDocument(object): else: raise LookUpError('Cannot resolve field "%s"' % field_name) else: - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") # If previous field was a reference, throw an error (we # cannot look up fields that are on references). if isinstance(field, (ReferenceField, GenericReferenceField)): - raise LookUpError('Cannot perform join in mongoDB: %s' % - '__'.join(parts)) + raise LookUpError( + "Cannot perform join in mongoDB: %s" % "__".join(parts) + ) # If the parent field has a "field" attribute which has a # lookup_member method, call it to find the field # corresponding to this iteration. - if hasattr(getattr(field, 'field', None), 'lookup_member'): + if hasattr(getattr(field, "field", None), "lookup_member"): new_field = field.field.lookup_member(field_name) # If the parent field is a DynamicField or if it's part of # a DynamicDocument, mark current field as a DynamicField # with db_name equal to the field name. - elif cls._dynamic and (isinstance(field, DynamicField) or - getattr(getattr(field, 'document_type', None), '_dynamic', None)): + elif cls._dynamic and ( + isinstance(field, DynamicField) + or getattr(getattr(field, "document_type", None), "_dynamic", None) + ): new_field = DynamicField(db_field=field_name) # Else, try to use the parent field's lookup_member method # to find the subfield. - elif hasattr(field, 'lookup_member'): + elif hasattr(field, "lookup_member"): new_field = field.lookup_member(field_name) # Raise a LookUpError if all the other conditions failed. else: raise LookUpError( - 'Cannot resolve subfield or operator {} ' - 'on the field {}'.format(field_name, field.name) + "Cannot resolve subfield or operator {} " + "on the field {}".format(field_name, field.name) ) # If current field still wasn't found and the parent field @@ -1073,23 +1129,24 @@ class BaseDocument(object): return fields @classmethod - def _translate_field_name(cls, field, sep='.'): + def _translate_field_name(cls, field, sep="."): """Translate a field attribute name to a database field name. """ parts = field.split(sep) parts = [f.db_field for f in cls._lookup_field(parts)] - return '.'.join(parts) + return ".".join(parts) def __set_field_display(self): """For each field that specifies choices, create a get__display method. """ - fields_with_choices = [(n, f) for n, f in self._fields.items() - if f.choices] + fields_with_choices = [(n, f) for n, f in self._fields.items() if f.choices] for attr_name, field in fields_with_choices: - setattr(self, - 'get_%s_display' % attr_name, - partial(self.__get_field_display, field=field)) + setattr( + self, + "get_%s_display" % attr_name, + partial(self.__get_field_display, field=field), + ) def __get_field_display(self, field): """Return the display value for a choice field""" @@ -1097,9 +1154,16 @@ class BaseDocument(object): if field.choices and isinstance(field.choices[0], (list, tuple)): if value is None: return None - sep = getattr(field, 'display_sep', ' ') - values = value if field.__class__.__name__ in ('ListField', 'SortedListField') else [value] - return sep.join([ - six.text_type(dict(field.choices).get(val, val)) - for val in values or []]) + sep = getattr(field, "display_sep", " ") + values = ( + value + if field.__class__.__name__ in ("ListField", "SortedListField") + else [value] + ) + return sep.join( + [ + six.text_type(dict(field.choices).get(val, val)) + for val in values or [] + ] + ) return value diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 9ce426c9..cd1039cb 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -8,13 +8,11 @@ import six from six import iteritems from mongoengine.base.common import UPDATE_OPERATORS -from mongoengine.base.datastructures import (BaseDict, BaseList, - EmbeddedDocumentList) +from mongoengine.base.datastructures import BaseDict, BaseList, EmbeddedDocumentList from mongoengine.common import _import_class from mongoengine.errors import DeprecatedError, ValidationError -__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', - 'GeoJsonBaseField') +__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") class BaseField(object): @@ -23,6 +21,7 @@ class BaseField(object): .. versionchanged:: 0.5 - added verbose and help text """ + name = None _geo_index = False _auto_gen = False # Call `generate` to generate a value @@ -34,10 +33,21 @@ class BaseField(object): creation_counter = 0 auto_creation_counter = -1 - def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False, - validation=None, choices=None, null=False, sparse=False, - **kwargs): + def __init__( + self, + db_field=None, + name=None, + required=False, + default=None, + unique=False, + unique_with=None, + primary_key=False, + validation=None, + choices=None, + null=False, + sparse=False, + **kwargs + ): """ :param db_field: The database field to store this field in (defaults to the name of the field) @@ -65,7 +75,7 @@ class BaseField(object): existing attributes. Common metadata includes `verbose_name` and `help_text`. """ - self.db_field = (db_field or name) if not primary_key else '_id' + self.db_field = (db_field or name) if not primary_key else "_id" if name: msg = 'Field\'s "name" attribute deprecated in favour of "db_field"' @@ -82,17 +92,16 @@ class BaseField(object): self._owner_document = None # Make sure db_field is a string (if it's explicitly defined). - if ( - self.db_field is not None and - not isinstance(self.db_field, six.string_types) + if self.db_field is not None and not isinstance( + self.db_field, six.string_types ): - raise TypeError('db_field should be a string.') + raise TypeError("db_field should be a string.") # Make sure db_field doesn't contain any forbidden characters. if isinstance(self.db_field, six.string_types) and ( - '.' in self.db_field or - '\0' in self.db_field or - self.db_field.startswith('$') + "." in self.db_field + or "\0" in self.db_field + or self.db_field.startswith("$") ): raise ValueError( 'field names cannot contain dots (".") or null characters ' @@ -102,15 +111,17 @@ class BaseField(object): # Detect and report conflicts between metadata and base properties. conflicts = set(dir(self)) & set(kwargs) if conflicts: - raise TypeError('%s already has attribute(s): %s' % ( - self.__class__.__name__, ', '.join(conflicts))) + raise TypeError( + "%s already has attribute(s): %s" + % (self.__class__.__name__, ", ".join(conflicts)) + ) # Assign metadata to the instance # This efficient method is available because no __slots__ are defined. self.__dict__.update(kwargs) # Adjust the appropriate creation counter, and save our local copy. - if self.db_field == '_id': + if self.db_field == "_id": self.creation_counter = BaseField.auto_creation_counter BaseField.auto_creation_counter -= 1 else: @@ -142,8 +153,8 @@ class BaseField(object): if instance._initialised: try: value_has_changed = ( - self.name not in instance._data or - instance._data[self.name] != value + self.name not in instance._data + or instance._data[self.name] != value ) if value_has_changed: instance._mark_as_changed(self.name) @@ -153,7 +164,7 @@ class BaseField(object): # Mark the field as changed in such cases. instance._mark_as_changed(self.name) - EmbeddedDocument = _import_class('EmbeddedDocument') + EmbeddedDocument = _import_class("EmbeddedDocument") if isinstance(value, EmbeddedDocument): value._instance = weakref.proxy(instance) elif isinstance(value, (list, tuple)): @@ -163,7 +174,7 @@ class BaseField(object): instance._data[self.name] = value - def error(self, message='', errors=None, field_name=None): + def error(self, message="", errors=None, field_name=None): """Raise a ValidationError.""" field_name = field_name if field_name else self.name raise ValidationError(message, errors=errors, field_name=field_name) @@ -180,11 +191,11 @@ class BaseField(object): """Helper method to call to_mongo with proper inputs.""" f_inputs = self.to_mongo.__code__.co_varnames ex_vars = {} - if 'fields' in f_inputs: - ex_vars['fields'] = fields + if "fields" in f_inputs: + ex_vars["fields"] = fields - if 'use_db_field' in f_inputs: - ex_vars['use_db_field'] = use_db_field + if "use_db_field" in f_inputs: + ex_vars["use_db_field"] = use_db_field return self.to_mongo(value, **ex_vars) @@ -199,8 +210,8 @@ class BaseField(object): pass def _validate_choices(self, value): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') + Document = _import_class("Document") + EmbeddedDocument = _import_class("EmbeddedDocument") choice_list = self.choices if isinstance(next(iter(choice_list)), (list, tuple)): @@ -211,15 +222,13 @@ class BaseField(object): if isinstance(value, (Document, EmbeddedDocument)): if not any(isinstance(value, c) for c in choice_list): self.error( - 'Value must be an instance of %s' % ( - six.text_type(choice_list) - ) + "Value must be an instance of %s" % (six.text_type(choice_list)) ) # Choices which are types other than Documents else: values = value if isinstance(value, (list, tuple)) else [value] if len(set(values) - set(choice_list)): - self.error('Value must be one of %s' % six.text_type(choice_list)) + self.error("Value must be one of %s" % six.text_type(choice_list)) def _validate(self, value, **kwargs): # Check the Choices Constraint @@ -235,13 +244,17 @@ class BaseField(object): # in favor of having validation raising a ValidationError ret = self.validation(value) if ret is not None: - raise DeprecatedError('validation argument for `%s` must not return anything, ' - 'it should raise a ValidationError if validation fails' % self.name) + raise DeprecatedError( + "validation argument for `%s` must not return anything, " + "it should raise a ValidationError if validation fails" + % self.name + ) except ValidationError as ex: self.error(str(ex)) else: - raise ValueError('validation argument for `"%s"` must be a ' - 'callable.' % self.name) + raise ValueError( + 'validation argument for `"%s"` must be a ' "callable." % self.name + ) self.validate(value, **kwargs) @@ -275,35 +288,41 @@ class ComplexBaseField(BaseField): # Document class being used rather than a document object return self - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") + EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField") auto_dereference = instance._fields[self.name]._auto_dereference - dereference = (auto_dereference and - (self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)))) + dereference = auto_dereference and ( + self.field is None + or isinstance(self.field, (GenericReferenceField, ReferenceField)) + ) - _dereference = _import_class('DeReference')() + _dereference = _import_class("DeReference")() - if (instance._initialised and - dereference and - instance._data.get(self.name) and - not getattr(instance._data[self.name], '_dereferenced', False)): + if ( + instance._initialised + and dereference + and instance._data.get(self.name) + and not getattr(instance._data[self.name], "_dereferenced", False) + ): instance._data[self.name] = _dereference( - instance._data.get(self.name), max_depth=1, instance=instance, - name=self.name + instance._data.get(self.name), + max_depth=1, + instance=instance, + name=self.name, ) - if hasattr(instance._data[self.name], '_dereferenced'): + if hasattr(instance._data[self.name], "_dereferenced"): instance._data[self.name]._dereferenced = True value = super(ComplexBaseField, self).__get__(instance, owner) # Convert lists / values so we can watch for any changes on them if isinstance(value, (list, tuple)): - if (issubclass(type(self), EmbeddedDocumentListField) and - not isinstance(value, EmbeddedDocumentList)): + if issubclass(type(self), EmbeddedDocumentListField) and not isinstance( + value, EmbeddedDocumentList + ): value = EmbeddedDocumentList(value, instance, self.name) elif not isinstance(value, BaseList): value = BaseList(value, instance, self.name) @@ -312,12 +331,13 @@ class ComplexBaseField(BaseField): value = BaseDict(value, instance, self.name) instance._data[self.name] = value - if (auto_dereference and instance._initialised and - isinstance(value, (BaseList, BaseDict)) and - not value._dereferenced): - value = _dereference( - value, max_depth=1, instance=instance, name=self.name - ) + if ( + auto_dereference + and instance._initialised + and isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced + ): + value = _dereference(value, max_depth=1, instance=instance, name=self.name) value._dereferenced = True instance._data[self.name] = value @@ -328,16 +348,16 @@ class ComplexBaseField(BaseField): if isinstance(value, six.string_types): return value - if hasattr(value, 'to_python'): + if hasattr(value, "to_python"): return value.to_python() - BaseDocument = _import_class('BaseDocument') + BaseDocument = _import_class("BaseDocument") if isinstance(value, BaseDocument): # Something is wrong, return the value as it is return value is_list = False - if not hasattr(value, 'items'): + if not hasattr(value, "items"): try: is_list = True value = {idx: v for idx, v in enumerate(value)} @@ -346,50 +366,54 @@ class ComplexBaseField(BaseField): if self.field: self.field._auto_dereference = self._auto_dereference - value_dict = {key: self.field.to_python(item) - for key, item in value.items()} + value_dict = { + key: self.field.to_python(item) for key, item in value.items() + } else: - Document = _import_class('Document') + Document = _import_class("Document") value_dict = {} for k, v in value.items(): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') + self.error( + "You can only reference documents once they" + " have been saved to the database" + ) collection = v._get_collection_name() value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_python'): + elif hasattr(v, "to_python"): value_dict[k] = v.to_python() else: value_dict[k] = self.to_python(v) if is_list: # Convert back to a list - return [v for _, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] + return [ + v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0)) + ] return value_dict def to_mongo(self, value, use_db_field=True, fields=None): """Convert a Python type to a MongoDB-compatible type.""" - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - GenericReferenceField = _import_class('GenericReferenceField') + Document = _import_class("Document") + EmbeddedDocument = _import_class("EmbeddedDocument") + GenericReferenceField = _import_class("GenericReferenceField") if isinstance(value, six.string_types): return value - if hasattr(value, 'to_mongo'): + if hasattr(value, "to_mongo"): if isinstance(value, Document): return GenericReferenceField().to_mongo(value) cls = value.__class__ val = value.to_mongo(use_db_field, fields) # If it's a document that is not inherited add _cls if isinstance(value, EmbeddedDocument): - val['_cls'] = cls.__name__ + val["_cls"] = cls.__name__ return val is_list = False - if not hasattr(value, 'items'): + if not hasattr(value, "items"): try: is_list = True value = {k: v for k, v in enumerate(value)} @@ -407,39 +431,42 @@ class ComplexBaseField(BaseField): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') + self.error( + "You can only reference documents once they" + " have been saved to the database" + ) # If its a document that is not inheritable it won't have # any _cls data so make it a generic reference allows # us to dereference - meta = getattr(v, '_meta', {}) - allow_inheritance = meta.get('allow_inheritance') + meta = getattr(v, "_meta", {}) + allow_inheritance = meta.get("allow_inheritance") if not allow_inheritance and not self.field: value_dict[k] = GenericReferenceField().to_mongo(v) else: collection = v._get_collection_name() value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_mongo'): + elif hasattr(v, "to_mongo"): cls = v.__class__ val = v.to_mongo(use_db_field, fields) # If it's a document that is not inherited add _cls if isinstance(v, (Document, EmbeddedDocument)): - val['_cls'] = cls.__name__ + val["_cls"] = cls.__name__ value_dict[k] = val else: value_dict[k] = self.to_mongo(v, use_db_field, fields) if is_list: # Convert back to a list - return [v for _, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] + return [ + v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0)) + ] return value_dict def validate(self, value): """If field is provided ensure the value is valid.""" errors = {} if self.field: - if hasattr(value, 'iteritems') or hasattr(value, 'items'): + if hasattr(value, "iteritems") or hasattr(value, "items"): sequence = iteritems(value) else: sequence = enumerate(value) @@ -453,11 +480,10 @@ class ComplexBaseField(BaseField): if errors: field_class = self.field.__class__.__name__ - self.error('Invalid %s item (%s)' % (field_class, value), - errors=errors) + self.error("Invalid %s item (%s)" % (field_class, value), errors=errors) # Don't allow empty values if required if self.required and not value: - self.error('Field is required and cannot be empty') + self.error("Field is required and cannot be empty") def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -500,7 +526,7 @@ class ObjectIdField(BaseField): try: ObjectId(six.text_type(value)) except Exception: - self.error('Invalid Object ID') + self.error("Invalid Object ID") class GeoJsonBaseField(BaseField): @@ -510,14 +536,14 @@ class GeoJsonBaseField(BaseField): """ _geo_index = pymongo.GEOSPHERE - _type = 'GeoBase' + _type = "GeoBase" def __init__(self, auto_index=True, *args, **kwargs): """ :param bool auto_index: Automatically create a '2dsphere' index.\ Defaults to `True`. """ - self._name = '%sField' % self._type + self._name = "%sField" % self._type if not auto_index: self._geo_index = False super(GeoJsonBaseField, self).__init__(*args, **kwargs) @@ -525,57 +551,58 @@ class GeoJsonBaseField(BaseField): def validate(self, value): """Validate the GeoJson object based on its type.""" if isinstance(value, dict): - if set(value.keys()) == {'type', 'coordinates'}: - if value['type'] != self._type: - self.error('%s type must be "%s"' % - (self._name, self._type)) - return self.validate(value['coordinates']) + if set(value.keys()) == {"type", "coordinates"}: + if value["type"] != self._type: + self.error('%s type must be "%s"' % (self._name, self._type)) + return self.validate(value["coordinates"]) else: - self.error('%s can only accept a valid GeoJson dictionary' - ' or lists of (x, y)' % self._name) + self.error( + "%s can only accept a valid GeoJson dictionary" + " or lists of (x, y)" % self._name + ) return elif not isinstance(value, (list, tuple)): - self.error('%s can only accept lists of [x, y]' % self._name) + self.error("%s can only accept lists of [x, y]" % self._name) return - validate = getattr(self, '_validate_%s' % self._type.lower()) + validate = getattr(self, "_validate_%s" % self._type.lower()) error = validate(value) if error: self.error(error) def _validate_polygon(self, value, top_level=True): if not isinstance(value, (list, tuple)): - return 'Polygons must contain list of linestrings' + return "Polygons must contain list of linestrings" # Quick and dirty validator try: value[0][0][0] except (TypeError, IndexError): - return 'Invalid Polygon must contain at least one valid linestring' + return "Invalid Polygon must contain at least one valid linestring" errors = [] for val in value: error = self._validate_linestring(val, False) if not error and val[0] != val[-1]: - error = 'LineStrings must start and end at the same point' + error = "LineStrings must start and end at the same point" if error and error not in errors: errors.append(error) if errors: if top_level: - return 'Invalid Polygon:\n%s' % ', '.join(errors) + return "Invalid Polygon:\n%s" % ", ".join(errors) else: - return '%s' % ', '.join(errors) + return "%s" % ", ".join(errors) def _validate_linestring(self, value, top_level=True): """Validate a linestring.""" if not isinstance(value, (list, tuple)): - return 'LineStrings must contain list of coordinate pairs' + return "LineStrings must contain list of coordinate pairs" # Quick and dirty validator try: value[0][0] except (TypeError, IndexError): - return 'Invalid LineString must contain at least one valid point' + return "Invalid LineString must contain at least one valid point" errors = [] for val in value: @@ -584,29 +611,30 @@ class GeoJsonBaseField(BaseField): errors.append(error) if errors: if top_level: - return 'Invalid LineString:\n%s' % ', '.join(errors) + return "Invalid LineString:\n%s" % ", ".join(errors) else: - return '%s' % ', '.join(errors) + return "%s" % ", ".join(errors) def _validate_point(self, value): """Validate each set of coords""" if not isinstance(value, (list, tuple)): - return 'Points must be a list of coordinate pairs' + return "Points must be a list of coordinate pairs" elif not len(value) == 2: - return 'Value (%s) must be a two-dimensional point' % repr(value) - elif (not isinstance(value[0], (float, int)) or - not isinstance(value[1], (float, int))): - return 'Both values (%s) in point must be float or int' % repr(value) + return "Value (%s) must be a two-dimensional point" % repr(value) + elif not isinstance(value[0], (float, int)) or not isinstance( + value[1], (float, int) + ): + return "Both values (%s) in point must be float or int" % repr(value) def _validate_multipoint(self, value): if not isinstance(value, (list, tuple)): - return 'MultiPoint must be a list of Point' + return "MultiPoint must be a list of Point" # Quick and dirty validator try: value[0][0] except (TypeError, IndexError): - return 'Invalid MultiPoint must contain at least one valid point' + return "Invalid MultiPoint must contain at least one valid point" errors = [] for point in value: @@ -615,17 +643,17 @@ class GeoJsonBaseField(BaseField): errors.append(error) if errors: - return '%s' % ', '.join(errors) + return "%s" % ", ".join(errors) def _validate_multilinestring(self, value, top_level=True): if not isinstance(value, (list, tuple)): - return 'MultiLineString must be a list of LineString' + return "MultiLineString must be a list of LineString" # Quick and dirty validator try: value[0][0][0] except (TypeError, IndexError): - return 'Invalid MultiLineString must contain at least one valid linestring' + return "Invalid MultiLineString must contain at least one valid linestring" errors = [] for linestring in value: @@ -635,19 +663,19 @@ class GeoJsonBaseField(BaseField): if errors: if top_level: - return 'Invalid MultiLineString:\n%s' % ', '.join(errors) + return "Invalid MultiLineString:\n%s" % ", ".join(errors) else: - return '%s' % ', '.join(errors) + return "%s" % ", ".join(errors) def _validate_multipolygon(self, value): if not isinstance(value, (list, tuple)): - return 'MultiPolygon must be a list of Polygon' + return "MultiPolygon must be a list of Polygon" # Quick and dirty validator try: value[0][0][0][0] except (TypeError, IndexError): - return 'Invalid MultiPolygon must contain at least one valid Polygon' + return "Invalid MultiPolygon must contain at least one valid Polygon" errors = [] for polygon in value: @@ -656,9 +684,9 @@ class GeoJsonBaseField(BaseField): errors.append(error) if errors: - return 'Invalid MultiPolygon:\n%s' % ', '.join(errors) + return "Invalid MultiPolygon:\n%s" % ", ".join(errors) def to_mongo(self, value): if isinstance(value, dict): return value - return SON([('type', self._type), ('coordinates', value)]) + return SON([("type", self._type), ("coordinates", value)]) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index c3ced5bb..e4d26811 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -8,12 +8,15 @@ from mongoengine.base.common import _document_registry from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField from mongoengine.common import _import_class from mongoengine.errors import InvalidDocumentError -from mongoengine.queryset import (DO_NOTHING, DoesNotExist, - MultipleObjectsReturned, - QuerySetManager) +from mongoengine.queryset import ( + DO_NOTHING, + DoesNotExist, + MultipleObjectsReturned, + QuerySetManager, +) -__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') +__all__ = ("DocumentMetaclass", "TopLevelDocumentMetaclass") class DocumentMetaclass(type): @@ -25,44 +28,46 @@ class DocumentMetaclass(type): super_new = super(DocumentMetaclass, mcs).__new__ # If a base class just call super - metaclass = attrs.get('my_metaclass') + metaclass = attrs.get("my_metaclass") if metaclass and issubclass(metaclass, DocumentMetaclass): return super_new(mcs, name, bases, attrs) - attrs['_is_document'] = attrs.get('_is_document', False) - attrs['_cached_reference_fields'] = [] + attrs["_is_document"] = attrs.get("_is_document", False) + attrs["_cached_reference_fields"] = [] # EmbeddedDocuments could have meta data for inheritance - if 'meta' in attrs: - attrs['_meta'] = attrs.pop('meta') + if "meta" in attrs: + attrs["_meta"] = attrs.pop("meta") # EmbeddedDocuments should inherit meta data - if '_meta' not in attrs: + if "_meta" not in attrs: meta = MetaDict() for base in flattened_bases[::-1]: # Add any mixin metadata from plain objects - if hasattr(base, 'meta'): + if hasattr(base, "meta"): meta.merge(base.meta) - elif hasattr(base, '_meta'): + elif hasattr(base, "_meta"): meta.merge(base._meta) - attrs['_meta'] = meta - attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract + attrs["_meta"] = meta + attrs["_meta"][ + "abstract" + ] = False # 789: EmbeddedDocument shouldn't inherit abstract # If allow_inheritance is True, add a "_cls" string field to the attrs - if attrs['_meta'].get('allow_inheritance'): - StringField = _import_class('StringField') - attrs['_cls'] = StringField() + if attrs["_meta"].get("allow_inheritance"): + StringField = _import_class("StringField") + attrs["_cls"] = StringField() # Handle document Fields # Merge all fields from subclasses doc_fields = {} for base in flattened_bases[::-1]: - if hasattr(base, '_fields'): + if hasattr(base, "_fields"): doc_fields.update(base._fields) # Standard object mixin - merge in any Fields - if not hasattr(base, '_meta'): + if not hasattr(base, "_meta"): base_fields = {} for attr_name, attr_value in iteritems(base.__dict__): if not isinstance(attr_value, BaseField): @@ -85,27 +90,31 @@ class DocumentMetaclass(type): doc_fields[attr_name] = attr_value # Count names to ensure no db_field redefinitions - field_names[attr_value.db_field] = field_names.get( - attr_value.db_field, 0) + 1 + field_names[attr_value.db_field] = ( + field_names.get(attr_value.db_field, 0) + 1 + ) # Ensure no duplicate db_fields duplicate_db_fields = [k for k, v in field_names.items() if v > 1] if duplicate_db_fields: - msg = ('Multiple db_fields defined for: %s ' % - ', '.join(duplicate_db_fields)) + msg = "Multiple db_fields defined for: %s " % ", ".join(duplicate_db_fields) raise InvalidDocumentError(msg) # Set _fields and db_field maps - attrs['_fields'] = doc_fields - attrs['_db_field_map'] = {k: getattr(v, 'db_field', k) - for k, v in doc_fields.items()} - attrs['_reverse_db_field_map'] = { - v: k for k, v in attrs['_db_field_map'].items() + attrs["_fields"] = doc_fields + attrs["_db_field_map"] = { + k: getattr(v, "db_field", k) for k, v in doc_fields.items() + } + attrs["_reverse_db_field_map"] = { + v: k for k, v in attrs["_db_field_map"].items() } - attrs['_fields_ordered'] = tuple(i[1] for i in sorted( - (v.creation_counter, v.name) - for v in itervalues(doc_fields))) + attrs["_fields_ordered"] = tuple( + i[1] + for i in sorted( + (v.creation_counter, v.name) for v in itervalues(doc_fields) + ) + ) # # Set document hierarchy @@ -113,32 +122,34 @@ class DocumentMetaclass(type): superclasses = () class_name = [name] for base in flattened_bases: - if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): + if not getattr(base, "_is_base_cls", True) and not getattr( + base, "_meta", {} + ).get("abstract", True): # Collate hierarchy for _cls and _subclasses class_name.append(base.__name__) - if hasattr(base, '_meta'): + if hasattr(base, "_meta"): # Warn if allow_inheritance isn't set and prevent # inheritance of classes where inheritance is set to False - allow_inheritance = base._meta.get('allow_inheritance') - if not allow_inheritance and not base._meta.get('abstract'): - raise ValueError('Document %s may not be subclassed. ' - 'To enable inheritance, use the "allow_inheritance" meta attribute.' % - base.__name__) + allow_inheritance = base._meta.get("allow_inheritance") + if not allow_inheritance and not base._meta.get("abstract"): + raise ValueError( + "Document %s may not be subclassed. " + 'To enable inheritance, use the "allow_inheritance" meta attribute.' + % base.__name__ + ) # Get superclasses from last base superclass - document_bases = [b for b in flattened_bases - if hasattr(b, '_class_name')] + document_bases = [b for b in flattened_bases if hasattr(b, "_class_name")] if document_bases: superclasses = document_bases[0]._superclasses - superclasses += (document_bases[0]._class_name, ) + superclasses += (document_bases[0]._class_name,) - _cls = '.'.join(reversed(class_name)) - attrs['_class_name'] = _cls - attrs['_superclasses'] = superclasses - attrs['_subclasses'] = (_cls, ) - attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types + _cls = ".".join(reversed(class_name)) + attrs["_class_name"] = _cls + attrs["_superclasses"] = superclasses + attrs["_subclasses"] = (_cls,) + attrs["_types"] = attrs["_subclasses"] # TODO depreciate _types # Create the new_class new_class = super_new(mcs, name, bases, attrs) @@ -149,8 +160,12 @@ class DocumentMetaclass(type): base._subclasses += (_cls,) base._types = base._subclasses # TODO depreciate _types - (Document, EmbeddedDocument, DictField, - CachedReferenceField) = mcs._import_classes() + ( + Document, + EmbeddedDocument, + DictField, + CachedReferenceField, + ) = mcs._import_classes() if issubclass(new_class, Document): new_class._collection = None @@ -169,52 +184,55 @@ class DocumentMetaclass(type): for val in new_class.__dict__.values(): if isinstance(val, classmethod): f = val.__get__(new_class) - if hasattr(f, '__func__') and not hasattr(f, 'im_func'): - f.__dict__.update({'im_func': getattr(f, '__func__')}) - if hasattr(f, '__self__') and not hasattr(f, 'im_self'): - f.__dict__.update({'im_self': getattr(f, '__self__')}) + if hasattr(f, "__func__") and not hasattr(f, "im_func"): + f.__dict__.update({"im_func": getattr(f, "__func__")}) + if hasattr(f, "__self__") and not hasattr(f, "im_self"): + f.__dict__.update({"im_self": getattr(f, "__self__")}) # Handle delete rules for field in itervalues(new_class._fields): f = field if f.owner_document is None: f.owner_document = new_class - delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) + delete_rule = getattr(f, "reverse_delete_rule", DO_NOTHING) if isinstance(f, CachedReferenceField): if issubclass(new_class, EmbeddedDocument): - raise InvalidDocumentError('CachedReferenceFields is not ' - 'allowed in EmbeddedDocuments') + raise InvalidDocumentError( + "CachedReferenceFields is not allowed in EmbeddedDocuments" + ) if f.auto_sync: f.start_listener() f.document_type._cached_reference_fields.append(f) - if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): - delete_rule = getattr(f.field, - 'reverse_delete_rule', - DO_NOTHING) + if isinstance(f, ComplexBaseField) and hasattr(f, "field"): + delete_rule = getattr(f.field, "reverse_delete_rule", DO_NOTHING) if isinstance(f, DictField) and delete_rule != DO_NOTHING: - msg = ('Reverse delete rules are not supported ' - 'for %s (field: %s)' % - (field.__class__.__name__, field.name)) + msg = ( + "Reverse delete rules are not supported " + "for %s (field: %s)" % (field.__class__.__name__, field.name) + ) raise InvalidDocumentError(msg) f = field.field if delete_rule != DO_NOTHING: if issubclass(new_class, EmbeddedDocument): - msg = ('Reverse delete rules are not supported for ' - 'EmbeddedDocuments (field: %s)' % field.name) + msg = ( + "Reverse delete rules are not supported for " + "EmbeddedDocuments (field: %s)" % field.name + ) raise InvalidDocumentError(msg) - f.document_type.register_delete_rule(new_class, - field.name, delete_rule) + f.document_type.register_delete_rule(new_class, field.name, delete_rule) - if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): - msg = ('%s is a document method and not a valid ' - 'field name' % field.name) + if ( + field.name + and hasattr(Document, field.name) + and EmbeddedDocument not in new_class.mro() + ): + msg = "%s is a document method and not a valid field name" % field.name raise InvalidDocumentError(msg) return new_class @@ -239,10 +257,10 @@ class DocumentMetaclass(type): @classmethod def _import_classes(mcs): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - DictField = _import_class('DictField') - CachedReferenceField = _import_class('CachedReferenceField') + Document = _import_class("Document") + EmbeddedDocument = _import_class("EmbeddedDocument") + DictField = _import_class("DictField") + CachedReferenceField = _import_class("CachedReferenceField") return Document, EmbeddedDocument, DictField, CachedReferenceField @@ -256,65 +274,67 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): super_new = super(TopLevelDocumentMetaclass, mcs).__new__ # Set default _meta data if base class, otherwise get user defined meta - if attrs.get('my_metaclass') == TopLevelDocumentMetaclass: + if attrs.get("my_metaclass") == TopLevelDocumentMetaclass: # defaults - attrs['_meta'] = { - 'abstract': True, - 'max_documents': None, - 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime - 'id_field': None, - 'index_background': False, - 'index_drop_dups': False, - 'index_opts': None, - 'delete_rules': None, - + attrs["_meta"] = { + "abstract": True, + "max_documents": None, + "max_size": None, + "ordering": [], # default ordering applied at runtime + "indexes": [], # indexes to be ensured at runtime + "id_field": None, + "index_background": False, + "index_drop_dups": False, + "index_opts": None, + "delete_rules": None, # allow_inheritance can be True, False, and None. True means # "allow inheritance", False means "don't allow inheritance", # None means "do whatever your parent does, or don't allow # inheritance if you're a top-level class". - 'allow_inheritance': None, + "allow_inheritance": None, } - attrs['_is_base_cls'] = True - attrs['_meta'].update(attrs.get('meta', {})) + attrs["_is_base_cls"] = True + attrs["_meta"].update(attrs.get("meta", {})) else: - attrs['_meta'] = attrs.get('meta', {}) + attrs["_meta"] = attrs.get("meta", {}) # Explicitly set abstract to false unless set - attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) - attrs['_is_base_cls'] = False + attrs["_meta"]["abstract"] = attrs["_meta"].get("abstract", False) + attrs["_is_base_cls"] = False # Set flag marking as document class - as opposed to an object mixin - attrs['_is_document'] = True + attrs["_is_document"] = True # Ensure queryset_class is inherited - if 'objects' in attrs: - manager = attrs['objects'] - if hasattr(manager, 'queryset_class'): - attrs['_meta']['queryset_class'] = manager.queryset_class + if "objects" in attrs: + manager = attrs["objects"] + if hasattr(manager, "queryset_class"): + attrs["_meta"]["queryset_class"] = manager.queryset_class # Clean up top level meta - if 'meta' in attrs: - del attrs['meta'] + if "meta" in attrs: + del attrs["meta"] # Find the parent document class - parent_doc_cls = [b for b in flattened_bases - if b.__class__ == TopLevelDocumentMetaclass] + parent_doc_cls = [ + b for b in flattened_bases if b.__class__ == TopLevelDocumentMetaclass + ] parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] # Prevent classes setting collection different to their parents # If parent wasn't an abstract class - if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and - not parent_doc_cls._meta.get('abstract', True)): - msg = 'Trying to set a collection on a subclass (%s)' % name + if ( + parent_doc_cls + and "collection" in attrs.get("_meta", {}) + and not parent_doc_cls._meta.get("abstract", True) + ): + msg = "Trying to set a collection on a subclass (%s)" % name warnings.warn(msg, SyntaxWarning) - del attrs['_meta']['collection'] + del attrs["_meta"]["collection"] # Ensure abstract documents have abstract bases - if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): - if (parent_doc_cls and - not parent_doc_cls._meta.get('abstract', False)): - msg = 'Abstract document cannot have non-abstract base' + if attrs.get("_is_base_cls") or attrs["_meta"].get("abstract"): + if parent_doc_cls and not parent_doc_cls._meta.get("abstract", False): + msg = "Abstract document cannot have non-abstract base" raise ValueError(msg) return super_new(mcs, name, bases, attrs) @@ -323,38 +343,43 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): meta = MetaDict() for base in flattened_bases[::-1]: # Add any mixin metadata from plain objects - if hasattr(base, 'meta'): + if hasattr(base, "meta"): meta.merge(base.meta) - elif hasattr(base, '_meta'): + elif hasattr(base, "_meta"): meta.merge(base._meta) # Set collection in the meta if its callable - if (getattr(base, '_is_document', False) and - not base._meta.get('abstract')): - collection = meta.get('collection', None) + if getattr(base, "_is_document", False) and not base._meta.get("abstract"): + collection = meta.get("collection", None) if callable(collection): - meta['collection'] = collection(base) + meta["collection"] = collection(base) - meta.merge(attrs.get('_meta', {})) # Top level meta + meta.merge(attrs.get("_meta", {})) # Top level meta # Only simple classes (i.e. direct subclasses of Document) may set # allow_inheritance to False. If the base Document allows inheritance, # none of its subclasses can override allow_inheritance to False. - simple_class = all([b._meta.get('abstract') - for b in flattened_bases if hasattr(b, '_meta')]) + simple_class = all( + [b._meta.get("abstract") for b in flattened_bases if hasattr(b, "_meta")] + ) if ( - not simple_class and - meta['allow_inheritance'] is False and - not meta['abstract'] + not simple_class + and meta["allow_inheritance"] is False + and not meta["abstract"] ): - raise ValueError('Only direct subclasses of Document may set ' - '"allow_inheritance" to False') + raise ValueError( + "Only direct subclasses of Document may set " + '"allow_inheritance" to False' + ) # Set default collection name - if 'collection' not in meta: - meta['collection'] = ''.join('_%s' % c if c.isupper() else c - for c in name).strip('_').lower() - attrs['_meta'] = meta + if "collection" not in meta: + meta["collection"] = ( + "".join("_%s" % c if c.isupper() else c for c in name) + .strip("_") + .lower() + ) + attrs["_meta"] = meta # Call super and get the new class new_class = super_new(mcs, name, bases, attrs) @@ -362,36 +387,36 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): meta = new_class._meta # Set index specifications - meta['index_specs'] = new_class._build_index_specs(meta['indexes']) + meta["index_specs"] = new_class._build_index_specs(meta["indexes"]) # If collection is a callable - call it and set the value - collection = meta.get('collection') + collection = meta.get("collection") if callable(collection): - new_class._meta['collection'] = collection(new_class) + new_class._meta["collection"] = collection(new_class) # Provide a default queryset unless exists or one has been set - if 'objects' not in dir(new_class): + if "objects" not in dir(new_class): new_class.objects = QuerySetManager() # Validate the fields and set primary key if needed for field_name, field in iteritems(new_class._fields): if field.primary_key: # Ensure only one primary key is set - current_pk = new_class._meta.get('id_field') + current_pk = new_class._meta.get("id_field") if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') + raise ValueError("Cannot override primary key field") # Set primary key if not current_pk: - new_class._meta['id_field'] = field_name + new_class._meta["id_field"] = field_name new_class.id = field # If the document doesn't explicitly define a primary key field, create # one. Make it an ObjectIdField and give it a non-clashing name ("id" # by default, but can be different if that one's taken). - if not new_class._meta.get('id_field'): + if not new_class._meta.get("id_field"): id_name, id_db_name = mcs.get_auto_id_names(new_class) - new_class._meta['id_field'] = id_name + new_class._meta["id_field"] = id_name new_class._fields[id_name] = ObjectIdField(db_field=id_db_name) new_class._fields[id_name].name = id_name new_class.id = new_class._fields[id_name] @@ -400,22 +425,20 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Prepend the ID field to _fields_ordered (so that it's *always* # the first field). - new_class._fields_ordered = (id_name, ) + new_class._fields_ordered + new_class._fields_ordered = (id_name,) + new_class._fields_ordered # Merge in exceptions with parent hierarchy. exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) - module = attrs.get('__module__') + module = attrs.get("__module__") for exc in exceptions_to_merge: name = exc.__name__ parents = tuple( - getattr(base, name) - for base in flattened_bases - if hasattr(base, name) + getattr(base, name) for base in flattened_bases if hasattr(base, name) ) or (exc,) # Create a new exception and set it as an attribute on the new # class. - exception = type(name, parents, {'__module__': module}) + exception = type(name, parents, {"__module__": module}) setattr(new_class, name, exception) return new_class @@ -431,23 +454,17 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): Defaults to ('id', '_id'), or generates a non-clashing name in the form of ('auto_id_X', '_auto_id_X') if the default name is already taken. """ - id_name, id_db_name = ('id', '_id') + id_name, id_db_name = ("id", "_id") existing_fields = {field_name for field_name in new_class._fields} existing_db_fields = {v.db_field for v in new_class._fields.values()} - if ( - id_name not in existing_fields and - id_db_name not in existing_db_fields - ): + if id_name not in existing_fields and id_db_name not in existing_db_fields: return id_name, id_db_name - id_basename, id_db_basename, i = ('auto_id', '_auto_id', 0) + id_basename, id_db_basename, i = ("auto_id", "_auto_id", 0) for i in itertools.count(): - id_name = '{0}_{1}'.format(id_basename, i) - id_db_name = '{0}_{1}'.format(id_db_basename, i) - if ( - id_name not in existing_fields and - id_db_name not in existing_db_fields - ): + id_name = "{0}_{1}".format(id_basename, i) + id_db_name = "{0}_{1}".format(id_db_basename, i) + if id_name not in existing_fields and id_db_name not in existing_db_fields: return id_name, id_db_name @@ -455,7 +472,8 @@ class MetaDict(dict): """Custom dictionary for meta classes. Handles the merging of set indexes """ - _merge_options = ('indexes',) + + _merge_options = ("indexes",) def merge(self, new_options): for k, v in iteritems(new_options): @@ -467,4 +485,5 @@ class MetaDict(dict): class BasesTuple(tuple): """Special class to handle introspection of bases tuple in __new__""" + pass diff --git a/mongoengine/common.py b/mongoengine/common.py index bcdea194..640384ec 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -19,34 +19,44 @@ def _import_class(cls_name): if cls_name in _class_registry_cache: return _class_registry_cache.get(cls_name) - doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', - 'MapReduceDocument') + doc_classes = ( + "Document", + "DynamicEmbeddedDocument", + "EmbeddedDocument", + "MapReduceDocument", + ) # Field Classes if not _field_list_cache: from mongoengine.fields import __all__ as fields + _field_list_cache.extend(fields) from mongoengine.base.fields import __all__ as fields + _field_list_cache.extend(fields) field_classes = _field_list_cache - deref_classes = ('DeReference',) + deref_classes = ("DeReference",) - if cls_name == 'BaseDocument': + if cls_name == "BaseDocument": from mongoengine.base import document as module - import_classes = ['BaseDocument'] + + import_classes = ["BaseDocument"] elif cls_name in doc_classes: from mongoengine import document as module + import_classes = doc_classes elif cls_name in field_classes: from mongoengine import fields as module + import_classes = field_classes elif cls_name in deref_classes: from mongoengine import dereference as module + import_classes = deref_classes else: - raise ValueError('No import set for: %s' % cls_name) + raise ValueError("No import set for: %s" % cls_name) for cls in import_classes: _class_registry_cache[cls] = getattr(module, cls) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 6a613a42..ef0dd27c 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -3,21 +3,21 @@ from pymongo.database import _check_name import six __all__ = [ - 'DEFAULT_CONNECTION_NAME', - 'DEFAULT_DATABASE_NAME', - 'MongoEngineConnectionError', - 'connect', - 'disconnect', - 'disconnect_all', - 'get_connection', - 'get_db', - 'register_connection', + "DEFAULT_CONNECTION_NAME", + "DEFAULT_DATABASE_NAME", + "MongoEngineConnectionError", + "connect", + "disconnect", + "disconnect_all", + "get_connection", + "get_db", + "register_connection", ] -DEFAULT_CONNECTION_NAME = 'default' -DEFAULT_DATABASE_NAME = 'test' -DEFAULT_HOST = 'localhost' +DEFAULT_CONNECTION_NAME = "default" +DEFAULT_DATABASE_NAME = "test" +DEFAULT_HOST = "localhost" DEFAULT_PORT = 27017 _connection_settings = {} @@ -31,6 +31,7 @@ class MongoEngineConnectionError(Exception): """Error raised when the database connection can't be established or when a connection with a requested alias can't be retrieved. """ + pass @@ -39,18 +40,23 @@ def _check_db_name(name): This functionality is copied from pymongo Database class constructor. """ if not isinstance(name, six.string_types): - raise TypeError('name must be an instance of %s' % six.string_types) - elif name != '$external': + raise TypeError("name must be an instance of %s" % six.string_types) + elif name != "$external": _check_name(name) def _get_connection_settings( - db=None, name=None, host=None, port=None, - read_preference=READ_PREFERENCE, - username=None, password=None, - authentication_source=None, - authentication_mechanism=None, - **kwargs): + db=None, + name=None, + host=None, + port=None, + read_preference=READ_PREFERENCE, + username=None, + password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs +): """Get the connection settings as a dict : param db: the name of the database to use, for compatibility with connect @@ -73,18 +79,18 @@ def _get_connection_settings( .. versionchanged:: 0.10.6 - added mongomock support """ conn_settings = { - 'name': name or db or DEFAULT_DATABASE_NAME, - 'host': host or DEFAULT_HOST, - 'port': port or DEFAULT_PORT, - 'read_preference': read_preference, - 'username': username, - 'password': password, - 'authentication_source': authentication_source, - 'authentication_mechanism': authentication_mechanism + "name": name or db or DEFAULT_DATABASE_NAME, + "host": host or DEFAULT_HOST, + "port": port or DEFAULT_PORT, + "read_preference": read_preference, + "username": username, + "password": password, + "authentication_source": authentication_source, + "authentication_mechanism": authentication_mechanism, } - _check_db_name(conn_settings['name']) - conn_host = conn_settings['host'] + _check_db_name(conn_settings["name"]) + conn_host = conn_settings["host"] # Host can be a list or a string, so if string, force to a list. if isinstance(conn_host, six.string_types): @@ -94,32 +100,32 @@ def _get_connection_settings( for entity in conn_host: # Handle Mongomock - if entity.startswith('mongomock://'): - conn_settings['is_mock'] = True + if entity.startswith("mongomock://"): + conn_settings["is_mock"] = True # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` - resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) + resolved_hosts.append(entity.replace("mongomock://", "mongodb://", 1)) # Handle URI style connections, only updating connection params which # were explicitly specified in the URI. - elif '://' in entity: + elif "://" in entity: uri_dict = uri_parser.parse_uri(entity) resolved_hosts.append(entity) - if uri_dict.get('database'): - conn_settings['name'] = uri_dict.get('database') + if uri_dict.get("database"): + conn_settings["name"] = uri_dict.get("database") - for param in ('read_preference', 'username', 'password'): + for param in ("read_preference", "username", "password"): if uri_dict.get(param): conn_settings[param] = uri_dict[param] - uri_options = uri_dict['options'] - if 'replicaset' in uri_options: - conn_settings['replicaSet'] = uri_options['replicaset'] - if 'authsource' in uri_options: - conn_settings['authentication_source'] = uri_options['authsource'] - if 'authmechanism' in uri_options: - conn_settings['authentication_mechanism'] = uri_options['authmechanism'] - if 'readpreference' in uri_options: + uri_options = uri_dict["options"] + if "replicaset" in uri_options: + conn_settings["replicaSet"] = uri_options["replicaset"] + if "authsource" in uri_options: + conn_settings["authentication_source"] = uri_options["authsource"] + if "authmechanism" in uri_options: + conn_settings["authentication_mechanism"] = uri_options["authmechanism"] + if "readpreference" in uri_options: read_preferences = ( ReadPreference.NEAREST, ReadPreference.PRIMARY, @@ -133,34 +139,41 @@ def _get_connection_settings( # int (e.g. 3). # TODO simplify the code below once we drop support for # PyMongo v3.4. - read_pf_mode = uri_options['readpreference'] + read_pf_mode = uri_options["readpreference"] if isinstance(read_pf_mode, six.string_types): read_pf_mode = read_pf_mode.lower() for preference in read_preferences: if ( - preference.name.lower() == read_pf_mode or - preference.mode == read_pf_mode + preference.name.lower() == read_pf_mode + or preference.mode == read_pf_mode ): - conn_settings['read_preference'] = preference + conn_settings["read_preference"] = preference break else: resolved_hosts.append(entity) - conn_settings['host'] = resolved_hosts + conn_settings["host"] = resolved_hosts # Deprecated parameters that should not be passed on - kwargs.pop('slaves', None) - kwargs.pop('is_slave', None) + kwargs.pop("slaves", None) + kwargs.pop("is_slave", None) conn_settings.update(kwargs) return conn_settings -def register_connection(alias, db=None, name=None, host=None, port=None, - read_preference=READ_PREFERENCE, - username=None, password=None, - authentication_source=None, - authentication_mechanism=None, - **kwargs): +def register_connection( + alias, + db=None, + name=None, + host=None, + port=None, + read_preference=READ_PREFERENCE, + username=None, + password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs +): """Register the connection settings. : param alias: the name that will be used to refer to this connection @@ -185,12 +198,17 @@ def register_connection(alias, db=None, name=None, host=None, port=None, .. versionchanged:: 0.10.6 - added mongomock support """ conn_settings = _get_connection_settings( - db=db, name=name, host=host, port=port, + db=db, + name=name, + host=host, + port=port, read_preference=read_preference, - username=username, password=password, + username=username, + password=password, authentication_source=authentication_source, authentication_mechanism=authentication_mechanism, - **kwargs) + **kwargs + ) _connection_settings[alias] = conn_settings @@ -206,7 +224,7 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): if alias in _dbs: # Detach all cached collections in Documents for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): - if issubclass(doc_cls, Document): # Skip EmbeddedDocument + if issubclass(doc_cls, Document): # Skip EmbeddedDocument doc_cls._disconnect() del _dbs[alias] @@ -237,19 +255,21 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Raise MongoEngineConnectionError if it doesn't. if alias not in _connection_settings: if alias == DEFAULT_CONNECTION_NAME: - msg = 'You have not defined a default connection' + msg = "You have not defined a default connection" else: msg = 'Connection with alias "%s" has not been defined' % alias raise MongoEngineConnectionError(msg) def _clean_settings(settings_dict): irrelevant_fields_set = { - 'name', 'username', 'password', - 'authentication_source', 'authentication_mechanism' + "name", + "username", + "password", + "authentication_source", + "authentication_mechanism", } return { - k: v for k, v in settings_dict.items() - if k not in irrelevant_fields_set + k: v for k, v in settings_dict.items() if k not in irrelevant_fields_set } raw_conn_settings = _connection_settings[alias].copy() @@ -260,13 +280,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings = _clean_settings(raw_conn_settings) # Determine if we should use PyMongo's or mongomock's MongoClient. - is_mock = conn_settings.pop('is_mock', False) + is_mock = conn_settings.pop("is_mock", False) if is_mock: try: import mongomock except ImportError: - raise RuntimeError('You need mongomock installed to mock ' - 'MongoEngine.') + raise RuntimeError("You need mongomock installed to mock MongoEngine.") connection_class = mongomock.MongoClient else: connection_class = MongoClient @@ -277,9 +296,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): connection = existing_connection else: connection = _create_connection( - alias=alias, - connection_class=connection_class, - **conn_settings + alias=alias, connection_class=connection_class, **conn_settings ) _connections[alias] = connection return _connections[alias] @@ -294,7 +311,8 @@ def _create_connection(alias, connection_class, **connection_settings): return connection_class(**connection_settings) except Exception as e: raise MongoEngineConnectionError( - 'Cannot connect to database %s :\n%s' % (alias, e)) + "Cannot connect to database %s :\n%s" % (alias, e) + ) def _find_existing_connection(connection_settings): @@ -316,7 +334,7 @@ def _find_existing_connection(connection_settings): # Only remove the name but it's important to # keep the username/password/authentication_source/authentication_mechanism # to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047) - return {k: v for k, v in settings_dict.items() if k != 'name'} + return {k: v for k, v in settings_dict.items() if k != "name"} cleaned_conn_settings = _clean_settings(connection_settings) for db_alias, connection_settings in connection_settings_bis: @@ -332,14 +350,18 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if alias not in _dbs: conn = get_connection(alias) conn_settings = _connection_settings[alias] - db = conn[conn_settings['name']] - auth_kwargs = {'source': conn_settings['authentication_source']} - if conn_settings['authentication_mechanism'] is not None: - auth_kwargs['mechanism'] = conn_settings['authentication_mechanism'] + db = conn[conn_settings["name"]] + auth_kwargs = {"source": conn_settings["authentication_source"]} + if conn_settings["authentication_mechanism"] is not None: + auth_kwargs["mechanism"] = conn_settings["authentication_mechanism"] # Authenticate if necessary - if conn_settings['username'] and (conn_settings['password'] or - conn_settings['authentication_mechanism'] == 'MONGODB-X509'): - db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs) + if conn_settings["username"] and ( + conn_settings["password"] + or conn_settings["authentication_mechanism"] == "MONGODB-X509" + ): + db.authenticate( + conn_settings["username"], conn_settings["password"], **auth_kwargs + ) _dbs[alias] = db return _dbs[alias] @@ -368,8 +390,8 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): if new_conn_settings != prev_conn_setting: err_msg = ( - u'A different connection with alias `{}` was already ' - u'registered. Use disconnect() first' + u"A different connection with alias `{}` was already " + u"registered. Use disconnect() first" ).format(alias) raise MongoEngineConnectionError(err_msg) else: diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 98bd897b..3424a5d5 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -7,8 +7,14 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.pymongo_support import count_documents -__all__ = ('switch_db', 'switch_collection', 'no_dereference', - 'no_sub_classes', 'query_counter', 'set_write_concern') +__all__ = ( + "switch_db", + "switch_collection", + "no_dereference", + "no_sub_classes", + "query_counter", + "set_write_concern", +) class switch_db(object): @@ -38,17 +44,17 @@ class switch_db(object): self.cls = cls self.collection = cls._get_collection() self.db_alias = db_alias - self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) + self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) def __enter__(self): """Change the db_alias and clear the cached collection.""" - self.cls._meta['db_alias'] = self.db_alias + self.cls._meta["db_alias"] = self.db_alias self.cls._collection = None return self.cls def __exit__(self, t, value, traceback): """Reset the db_alias and collection.""" - self.cls._meta['db_alias'] = self.ori_db_alias + self.cls._meta["db_alias"] = self.ori_db_alias self.cls._collection = self.collection @@ -111,14 +117,15 @@ class no_dereference(object): """ self.cls = cls - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - ComplexBaseField = _import_class('ComplexBaseField') + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") + ComplexBaseField = _import_class("ComplexBaseField") - self.deref_fields = [k for k, v in iteritems(self.cls._fields) - if isinstance(v, (ReferenceField, - GenericReferenceField, - ComplexBaseField))] + self.deref_fields = [ + k + for k, v in iteritems(self.cls._fields) + if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField)) + ] def __enter__(self): """Change the objects default and _auto_dereference values.""" @@ -180,15 +187,12 @@ class query_counter(object): """ self.db = get_db() self.initial_profiling_level = None - self._ctx_query_counter = 0 # number of queries issued by the context + self._ctx_query_counter = 0 # number of queries issued by the context self._ignored_query = { - 'ns': - {'$ne': '%s.system.indexes' % self.db.name}, - 'op': # MONGODB < 3.2 - {'$ne': 'killcursors'}, - 'command.killCursors': # MONGODB >= 3.2 - {'$exists': False} + "ns": {"$ne": "%s.system.indexes" % self.db.name}, + "op": {"$ne": "killcursors"}, # MONGODB < 3.2 + "command.killCursors": {"$exists": False}, # MONGODB >= 3.2 } def _turn_on_profiling(self): @@ -238,8 +242,13 @@ class query_counter(object): and substracting the queries issued by this context. In fact everytime this is called, 1 query is issued so we need to balance that """ - count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter - self._ctx_query_counter += 1 # Account for the query we just issued to gather the information + count = ( + count_documents(self.db.system.profile, self._ignored_query) + - self._ctx_query_counter + ) + self._ctx_query_counter += ( + 1 + ) # Account for the query we just issued to gather the information return count diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index eaebb56f..9e75f353 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -2,8 +2,13 @@ from bson import DBRef, SON import six from six import iteritems -from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, - TopLevelDocumentMetaclass, get_document) +from mongoengine.base import ( + BaseDict, + BaseList, + EmbeddedDocumentList, + TopLevelDocumentMetaclass, + get_document, +) from mongoengine.base.datastructures import LazyReference from mongoengine.connection import get_db from mongoengine.document import Document, EmbeddedDocument @@ -36,21 +41,23 @@ class DeReference(object): self.max_depth = max_depth doc_type = None - if instance and isinstance(instance, (Document, EmbeddedDocument, - TopLevelDocumentMetaclass)): + if instance and isinstance( + instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass) + ): doc_type = instance._fields.get(name) - while hasattr(doc_type, 'field'): + while hasattr(doc_type, "field"): doc_type = doc_type.field if isinstance(doc_type, ReferenceField): field = doc_type doc_type = doc_type.document_type - is_list = not hasattr(items, 'items') + is_list = not hasattr(items, "items") if is_list and all([i.__class__ == doc_type for i in items]): return items elif not is_list and all( - [i.__class__ == doc_type for i in items.values()]): + [i.__class__ == doc_type for i in items.values()] + ): return items elif not field.dbref: # We must turn the ObjectIds into DBRefs @@ -83,7 +90,7 @@ class DeReference(object): new_items[k] = value return new_items - if not hasattr(items, 'items'): + if not hasattr(items, "items"): items = _get_items_from_list(items) else: items = _get_items_from_dict(items) @@ -120,13 +127,19 @@ class DeReference(object): continue elif isinstance(v, DBRef): reference_map.setdefault(field.document_type, set()).add(v.id) - elif isinstance(v, (dict, SON)) and '_ref' in v: - reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) + elif isinstance(v, (dict, SON)) and "_ref" in v: + reference_map.setdefault(get_document(v["_cls"]), set()).add( + v["_ref"].id + ) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - field_cls = getattr(getattr(field, 'field', None), 'document_type', None) + field_cls = getattr( + getattr(field, "field", None), "document_type", None + ) references = self._find_references(v, depth) for key, refs in iteritems(references): - if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): + if isinstance( + field_cls, (Document, TopLevelDocumentMetaclass) + ): key = field_cls reference_map.setdefault(key, set()).update(refs) elif isinstance(item, LazyReference): @@ -134,8 +147,10 @@ class DeReference(object): continue elif isinstance(item, DBRef): reference_map.setdefault(item.collection, set()).add(item.id) - elif isinstance(item, (dict, SON)) and '_ref' in item: - reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) + elif isinstance(item, (dict, SON)) and "_ref" in item: + reference_map.setdefault(get_document(item["_cls"]), set()).add( + item["_ref"].id + ) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) for key, refs in iteritems(references): @@ -151,12 +166,13 @@ class DeReference(object): # we use getattr instead of hasattr because hasattr swallows any exception under python2 # so it could hide nasty things without raising exceptions (cfr bug #1688)) - ref_document_cls_exists = (getattr(collection, 'objects', None) is not None) + ref_document_cls_exists = getattr(collection, "objects", None) is not None if ref_document_cls_exists: col_name = collection._get_collection_name() - refs = [dbref for dbref in dbrefs - if (col_name, dbref) not in object_map] + refs = [ + dbref for dbref in dbrefs if (col_name, dbref) not in object_map + ] references = collection.objects.in_bulk(refs) for key, doc in iteritems(references): object_map[(col_name, key)] = doc @@ -164,23 +180,26 @@ class DeReference(object): if isinstance(doc_type, (ListField, DictField, MapField)): continue - refs = [dbref for dbref in dbrefs - if (collection, dbref) not in object_map] + refs = [ + dbref for dbref in dbrefs if (collection, dbref) not in object_map + ] if doc_type: - references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) + references = doc_type._get_db()[collection].find( + {"_id": {"$in": refs}} + ) for ref in references: doc = doc_type._from_son(ref) object_map[(collection, doc.id)] = doc else: - references = get_db()[collection].find({'_id': {'$in': refs}}) + references = get_db()[collection].find({"_id": {"$in": refs}}) for ref in references: - if '_cls' in ref: - doc = get_document(ref['_cls'])._from_son(ref) + if "_cls" in ref: + doc = get_document(ref["_cls"])._from_son(ref) elif doc_type is None: doc = get_document( - ''.join(x.capitalize() - for x in collection.split('_')))._from_son(ref) + "".join(x.capitalize() for x in collection.split("_")) + )._from_son(ref) else: doc = doc_type._from_son(ref) object_map[(collection, doc.id)] = doc @@ -208,19 +227,20 @@ class DeReference(object): return BaseList(items, instance, name) if isinstance(items, (dict, SON)): - if '_ref' in items: + if "_ref" in items: return self.object_map.get( - (items['_ref'].collection, items['_ref'].id), items) - elif '_cls' in items: - doc = get_document(items['_cls'])._from_son(items) - _cls = doc._data.pop('_cls', None) - del items['_cls'] + (items["_ref"].collection, items["_ref"].id), items + ) + elif "_cls" in items: + doc = get_document(items["_cls"])._from_son(items) + _cls = doc._data.pop("_cls", None) + del items["_cls"] doc._data = self._attach_objects(doc._data, depth, doc, None) if _cls is not None: - doc._data['_cls'] = _cls + doc._data["_cls"] = _cls return doc - if not hasattr(items, 'items'): + if not hasattr(items, "items"): is_list = True list_type = BaseList if isinstance(items, EmbeddedDocumentList): @@ -247,17 +267,25 @@ class DeReference(object): v = data[k]._data.get(field_name, None) if isinstance(v, DBRef): data[k]._data[field_name] = self.object_map.get( - (v.collection, v.id), v) - elif isinstance(v, (dict, SON)) and '_ref' in v: + (v.collection, v.id), v + ) + elif isinstance(v, (dict, SON)) and "_ref" in v: data[k]._data[field_name] = self.object_map.get( - (v['_ref'].collection, v['_ref'].id), v) + (v["_ref"].collection, v["_ref"].id), v + ) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name) - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) + item_name = six.text_type("{0}.{1}.{2}").format( + name, k, field_name + ) + data[k]._data[field_name] = self._attach_objects( + v, depth, instance=instance, name=item_name + ) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - item_name = '%s.%s' % (name, k) if name else name - data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) - elif isinstance(v, DBRef) and hasattr(v, 'id'): + item_name = "%s.%s" % (name, k) if name else name + data[k] = self._attach_objects( + v, depth - 1, instance=instance, name=item_name + ) + elif isinstance(v, DBRef) and hasattr(v, "id"): data[k] = self.object_map.get((v.collection, v.id), v) if instance and name: diff --git a/mongoengine/document.py b/mongoengine/document.py index 520de5bf..41166df4 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -8,23 +8,36 @@ import six from six import iteritems from mongoengine import signals -from mongoengine.base import (BaseDict, BaseDocument, BaseList, - DocumentMetaclass, EmbeddedDocumentList, - TopLevelDocumentMetaclass, get_document) +from mongoengine.base import ( + BaseDict, + BaseDocument, + BaseList, + DocumentMetaclass, + EmbeddedDocumentList, + TopLevelDocumentMetaclass, + get_document, +) from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.context_managers import (set_write_concern, - switch_collection, - switch_db) -from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, - SaveConditionError) +from mongoengine.context_managers import set_write_concern, switch_collection, switch_db +from mongoengine.errors import ( + InvalidDocumentError, + InvalidQueryError, + SaveConditionError, +) from mongoengine.pymongo_support import list_collection_names -from mongoengine.queryset import (NotUniqueError, OperationError, - QuerySet, transform) +from mongoengine.queryset import NotUniqueError, OperationError, QuerySet, transform -__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', - 'DynamicEmbeddedDocument', 'OperationError', - 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') +__all__ = ( + "Document", + "EmbeddedDocument", + "DynamicDocument", + "DynamicEmbeddedDocument", + "OperationError", + "InvalidCollectionError", + "NotUniqueError", + "MapReduceDocument", +) def includes_cls(fields): @@ -35,7 +48,7 @@ def includes_cls(fields): first_field = fields[0] elif isinstance(fields[0], (list, tuple)) and len(fields[0]): first_field = fields[0][0] - return first_field == '_cls' + return first_field == "_cls" class InvalidCollectionError(Exception): @@ -56,7 +69,7 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)): :attr:`meta` dictionary. """ - __slots__ = ('_instance', ) + __slots__ = ("_instance",) # The __metaclass__ attribute is removed by 2to3 when running with Python3 # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 @@ -85,8 +98,8 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)): data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs) # remove _id from the SON if it's in it and it's None - if '_id' in data and data['_id'] is None: - del data['_id'] + if "_id" in data and data["_id"] is None: + del data["_id"] return data @@ -147,19 +160,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 my_metaclass = TopLevelDocumentMetaclass - __slots__ = ('__objects',) + __slots__ = ("__objects",) @property def pk(self): """Get the primary key.""" - if 'id_field' not in self._meta: + if "id_field" not in self._meta: return None - return getattr(self, self._meta['id_field']) + return getattr(self, self._meta["id_field"]) @pk.setter def pk(self, value): """Set the primary key.""" - return setattr(self, self._meta['id_field'], value) + return setattr(self, self._meta["id_field"], value) def __hash__(self): """Return the hash based on the PK of this document. If it's new @@ -173,7 +186,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): @classmethod def _get_db(cls): """Some Model using other db_alias""" - return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) + return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) @classmethod def _disconnect(cls): @@ -190,9 +203,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): 2. Creates indexes defined in this document's :attr:`meta` dictionary. This happens only if `auto_create_index` is True. """ - if not hasattr(cls, '_collection') or cls._collection is None: + if not hasattr(cls, "_collection") or cls._collection is None: # Get the collection, either capped or regular. - if cls._meta.get('max_size') or cls._meta.get('max_documents'): + if cls._meta.get("max_size") or cls._meta.get("max_documents"): cls._collection = cls._get_capped_collection() else: db = cls._get_db() @@ -203,8 +216,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # set to False. # Also there is no need to ensure indexes on slave. db = cls._get_db() - if cls._meta.get('auto_create_index', True) and\ - db.client.is_primary: + if cls._meta.get("auto_create_index", True) and db.client.is_primary: cls.ensure_indexes() return cls._collection @@ -216,8 +228,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): collection_name = cls._get_collection_name() # Get max document limit and max byte size from meta. - max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default - max_documents = cls._meta.get('max_documents') + max_size = cls._meta.get("max_size") or 10 * 2 ** 20 # 10MB default + max_documents = cls._meta.get("max_documents") # MongoDB will automatically raise the size to make it a multiple of # 256 bytes. We raise it here ourselves to be able to reliably compare @@ -227,24 +239,23 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # If the collection already exists and has different options # (i.e. isn't capped or has different max/size), raise an error. - if collection_name in list_collection_names(db, include_system_collections=True): + if collection_name in list_collection_names( + db, include_system_collections=True + ): collection = db[collection_name] options = collection.options() - if ( - options.get('max') != max_documents or - options.get('size') != max_size - ): + if options.get("max") != max_documents or options.get("size") != max_size: raise InvalidCollectionError( 'Cannot create collection "{}" as a capped ' - 'collection as it already exists'.format(cls._collection) + "collection as it already exists".format(cls._collection) ) return collection # Create a new capped collection. - opts = {'capped': True, 'size': max_size} + opts = {"capped": True, "size": max_size} if max_documents: - opts['max'] = max_documents + opts["max"] = max_documents return db.create_collection(collection_name, **opts) @@ -253,11 +264,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # If '_id' is None, try and set it from self._data. If that # doesn't exist either, remove '_id' from the SON completely. - if data['_id'] is None: - if self._data.get('id') is None: - del data['_id'] + if data["_id"] is None: + if self._data.get("id") is None: + del data["_id"] else: - data['_id'] = self._data['id'] + data["_id"] = self._data["id"] return data @@ -279,15 +290,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): query = {} if self.pk is None: - raise InvalidDocumentError('The document does not have a primary key.') + raise InvalidDocumentError("The document does not have a primary key.") - id_field = self._meta['id_field'] + id_field = self._meta["id_field"] query = query.copy() if isinstance(query, dict) else query.to_query(self) if id_field not in query: query[id_field] = self.pk elif query[id_field] != self.pk: - raise InvalidQueryError('Invalid document modify query: it must modify only this document.') + raise InvalidQueryError( + "Invalid document modify query: it must modify only this document." + ) # Need to add shard key to query, or you get an error query.update(self._object_key) @@ -304,9 +317,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): return True - def save(self, force_insert=False, validate=True, clean=True, - write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, save_condition=None, signal_kwargs=None, **kwargs): + def save( + self, + force_insert=False, + validate=True, + clean=True, + write_concern=None, + cascade=None, + cascade_kwargs=None, + _refs=None, + save_condition=None, + signal_kwargs=None, + **kwargs + ): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -360,8 +383,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """ signal_kwargs = signal_kwargs or {} - if self._meta.get('abstract'): - raise InvalidDocumentError('Cannot save an abstract document.') + if self._meta.get("abstract"): + raise InvalidDocumentError("Cannot save an abstract document.") signals.pre_save.send(self.__class__, document=self, **signal_kwargs) @@ -371,15 +394,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if write_concern is None: write_concern = {} - doc_id = self.to_mongo(fields=[self._meta['id_field']]) - created = ('_id' not in doc_id or self._created or force_insert) + doc_id = self.to_mongo(fields=[self._meta["id_field"]]) + created = "_id" not in doc_id or self._created or force_insert - signals.pre_save_post_validation.send(self.__class__, document=self, - created=created, **signal_kwargs) + signals.pre_save_post_validation.send( + self.__class__, document=self, created=created, **signal_kwargs + ) # it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation doc = self.to_mongo() - if self._meta.get('auto_create_index', True): + if self._meta.get("auto_create_index", True): self.ensure_indexes() try: @@ -387,44 +411,45 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if created: object_id = self._save_create(doc, force_insert, write_concern) else: - object_id, created = self._save_update(doc, save_condition, - write_concern) + object_id, created = self._save_update( + doc, save_condition, write_concern + ) if cascade is None: - cascade = (self._meta.get('cascade', False) or - cascade_kwargs is not None) + cascade = self._meta.get("cascade", False) or cascade_kwargs is not None if cascade: kwargs = { - 'force_insert': force_insert, - 'validate': validate, - 'write_concern': write_concern, - 'cascade': cascade + "force_insert": force_insert, + "validate": validate, + "write_concern": write_concern, + "cascade": cascade, } if cascade_kwargs: # Allow granular control over cascades kwargs.update(cascade_kwargs) - kwargs['_refs'] = _refs + kwargs["_refs"] = _refs self.cascade_save(**kwargs) except pymongo.errors.DuplicateKeyError as err: - message = u'Tried to save duplicate unique keys (%s)' + message = u"Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % six.text_type(err)) except pymongo.errors.OperationFailure as err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', six.text_type(err)): + message = "Could not save document (%s)" + if re.match("^E1100[01] duplicate key", six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' + message = u"Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % six.text_type(err)) raise OperationError(message % six.text_type(err)) # Make sure we store the PK on this document now that it's saved - id_field = self._meta['id_field'] - if created or id_field not in self._meta.get('shard_key', []): + id_field = self._meta["id_field"] + if created or id_field not in self._meta.get("shard_key", []): self[id_field] = self._fields[id_field].to_python(object_id) - signals.post_save.send(self.__class__, document=self, - created=created, **signal_kwargs) + signals.post_save.send( + self.__class__, document=self, created=created, **signal_kwargs + ) self._clear_changed_fields() self._created = False @@ -442,11 +467,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): return wc_collection.insert_one(doc).inserted_id # insert_one will provoke UniqueError alongside save does not # therefore, it need to catch and call replace_one. - if '_id' in doc: + if "_id" in doc: raw_object = wc_collection.find_one_and_replace( - {'_id': doc['_id']}, doc) + {"_id": doc["_id"]}, doc + ) if raw_object: - return doc['_id'] + return doc["_id"] object_id = wc_collection.insert_one(doc).inserted_id @@ -461,9 +487,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): update_doc = {} if updates: - update_doc['$set'] = updates + update_doc["$set"] = updates if removals: - update_doc['$unset'] = removals + update_doc["$unset"] = removals return update_doc @@ -473,39 +499,38 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): Helper method, should only be used inside save(). """ collection = self._get_collection() - object_id = doc['_id'] + object_id = doc["_id"] created = False select_dict = {} if save_condition is not None: select_dict = transform.query(self.__class__, **save_condition) - select_dict['_id'] = object_id + select_dict["_id"] = object_id # Need to add shard key to query, or you get an error - shard_key = self._meta.get('shard_key', tuple()) + shard_key = self._meta.get("shard_key", tuple()) for k in shard_key: - path = self._lookup_field(k.split('.')) + path = self._lookup_field(k.split(".")) actual_key = [p.db_field for p in path] val = doc for ak in actual_key: val = val[ak] - select_dict['.'.join(actual_key)] = val + select_dict[".".join(actual_key)] = val update_doc = self._get_update_doc() if update_doc: upsert = save_condition is None with set_write_concern(collection, write_concern) as wc_collection: last_error = wc_collection.update_one( - select_dict, - update_doc, - upsert=upsert + select_dict, update_doc, upsert=upsert ).raw_result - if not upsert and last_error['n'] == 0: - raise SaveConditionError('Race condition preventing' - ' document update detected') + if not upsert and last_error["n"] == 0: + raise SaveConditionError( + "Race condition preventing document update detected" + ) if last_error is not None: - updated_existing = last_error.get('updatedExisting') + updated_existing = last_error.get("updatedExisting") if updated_existing is False: created = True # !!! This is bad, means we accidentally created a new, @@ -518,21 +543,20 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """Recursively save any references and generic references on the document. """ - _refs = kwargs.get('_refs') or [] + _refs = kwargs.get("_refs") or [] - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") for name, cls in self._fields.items(): - if not isinstance(cls, (ReferenceField, - GenericReferenceField)): + if not isinstance(cls, (ReferenceField, GenericReferenceField)): continue ref = self._data.get(name) if not ref or isinstance(ref, DBRef): continue - if not getattr(ref, '_changed_fields', True): + if not getattr(ref, "_changed_fields", True): continue ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data)) @@ -545,7 +569,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): @property def _qs(self): """Return the default queryset corresponding to this document.""" - if not hasattr(self, '__objects'): + if not hasattr(self, "__objects"): self.__objects = QuerySet(self, self._get_collection()) return self.__objects @@ -558,15 +582,15 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): a sharded collection with a compound shard key, it can contain a more complex query. """ - select_dict = {'pk': self.pk} - shard_key = self.__class__._meta.get('shard_key', tuple()) + select_dict = {"pk": self.pk} + shard_key = self.__class__._meta.get("shard_key", tuple()) for k in shard_key: - path = self._lookup_field(k.split('.')) + path = self._lookup_field(k.split(".")) actual_key = [p.db_field for p in path] val = self for ak in actual_key: val = getattr(val, ak) - select_dict['__'.join(actual_key)] = val + select_dict["__".join(actual_key)] = val return select_dict def update(self, **kwargs): @@ -577,14 +601,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): been saved. """ if self.pk is None: - if kwargs.get('upsert', False): + if kwargs.get("upsert", False): query = self.to_mongo() - if '_cls' in query: - del query['_cls'] + if "_cls" in query: + del query["_cls"] return self._qs.filter(**query).update_one(**kwargs) else: - raise OperationError( - 'attempt to update a document not yet saved') + raise OperationError("attempt to update a document not yet saved") # Need to add shard key to query, or you get an error return self._qs.filter(**self._object_key).update_one(**kwargs) @@ -608,16 +631,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): signals.pre_delete.send(self.__class__, document=self, **signal_kwargs) # Delete FileFields separately - FileField = _import_class('FileField') + FileField = _import_class("FileField") for name, field in iteritems(self._fields): if isinstance(field, FileField): getattr(self, name).delete() try: - self._qs.filter( - **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) + self._qs.filter(**self._object_key).delete( + write_concern=write_concern, _from_doc_delete=True + ) except pymongo.errors.OperationFailure as err: - message = u'Could not delete document (%s)' % err.message + message = u"Could not delete document (%s)" % err.message raise OperationError(message) signals.post_delete.send(self.__class__, document=self, **signal_kwargs) @@ -686,7 +710,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionadded:: 0.5 """ - DeReference = _import_class('DeReference') + DeReference = _import_class("DeReference") DeReference()([self], max_depth + 1) return self @@ -704,20 +728,24 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): if fields and isinstance(fields[0], int): max_depth = fields[0] fields = fields[1:] - elif 'max_depth' in kwargs: - max_depth = kwargs['max_depth'] + elif "max_depth" in kwargs: + max_depth = kwargs["max_depth"] if self.pk is None: - raise self.DoesNotExist('Document does not exist') + raise self.DoesNotExist("Document does not exist") - obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( - **self._object_key).only(*fields).limit( - 1).select_related(max_depth=max_depth) + obj = ( + self._qs.read_preference(ReadPreference.PRIMARY) + .filter(**self._object_key) + .only(*fields) + .limit(1) + .select_related(max_depth=max_depth) + ) if obj: obj = obj[0] else: - raise self.DoesNotExist('Document does not exist') + raise self.DoesNotExist("Document does not exist") for field in obj._data: if not fields or field in fields: try: @@ -733,9 +761,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): # i.e. obj.update(unset__field=1) followed by obj.reload() delattr(self, field) - self._changed_fields = list( - set(self._changed_fields) - set(fields) - ) if fields else obj._changed_fields + self._changed_fields = ( + list(set(self._changed_fields) - set(fields)) + if fields + else obj._changed_fields + ) self._created = False return self @@ -761,7 +791,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """Returns an instance of :class:`~bson.dbref.DBRef` useful in `__raw__` queries.""" if self.pk is None: - msg = 'Only saved documents can have a valid dbref' + msg = "Only saved documents can have a valid dbref" raise OperationError(msg) return DBRef(self.__class__._get_collection_name(), self.pk) @@ -770,18 +800,22 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """This method registers the delete rules to apply when removing this object. """ - classes = [get_document(class_name) - for class_name in cls._subclasses - if class_name != cls.__name__] + [cls] - documents = [get_document(class_name) - for class_name in document_cls._subclasses - if class_name != document_cls.__name__] + [document_cls] + classes = [ + get_document(class_name) + for class_name in cls._subclasses + if class_name != cls.__name__ + ] + [cls] + documents = [ + get_document(class_name) + for class_name in document_cls._subclasses + if class_name != document_cls.__name__ + ] + [document_cls] for klass in classes: for document_cls in documents: - delete_rules = klass._meta.get('delete_rules') or {} + delete_rules = klass._meta.get("delete_rules") or {} delete_rules[(document_cls, field_name)] = rule - klass._meta['delete_rules'] = delete_rules + klass._meta["delete_rules"] = delete_rules @classmethod def drop_collection(cls): @@ -796,8 +830,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """ coll_name = cls._get_collection_name() if not coll_name: - raise OperationError('Document %s has no collection defined ' - '(is it abstract ?)' % cls) + raise OperationError( + "Document %s has no collection defined (is it abstract ?)" % cls + ) cls._collection = None db = cls._get_db() db.drop_collection(coll_name) @@ -813,19 +848,18 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): """ index_spec = cls._build_index_spec(keys) index_spec = index_spec.copy() - fields = index_spec.pop('fields') - drop_dups = kwargs.get('drop_dups', False) + fields = index_spec.pop("fields") + drop_dups = kwargs.get("drop_dups", False) if drop_dups: - msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' + msg = "drop_dups is deprecated and is removed when using PyMongo 3+." warnings.warn(msg, DeprecationWarning) - index_spec['background'] = background + index_spec["background"] = background index_spec.update(kwargs) return cls._get_collection().create_index(fields, **index_spec) @classmethod - def ensure_index(cls, key_or_list, drop_dups=False, background=False, - **kwargs): + def ensure_index(cls, key_or_list, drop_dups=False, background=False, **kwargs): """Ensure that the given indexes are in place. Deprecated in favour of create_index. @@ -837,7 +871,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): will be removed if PyMongo3+ is used """ if drop_dups: - msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' + msg = "drop_dups is deprecated and is removed when using PyMongo 3+." warnings.warn(msg, DeprecationWarning) return cls.create_index(key_or_list, background=background, **kwargs) @@ -850,12 +884,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. note:: You can disable automatic index creation by setting `auto_create_index` to False in the documents meta data """ - background = cls._meta.get('index_background', False) - drop_dups = cls._meta.get('index_drop_dups', False) - index_opts = cls._meta.get('index_opts') or {} - index_cls = cls._meta.get('index_cls', True) + background = cls._meta.get("index_background", False) + drop_dups = cls._meta.get("index_drop_dups", False) + index_opts = cls._meta.get("index_opts") or {} + index_cls = cls._meta.get("index_cls", True) if drop_dups: - msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' + msg = "drop_dups is deprecated and is removed when using PyMongo 3+." warnings.warn(msg, DeprecationWarning) collection = cls._get_collection() @@ -871,40 +905,39 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): cls_indexed = False # Ensure document-defined indexes are created - if cls._meta['index_specs']: - index_spec = cls._meta['index_specs'] + if cls._meta["index_specs"]: + index_spec = cls._meta["index_specs"] for spec in index_spec: spec = spec.copy() - fields = spec.pop('fields') + fields = spec.pop("fields") cls_indexed = cls_indexed or includes_cls(fields) opts = index_opts.copy() opts.update(spec) # we shouldn't pass 'cls' to the collection.ensureIndex options # because of https://jira.mongodb.org/browse/SERVER-769 - if 'cls' in opts: - del opts['cls'] + if "cls" in opts: + del opts["cls"] collection.create_index(fields, background=background, **opts) # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls - if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'): + if index_cls and not cls_indexed and cls._meta.get("allow_inheritance"): # we shouldn't pass 'cls' to the collection.ensureIndex options # because of https://jira.mongodb.org/browse/SERVER-769 - if 'cls' in index_opts: - del index_opts['cls'] + if "cls" in index_opts: + del index_opts["cls"] - collection.create_index('_cls', background=background, - **index_opts) + collection.create_index("_cls", background=background, **index_opts) @classmethod def list_indexes(cls): """ Lists all of the indexes that should be created for given collection. It includes all the indexes from super- and sub-classes. """ - if cls._meta.get('abstract'): + if cls._meta.get("abstract"): return [] # get all the base classes, subclasses and siblings @@ -912,22 +945,27 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): def get_classes(cls): - if (cls not in classes and - isinstance(cls, TopLevelDocumentMetaclass)): + if cls not in classes and isinstance(cls, TopLevelDocumentMetaclass): classes.append(cls) for base_cls in cls.__bases__: - if (isinstance(base_cls, TopLevelDocumentMetaclass) and - base_cls != Document and - not base_cls._meta.get('abstract') and - base_cls._get_collection().full_name == cls._get_collection().full_name and - base_cls not in classes): + if ( + isinstance(base_cls, TopLevelDocumentMetaclass) + and base_cls != Document + and not base_cls._meta.get("abstract") + and base_cls._get_collection().full_name + == cls._get_collection().full_name + and base_cls not in classes + ): classes.append(base_cls) get_classes(base_cls) for subclass in cls.__subclasses__(): - if (isinstance(base_cls, TopLevelDocumentMetaclass) and - subclass._get_collection().full_name == cls._get_collection().full_name and - subclass not in classes): + if ( + isinstance(base_cls, TopLevelDocumentMetaclass) + and subclass._get_collection().full_name + == cls._get_collection().full_name + and subclass not in classes + ): classes.append(subclass) get_classes(subclass) @@ -937,11 +975,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): def get_indexes_spec(cls): indexes = [] - if cls._meta['index_specs']: - index_spec = cls._meta['index_specs'] + if cls._meta["index_specs"]: + index_spec = cls._meta["index_specs"] for spec in index_spec: spec = spec.copy() - fields = spec.pop('fields') + fields = spec.pop("fields") indexes.append(fields) return indexes @@ -952,10 +990,10 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): indexes.append(index) # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed - if [(u'_id', 1)] not in indexes: - indexes.append([(u'_id', 1)]) - if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'): - indexes.append([(u'_cls', 1)]) + if [(u"_id", 1)] not in indexes: + indexes.append([(u"_id", 1)]) + if cls._meta.get("index_cls", True) and cls._meta.get("allow_inheritance"): + indexes.append([(u"_cls", 1)]) return indexes @@ -969,27 +1007,26 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): existing = [] for info in cls._get_collection().index_information().values(): - if '_fts' in info['key'][0]: - index_type = info['key'][0][1] - text_index_fields = info.get('weights').keys() - existing.append( - [(key, index_type) for key in text_index_fields]) + if "_fts" in info["key"][0]: + index_type = info["key"][0][1] + text_index_fields = info.get("weights").keys() + existing.append([(key, index_type) for key in text_index_fields]) else: - existing.append(info['key']) + existing.append(info["key"]) missing = [index for index in required if index not in existing] extra = [index for index in existing if index not in required] # if { _cls: 1 } is missing, make sure it's *really* necessary - if [(u'_cls', 1)] in missing: + if [(u"_cls", 1)] in missing: cls_obsolete = False for index in existing: if includes_cls(index) and index not in extra: cls_obsolete = True break if cls_obsolete: - missing.remove([(u'_cls', 1)]) + missing.remove([(u"_cls", 1)]) - return {'missing': missing, 'extra': extra} + return {"missing": missing, "extra": extra} class DynamicDocument(six.with_metaclass(TopLevelDocumentMetaclass, Document)): @@ -1074,17 +1111,16 @@ class MapReduceDocument(object): """Lazy-load the object referenced by ``self.key``. ``self.key`` should be the ``primary_key``. """ - id_field = self._document()._meta['id_field'] + id_field = self._document()._meta["id_field"] id_field_type = type(id_field) if not isinstance(self.key, id_field_type): try: self.key = id_field_type(self.key) except Exception: - raise Exception('Could not cast key as %s' % - id_field_type.__name__) + raise Exception("Could not cast key as %s" % id_field_type.__name__) - if not hasattr(self, '_key_object'): + if not hasattr(self, "_key_object"): self._key_object = self._document.objects.with_id(self.key) return self._key_object return self._key_object diff --git a/mongoengine/errors.py b/mongoengine/errors.py index bea1d3dc..9852f2a1 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -3,10 +3,20 @@ from collections import defaultdict import six from six import iteritems -__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', - 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', - 'OperationError', 'NotUniqueError', 'FieldDoesNotExist', - 'ValidationError', 'SaveConditionError', 'DeprecatedError') +__all__ = ( + "NotRegistered", + "InvalidDocumentError", + "LookUpError", + "DoesNotExist", + "MultipleObjectsReturned", + "InvalidQueryError", + "OperationError", + "NotUniqueError", + "FieldDoesNotExist", + "ValidationError", + "SaveConditionError", + "DeprecatedError", +) class NotRegistered(Exception): @@ -71,25 +81,25 @@ class ValidationError(AssertionError): field_name = None _message = None - def __init__(self, message='', **kwargs): + def __init__(self, message="", **kwargs): super(ValidationError, self).__init__(message) - self.errors = kwargs.get('errors', {}) - self.field_name = kwargs.get('field_name') + self.errors = kwargs.get("errors", {}) + self.field_name = kwargs.get("field_name") self.message = message def __str__(self): return six.text_type(self.message) def __repr__(self): - return '%s(%s,)' % (self.__class__.__name__, self.message) + return "%s(%s,)" % (self.__class__.__name__, self.message) def __getattribute__(self, name): message = super(ValidationError, self).__getattribute__(name) - if name == 'message': + if name == "message": if self.field_name: - message = '%s' % message + message = "%s" % message if self.errors: - message = '%s(%s)' % (message, self._format_errors()) + message = "%s(%s)" % (message, self._format_errors()) return message def _get_message(self): @@ -128,22 +138,22 @@ class ValidationError(AssertionError): def _format_errors(self): """Returns a string listing all errors within a document""" - def generate_key(value, prefix=''): + def generate_key(value, prefix=""): if isinstance(value, list): - value = ' '.join([generate_key(k) for k in value]) + value = " ".join([generate_key(k) for k in value]) elif isinstance(value, dict): - value = ' '.join( - [generate_key(v, k) for k, v in iteritems(value)]) + value = " ".join([generate_key(v, k) for k, v in iteritems(value)]) - results = '%s.%s' % (prefix, value) if prefix else value + results = "%s.%s" % (prefix, value) if prefix else value return results error_dict = defaultdict(list) for k, v in iteritems(self.to_dict()): error_dict[generate_key(v)].append(k) - return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)]) + return " ".join(["%s: %s" % (k, v) for k, v in iteritems(error_dict)]) class DeprecatedError(Exception): """Raise when a user uses a feature that has been Deprecated""" + pass diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 2a4a2ad8..7ab2276d 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -27,9 +27,15 @@ except ImportError: Int64 = long -from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, - GeoJsonBaseField, LazyReference, ObjectIdField, - get_document) +from mongoengine.base import ( + BaseDocument, + BaseField, + ComplexBaseField, + GeoJsonBaseField, + LazyReference, + ObjectIdField, + get_document, +) from mongoengine.base.utils import LazyRegexCompiler from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db @@ -53,21 +59,51 @@ if six.PY3: __all__ = ( - 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', - 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'DateField', - 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', - 'GenericEmbeddedDocumentField', 'DynamicField', 'ListField', - 'SortedListField', 'EmbeddedDocumentListField', 'DictField', - 'MapField', 'ReferenceField', 'CachedReferenceField', - 'LazyReferenceField', 'GenericLazyReferenceField', - 'GenericReferenceField', 'BinaryField', 'GridFSError', 'GridFSProxy', - 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', - 'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', - 'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', - 'MultiPolygonField', 'GeoJsonBaseField' + "StringField", + "URLField", + "EmailField", + "IntField", + "LongField", + "FloatField", + "DecimalField", + "BooleanField", + "DateTimeField", + "DateField", + "ComplexDateTimeField", + "EmbeddedDocumentField", + "ObjectIdField", + "GenericEmbeddedDocumentField", + "DynamicField", + "ListField", + "SortedListField", + "EmbeddedDocumentListField", + "DictField", + "MapField", + "ReferenceField", + "CachedReferenceField", + "LazyReferenceField", + "GenericLazyReferenceField", + "GenericReferenceField", + "BinaryField", + "GridFSError", + "GridFSProxy", + "FileField", + "ImageGridFsProxy", + "ImproperlyConfigured", + "ImageField", + "GeoPointField", + "PointField", + "LineStringField", + "PolygonField", + "SequenceField", + "UUIDField", + "MultiPointField", + "MultiLineStringField", + "MultiPolygonField", + "GeoJsonBaseField", ) -RECURSIVE_REFERENCE_CONSTANT = 'self' +RECURSIVE_REFERENCE_CONSTANT = "self" class StringField(BaseField): @@ -83,23 +119,23 @@ class StringField(BaseField): if isinstance(value, six.text_type): return value try: - value = value.decode('utf-8') + value = value.decode("utf-8") except Exception: pass return value def validate(self, value): if not isinstance(value, six.string_types): - self.error('StringField only accepts string values') + self.error("StringField only accepts string values") if self.max_length is not None and len(value) > self.max_length: - self.error('String value is too long') + self.error("String value is too long") if self.min_length is not None and len(value) < self.min_length: - self.error('String value is too short') + self.error("String value is too short") if self.regex is not None and self.regex.match(value) is None: - self.error('String value did not match validation regex') + self.error("String value did not match validation regex") def lookup_member(self, member_name): return None @@ -109,18 +145,18 @@ class StringField(BaseField): return value if op in STRING_OPERATORS: - case_insensitive = op.startswith('i') - op = op.lstrip('i') + case_insensitive = op.startswith("i") + op = op.lstrip("i") flags = re.IGNORECASE if case_insensitive else 0 - regex = r'%s' - if op == 'startswith': - regex = r'^%s' - elif op == 'endswith': - regex = r'%s$' - elif op == 'exact': - regex = r'^%s$' + regex = r"%s" + if op == "startswith": + regex = r"^%s" + elif op == "endswith": + regex = r"%s$" + elif op == "exact": + regex = r"^%s$" # escape unsafe characters which could lead to a re.error value = re.escape(value) @@ -135,14 +171,16 @@ class URLField(StringField): """ _URL_REGEX = LazyRegexCompiler( - r'^(?:[a-z0-9\.\-]*)://' # scheme is validated separately - r'(?:(?:[A-Z0-9](?:[A-Z0-9-_]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(? self.max_value: - self.error('Integer value is too large') + self.error("Integer value is too large") def prepare_query_value(self, op, value): if value is None: @@ -319,13 +365,13 @@ class LongField(BaseField): try: value = long(value) except (TypeError, ValueError): - self.error('%s could not be converted to long' % value) + self.error("%s could not be converted to long" % value) if self.min_value is not None and value < self.min_value: - self.error('Long value is too small') + self.error("Long value is too small") if self.max_value is not None and value > self.max_value: - self.error('Long value is too large') + self.error("Long value is too large") def prepare_query_value(self, op, value): if value is None: @@ -353,16 +399,16 @@ class FloatField(BaseField): try: value = float(value) except OverflowError: - self.error('The value is too large to be converted to float') + self.error("The value is too large to be converted to float") if not isinstance(value, float): - self.error('FloatField only accepts float and integer values') + self.error("FloatField only accepts float and integer values") if self.min_value is not None and value < self.min_value: - self.error('Float value is too small') + self.error("Float value is too small") if self.max_value is not None and value > self.max_value: - self.error('Float value is too large') + self.error("Float value is too large") def prepare_query_value(self, op, value): if value is None: @@ -379,8 +425,15 @@ class DecimalField(BaseField): .. versionadded:: 0.3 """ - def __init__(self, min_value=None, max_value=None, force_string=False, - precision=2, rounding=decimal.ROUND_HALF_UP, **kwargs): + def __init__( + self, + min_value=None, + max_value=None, + force_string=False, + precision=2, + rounding=decimal.ROUND_HALF_UP, + **kwargs + ): """ :param min_value: Validation rule for the minimum acceptable value. :param max_value: Validation rule for the maximum acceptable value. @@ -416,10 +469,12 @@ class DecimalField(BaseField): # Convert to string for python 2.6 before casting to Decimal try: - value = decimal.Decimal('%s' % value) + value = decimal.Decimal("%s" % value) except (TypeError, ValueError, decimal.InvalidOperation): return value - return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding) + return value.quantize( + decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding + ) def to_mongo(self, value): if value is None: @@ -435,13 +490,13 @@ class DecimalField(BaseField): try: value = decimal.Decimal(value) except (TypeError, ValueError, decimal.InvalidOperation) as exc: - self.error('Could not convert value to decimal: %s' % exc) + self.error("Could not convert value to decimal: %s" % exc) if self.min_value is not None and value < self.min_value: - self.error('Decimal value is too small') + self.error("Decimal value is too small") if self.max_value is not None and value > self.max_value: - self.error('Decimal value is too large') + self.error("Decimal value is too large") def prepare_query_value(self, op, value): return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value)) @@ -462,7 +517,7 @@ class BooleanField(BaseField): def validate(self, value): if not isinstance(value, bool): - self.error('BooleanField only accepts boolean values') + self.error("BooleanField only accepts boolean values") class DateTimeField(BaseField): @@ -514,26 +569,29 @@ class DateTimeField(BaseField): return None # split usecs, because they are not recognized by strptime. - if '.' in value: + if "." in value: try: - value, usecs = value.split('.') + value, usecs = value.split(".") usecs = int(usecs) except ValueError: return None else: usecs = 0 - kwargs = {'microsecond': usecs} + kwargs = {"microsecond": usecs} try: # Seconds are optional, so try converting seconds first. - return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M:%S')[:6], **kwargs) + return datetime.datetime( + *time.strptime(value, "%Y-%m-%d %H:%M:%S")[:6], **kwargs + ) except ValueError: try: # Try without seconds. - return datetime.datetime(*time.strptime(value, - '%Y-%m-%d %H:%M')[:5], **kwargs) + return datetime.datetime( + *time.strptime(value, "%Y-%m-%d %H:%M")[:5], **kwargs + ) except ValueError: # Try without hour/minutes/seconds. try: - return datetime.datetime(*time.strptime(value, - '%Y-%m-%d')[:3], **kwargs) + return datetime.datetime( + *time.strptime(value, "%Y-%m-%d")[:3], **kwargs + ) except ValueError: return None @@ -578,12 +636,12 @@ class ComplexDateTimeField(StringField): .. versionadded:: 0.5 """ - def __init__(self, separator=',', **kwargs): + def __init__(self, separator=",", **kwargs): """ :param separator: Allows to customize the separator used for storage (default ``,``) """ self.separator = separator - self.format = separator.join(['%Y', '%m', '%d', '%H', '%M', '%S', '%f']) + self.format = separator.join(["%Y", "%m", "%d", "%H", "%M", "%S", "%f"]) super(ComplexDateTimeField, self).__init__(**kwargs) def _convert_from_datetime(self, val): @@ -630,8 +688,7 @@ class ComplexDateTimeField(StringField): def validate(self, value): value = self.to_python(value) if not isinstance(value, datetime.datetime): - self.error('Only datetime objects may used in a ' - 'ComplexDateTimeField') + self.error("Only datetime objects may used in a ComplexDateTimeField") def to_python(self, value): original_value = value @@ -645,7 +702,9 @@ class ComplexDateTimeField(StringField): return self._convert_from_datetime(value) def prepare_query_value(self, op, value): - return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value)) + return super(ComplexDateTimeField, self).prepare_query_value( + op, self._convert_from_datetime(value) + ) class EmbeddedDocumentField(BaseField): @@ -656,11 +715,13 @@ class EmbeddedDocumentField(BaseField): def __init__(self, document_type, **kwargs): # XXX ValidationError raised outside of the "validate" method. if not ( - isinstance(document_type, six.string_types) or - issubclass(document_type, EmbeddedDocument) + isinstance(document_type, six.string_types) + or issubclass(document_type, EmbeddedDocument) ): - self.error('Invalid embedded document class provided to an ' - 'EmbeddedDocumentField') + self.error( + "Invalid embedded document class provided to an " + "EmbeddedDocumentField" + ) self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) @@ -676,15 +737,19 @@ class EmbeddedDocumentField(BaseField): if not issubclass(resolved_document_type, EmbeddedDocument): # Due to the late resolution of the document_type # There is a chance that it won't be an EmbeddedDocument (#1661) - self.error('Invalid embedded document class provided to an ' - 'EmbeddedDocumentField') + self.error( + "Invalid embedded document class provided to an " + "EmbeddedDocumentField" + ) self.document_type_obj = resolved_document_type return self.document_type_obj def to_python(self, value): if not isinstance(value, self.document_type): - return self.document_type._from_son(value, _auto_dereference=self._auto_dereference) + return self.document_type._from_son( + value, _auto_dereference=self._auto_dereference + ) return value def to_mongo(self, value, use_db_field=True, fields=None): @@ -698,8 +763,10 @@ class EmbeddedDocumentField(BaseField): """ # Using isinstance also works for subclasses of self.document if not isinstance(value, self.document_type): - self.error('Invalid embedded document instance provided to an ' - 'EmbeddedDocumentField') + self.error( + "Invalid embedded document instance provided to an " + "EmbeddedDocumentField" + ) self.document_type.validate(value, clean) def lookup_member(self, member_name): @@ -714,8 +781,10 @@ class EmbeddedDocumentField(BaseField): try: value = self.document_type._from_son(value) except ValueError: - raise InvalidQueryError("Querying the embedded document '%s' failed, due to an invalid query value" % - (self.document_type._class_name,)) + raise InvalidQueryError( + "Querying the embedded document '%s' failed, due to an invalid query value" + % (self.document_type._class_name,) + ) super(EmbeddedDocumentField, self).prepare_query_value(op, value) return self.to_mongo(value) @@ -732,11 +801,13 @@ class GenericEmbeddedDocumentField(BaseField): """ def prepare_query_value(self, op, value): - return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value)) + return super(GenericEmbeddedDocumentField, self).prepare_query_value( + op, self.to_mongo(value) + ) def to_python(self, value): if isinstance(value, dict): - doc_cls = get_document(value['_cls']) + doc_cls = get_document(value["_cls"]) value = doc_cls._from_son(value) return value @@ -744,12 +815,14 @@ class GenericEmbeddedDocumentField(BaseField): def validate(self, value, clean=True): if self.choices and isinstance(value, SON): for choice in self.choices: - if value['_cls'] == choice._class_name: + if value["_cls"] == choice._class_name: return True if not isinstance(value, EmbeddedDocument): - self.error('Invalid embedded document instance provided to an ' - 'GenericEmbeddedDocumentField') + self.error( + "Invalid embedded document instance provided to an " + "GenericEmbeddedDocumentField" + ) value.validate(clean=clean) @@ -766,8 +839,8 @@ class GenericEmbeddedDocumentField(BaseField): if document is None: return None data = document.to_mongo(use_db_field, fields) - if '_cls' not in data: - data['_cls'] = document._class_name + if "_cls" not in data: + data["_cls"] = document._class_name return data @@ -784,21 +857,21 @@ class DynamicField(BaseField): if isinstance(value, six.string_types): return value - if hasattr(value, 'to_mongo'): + if hasattr(value, "to_mongo"): cls = value.__class__ val = value.to_mongo(use_db_field, fields) # If we its a document thats not inherited add _cls if isinstance(value, Document): - val = {'_ref': value.to_dbref(), '_cls': cls.__name__} + val = {"_ref": value.to_dbref(), "_cls": cls.__name__} if isinstance(value, EmbeddedDocument): - val['_cls'] = cls.__name__ + val["_cls"] = cls.__name__ return val if not isinstance(value, (dict, list, tuple)): return value is_list = False - if not hasattr(value, 'items'): + if not hasattr(value, "items"): is_list = True value = {k: v for k, v in enumerate(value)} @@ -812,10 +885,10 @@ class DynamicField(BaseField): return value def to_python(self, value): - if isinstance(value, dict) and '_cls' in value: - doc_cls = get_document(value['_cls']) - if '_ref' in value: - value = doc_cls._get_db().dereference(value['_ref']) + if isinstance(value, dict) and "_cls" in value: + doc_cls = get_document(value["_cls"]) + if "_ref" in value: + value = doc_cls._get_db().dereference(value["_ref"]) return doc_cls._from_son(value) return super(DynamicField, self).to_python(value) @@ -829,7 +902,7 @@ class DynamicField(BaseField): return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) def validate(self, value, clean=True): - if hasattr(value, 'validate'): + if hasattr(value, "validate"): value.validate(clean=clean) @@ -845,7 +918,7 @@ class ListField(ComplexBaseField): def __init__(self, field=None, **kwargs): self.field = field - kwargs.setdefault('default', lambda: []) + kwargs.setdefault("default", lambda: []) super(ListField, self).__init__(**kwargs) def __get__(self, instance, owner): @@ -853,16 +926,19 @@ class ListField(ComplexBaseField): # Document class being used rather than a document object return self value = instance._data.get(self.name) - LazyReferenceField = _import_class('LazyReferenceField') - GenericLazyReferenceField = _import_class('GenericLazyReferenceField') - if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value: + LazyReferenceField = _import_class("LazyReferenceField") + GenericLazyReferenceField = _import_class("GenericLazyReferenceField") + if ( + isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) + and value + ): instance._data[self.name] = [self.field.build_lazyref(x) for x in value] return super(ListField, self).__get__(instance, owner) def validate(self, value): """Make sure that a list of valid fields is being used.""" if not isinstance(value, (list, tuple, BaseQuerySet)): - self.error('Only lists and tuples may be used in a list field') + self.error("Only lists and tuples may be used in a list field") super(ListField, self).validate(value) def prepare_query_value(self, op, value): @@ -871,10 +947,10 @@ class ListField(ComplexBaseField): # If the value is iterable and it's not a string nor a # BaseDocument, call prepare_query_value for each of its items. if ( - op in ('set', 'unset', None) and - hasattr(value, '__iter__') and - not isinstance(value, six.string_types) and - not isinstance(value, BaseDocument) + op in ("set", "unset", None) + and hasattr(value, "__iter__") + and not isinstance(value, six.string_types) + and not isinstance(value, BaseDocument) ): return [self.field.prepare_query_value(op, v) for v in value] @@ -925,17 +1001,18 @@ class SortedListField(ListField): _order_reverse = False def __init__(self, field, **kwargs): - if 'ordering' in kwargs.keys(): - self._ordering = kwargs.pop('ordering') - if 'reverse' in kwargs.keys(): - self._order_reverse = kwargs.pop('reverse') + if "ordering" in kwargs.keys(): + self._ordering = kwargs.pop("ordering") + if "reverse" in kwargs.keys(): + self._order_reverse = kwargs.pop("reverse") super(SortedListField, self).__init__(field, **kwargs) def to_mongo(self, value, use_db_field=True, fields=None): value = super(SortedListField, self).to_mongo(value, use_db_field, fields) if self._ordering is not None: - return sorted(value, key=itemgetter(self._ordering), - reverse=self._order_reverse) + return sorted( + value, key=itemgetter(self._ordering), reverse=self._order_reverse + ) return sorted(value, reverse=self._order_reverse) @@ -944,7 +1021,9 @@ def key_not_string(d): dictionary is not a string. """ for k, v in d.items(): - if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)): + if not isinstance(k, six.string_types) or ( + isinstance(v, dict) and key_not_string(v) + ): return True @@ -953,7 +1032,9 @@ def key_has_dot_or_dollar(d): dictionary contains a dot or a dollar sign. """ for k, v in d.items(): - if ('.' in k or k.startswith('$')) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): + if ("." in k or k.startswith("$")) or ( + isinstance(v, dict) and key_has_dot_or_dollar(v) + ): return True @@ -972,39 +1053,48 @@ class DictField(ComplexBaseField): self.field = field self._auto_dereference = False - kwargs.setdefault('default', lambda: {}) + kwargs.setdefault("default", lambda: {}) super(DictField, self).__init__(*args, **kwargs) def validate(self, value): """Make sure that a list of valid fields is being used.""" if not isinstance(value, dict): - self.error('Only dictionaries may be used in a DictField') + self.error("Only dictionaries may be used in a DictField") if key_not_string(value): - msg = ('Invalid dictionary key - documents must ' - 'have only string keys') + msg = "Invalid dictionary key - documents must have only string keys" self.error(msg) if key_has_dot_or_dollar(value): - self.error('Invalid dictionary key name - keys may not contain "."' - ' or startswith "$" characters') + self.error( + 'Invalid dictionary key name - keys may not contain "."' + ' or startswith "$" characters' + ) super(DictField, self).validate(value) def lookup_member(self, member_name): return DictField(db_field=member_name) def prepare_query_value(self, op, value): - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact'] + match_operators = [ + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "exact", + "iexact", + ] if op in match_operators and isinstance(value, six.string_types): return StringField().prepare_query_value(op, value) - if hasattr(self.field, 'field'): # Used for instance when using DictField(ListField(IntField())) - if op in ('set', 'unset') and isinstance(value, dict): + if hasattr( + self.field, "field" + ): # Used for instance when using DictField(ListField(IntField())) + if op in ("set", "unset") and isinstance(value, dict): return { - k: self.field.prepare_query_value(op, v) - for k, v in value.items() + k: self.field.prepare_query_value(op, v) for k, v in value.items() } return self.field.prepare_query_value(op, value) @@ -1022,8 +1112,7 @@ class MapField(DictField): def __init__(self, field=None, *args, **kwargs): # XXX ValidationError raised outside of the "validate" method. if not isinstance(field, BaseField): - self.error('Argument to MapField constructor must be a valid ' - 'field') + self.error("Argument to MapField constructor must be a valid field") super(MapField, self).__init__(field=field, *args, **kwargs) @@ -1069,8 +1158,9 @@ class ReferenceField(BaseField): .. versionchanged:: 0.5 added `reverse_delete_rule` """ - def __init__(self, document_type, dbref=False, - reverse_delete_rule=DO_NOTHING, **kwargs): + def __init__( + self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs + ): """Initialises the Reference Field. :param dbref: Store the reference as :class:`~pymongo.dbref.DBRef` @@ -1083,12 +1173,13 @@ class ReferenceField(BaseField): :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. """ # XXX ValidationError raised outside of the "validate" method. - if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + if not isinstance(document_type, six.string_types) and not issubclass( + document_type, Document ): - self.error('Argument to ReferenceField constructor must be a ' - 'document class or a string') + self.error( + "Argument to ReferenceField constructor must be a " + "document class or a string" + ) self.dbref = dbref self.document_type_obj = document_type @@ -1115,14 +1206,14 @@ class ReferenceField(BaseField): auto_dereference = instance._fields[self.name]._auto_dereference # Dereference DBRefs if auto_dereference and isinstance(value, DBRef): - if hasattr(value, 'cls'): + if hasattr(value, "cls"): # Dereference using the class type specified in the reference cls = get_document(value.cls) else: cls = self.document_type dereferenced = cls._get_db().dereference(value) if dereferenced is None: - raise DoesNotExist('Trying to dereference unknown document %s' % value) + raise DoesNotExist("Trying to dereference unknown document %s" % value) else: instance._data[self.name] = cls._from_son(dereferenced) @@ -1140,8 +1231,10 @@ class ReferenceField(BaseField): # XXX ValidationError raised outside of the "validate" method. if id_ is None: - self.error('You can only reference documents once they have' - ' been saved to the database') + self.error( + "You can only reference documents once they have" + " been saved to the database" + ) # Use the attributes from the document instance, so that they # override the attributes of this field's document type @@ -1150,11 +1243,11 @@ class ReferenceField(BaseField): id_ = document cls = self.document_type - id_field_name = cls._meta['id_field'] + id_field_name = cls._meta["id_field"] id_field = cls._fields[id_field_name] id_ = id_field.to_mongo(id_) - if self.document_type._meta.get('abstract'): + if self.document_type._meta.get("abstract"): collection = cls._get_collection_name() return DBRef(collection, id_, cls=cls._class_name) elif self.dbref: @@ -1165,8 +1258,9 @@ class ReferenceField(BaseField): def to_python(self, value): """Convert a MongoDB-compatible type to a Python type.""" - if (not self.dbref and - not isinstance(value, (DBRef, Document, EmbeddedDocument))): + if not self.dbref and not isinstance( + value, (DBRef, Document, EmbeddedDocument) + ): collection = self.document_type._get_collection_name() value = DBRef(collection, self.document_type.id.to_python(value)) return value @@ -1179,11 +1273,15 @@ class ReferenceField(BaseField): def validate(self, value): if not isinstance(value, (self.document_type, LazyReference, DBRef, ObjectId)): - self.error('A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents') + self.error( + "A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents" + ) if isinstance(value, Document) and value.id is None: - self.error('You can only reference documents once they have been ' - 'saved to the database') + self.error( + "You can only reference documents once they have been " + "saved to the database" + ) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -1206,12 +1304,13 @@ class CachedReferenceField(BaseField): fields = [] # XXX ValidationError raised outside of the "validate" method. - if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + if not isinstance(document_type, six.string_types) and not issubclass( + document_type, Document ): - self.error('Argument to CachedReferenceField constructor must be a' - ' document class or a string') + self.error( + "Argument to CachedReferenceField constructor must be a" + " document class or a string" + ) self.auto_sync = auto_sync self.document_type_obj = document_type @@ -1221,15 +1320,14 @@ class CachedReferenceField(BaseField): def start_listener(self): from mongoengine import signals - signals.post_save.connect(self.on_document_pre_save, - sender=self.document_type) + signals.post_save.connect(self.on_document_pre_save, sender=self.document_type) def on_document_pre_save(self, sender, document, created, **kwargs): if created: return None update_kwargs = { - 'set__%s__%s' % (self.name, key): val + "set__%s__%s" % (self.name, key): val for key, val in document._delta()[0].items() if key in self.fields } @@ -1237,15 +1335,15 @@ class CachedReferenceField(BaseField): filter_kwargs = {} filter_kwargs[self.name] = document - self.owner_document.objects( - **filter_kwargs).update(**update_kwargs) + self.owner_document.objects(**filter_kwargs).update(**update_kwargs) def to_python(self, value): if isinstance(value, dict): collection = self.document_type._get_collection_name() - value = DBRef( - collection, self.document_type.id.to_python(value['_id'])) - return self.document_type._from_son(self.document_type._get_db().dereference(value)) + value = DBRef(collection, self.document_type.id.to_python(value["_id"])) + return self.document_type._from_son( + self.document_type._get_db().dereference(value) + ) return value @@ -1271,14 +1369,14 @@ class CachedReferenceField(BaseField): if auto_dereference and isinstance(value, DBRef): dereferenced = self.document_type._get_db().dereference(value) if dereferenced is None: - raise DoesNotExist('Trying to dereference unknown document %s' % value) + raise DoesNotExist("Trying to dereference unknown document %s" % value) else: instance._data[self.name] = self.document_type._from_son(dereferenced) return super(CachedReferenceField, self).__get__(instance, owner) def to_mongo(self, document, use_db_field=True, fields=None): - id_field_name = self.document_type._meta['id_field'] + id_field_name = self.document_type._meta["id_field"] id_field = self.document_type._fields[id_field_name] # XXX ValidationError raised outside of the "validate" method. @@ -1286,14 +1384,14 @@ class CachedReferenceField(BaseField): # We need the id from the saved object to create the DBRef id_ = document.pk if id_ is None: - self.error('You can only reference documents once they have' - ' been saved to the database') + self.error( + "You can only reference documents once they have" + " been saved to the database" + ) else: - self.error('Only accept a document object') + self.error("Only accept a document object") - value = SON(( - ('_id', id_field.to_mongo(id_)), - )) + value = SON((("_id", id_field.to_mongo(id_)),)) if fields: new_fields = [f for f in self.fields if f in fields] @@ -1310,9 +1408,11 @@ class CachedReferenceField(BaseField): # XXX ValidationError raised outside of the "validate" method. if isinstance(value, Document): if value.pk is None: - self.error('You can only reference documents once they have' - ' been saved to the database') - value_dict = {'_id': value.pk} + self.error( + "You can only reference documents once they have" + " been saved to the database" + ) + value_dict = {"_id": value.pk} for field in self.fields: value_dict.update({field: value[field]}) @@ -1322,11 +1422,13 @@ class CachedReferenceField(BaseField): def validate(self, value): if not isinstance(value, self.document_type): - self.error('A CachedReferenceField only accepts documents') + self.error("A CachedReferenceField only accepts documents") if isinstance(value, Document) and value.id is None: - self.error('You can only reference documents once they have been ' - 'saved to the database') + self.error( + "You can only reference documents once they have been " + "saved to the database" + ) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -1336,7 +1438,7 @@ class CachedReferenceField(BaseField): Sync all cached fields on demand. Caution: this operation may be slower. """ - update_key = 'set__%s' % self.name + update_key = "set__%s" % self.name for doc in self.document_type.objects: filter_kwargs = {} @@ -1345,8 +1447,7 @@ class CachedReferenceField(BaseField): update_kwargs = {} update_kwargs[update_key] = doc - self.owner_document.objects( - **filter_kwargs).update(**update_kwargs) + self.owner_document.objects(**filter_kwargs).update(**update_kwargs) class GenericReferenceField(BaseField): @@ -1370,7 +1471,7 @@ class GenericReferenceField(BaseField): """ def __init__(self, *args, **kwargs): - choices = kwargs.pop('choices', None) + choices = kwargs.pop("choices", None) super(GenericReferenceField, self).__init__(*args, **kwargs) self.choices = [] # Keep the choices as a list of allowed Document class names @@ -1383,14 +1484,16 @@ class GenericReferenceField(BaseField): else: # XXX ValidationError raised outside of the "validate" # method. - self.error('Invalid choices provided: must be a list of' - 'Document subclasses and/or six.string_typess') + self.error( + "Invalid choices provided: must be a list of" + "Document subclasses and/or six.string_typess" + ) def _validate_choices(self, value): if isinstance(value, dict): # If the field has not been dereferenced, it is still a dict # of class and DBRef - value = value.get('_cls') + value = value.get("_cls") elif isinstance(value, Document): value = value._class_name super(GenericReferenceField, self)._validate_choices(value) @@ -1405,7 +1508,7 @@ class GenericReferenceField(BaseField): if auto_dereference and isinstance(value, (dict, SON)): dereferenced = self.dereference(value) if dereferenced is None: - raise DoesNotExist('Trying to dereference unknown document %s' % value) + raise DoesNotExist("Trying to dereference unknown document %s" % value) else: instance._data[self.name] = dereferenced @@ -1413,20 +1516,22 @@ class GenericReferenceField(BaseField): def validate(self, value): if not isinstance(value, (Document, DBRef, dict, SON)): - self.error('GenericReferences can only contain documents') + self.error("GenericReferences can only contain documents") if isinstance(value, (dict, SON)): - if '_ref' not in value or '_cls' not in value: - self.error('GenericReferences can only contain documents') + if "_ref" not in value or "_cls" not in value: + self.error("GenericReferences can only contain documents") # We need the id from the saved object to create the DBRef elif isinstance(value, Document) and value.id is None: - self.error('You can only reference documents once they have been' - ' saved to the database') + self.error( + "You can only reference documents once they have been" + " saved to the database" + ) def dereference(self, value): - doc_cls = get_document(value['_cls']) - reference = value['_ref'] + doc_cls = get_document(value["_cls"]) + reference = value["_ref"] doc = doc_cls._get_db().dereference(reference) if doc is not None: doc = doc_cls._from_son(doc) @@ -1439,7 +1544,7 @@ class GenericReferenceField(BaseField): if isinstance(document, (dict, SON, ObjectId, DBRef)): return document - id_field_name = document.__class__._meta['id_field'] + id_field_name = document.__class__._meta["id_field"] id_field = document.__class__._fields[id_field_name] if isinstance(document, Document): @@ -1447,18 +1552,17 @@ class GenericReferenceField(BaseField): id_ = document.id if id_ is None: # XXX ValidationError raised outside of the "validate" method. - self.error('You can only reference documents once they have' - ' been saved to the database') + self.error( + "You can only reference documents once they have" + " been saved to the database" + ) else: id_ = document id_ = id_field.to_mongo(id_) collection = document._get_collection_name() ref = DBRef(collection, id_) - return SON(( - ('_cls', document._class_name), - ('_ref', ref) - )) + return SON((("_cls", document._class_name), ("_ref", ref))) def prepare_query_value(self, op, value): if value is None: @@ -1485,18 +1589,18 @@ class BinaryField(BaseField): def validate(self, value): if not isinstance(value, (six.binary_type, Binary)): - self.error('BinaryField only accepts instances of ' - '(%s, %s, Binary)' % ( - six.binary_type.__name__, Binary.__name__)) + self.error( + "BinaryField only accepts instances of " + "(%s, %s, Binary)" % (six.binary_type.__name__, Binary.__name__) + ) if self.max_bytes is not None and len(value) > self.max_bytes: - self.error('Binary value is too long') + self.error("Binary value is too long") def prepare_query_value(self, op, value): if value is None: return value - return super(BinaryField, self).prepare_query_value( - op, self.to_mongo(value)) + return super(BinaryField, self).prepare_query_value(op, self.to_mongo(value)) class GridFSError(Exception): @@ -1513,10 +1617,14 @@ class GridFSProxy(object): _fs = None - def __init__(self, grid_id=None, key=None, - instance=None, - db_alias=DEFAULT_CONNECTION_NAME, - collection_name='fs'): + def __init__( + self, + grid_id=None, + key=None, + instance=None, + db_alias=DEFAULT_CONNECTION_NAME, + collection_name="fs", + ): self.grid_id = grid_id # Store GridFS id for file self.key = key self.instance = instance @@ -1526,8 +1634,16 @@ class GridFSProxy(object): self.gridout = None def __getattr__(self, name): - attrs = ('_fs', 'grid_id', 'key', 'instance', 'db_alias', - 'collection_name', 'newfile', 'gridout') + attrs = ( + "_fs", + "grid_id", + "key", + "instance", + "db_alias", + "collection_name", + "newfile", + "gridout", + ) if name in attrs: return self.__getattribute__(name) obj = self.get() @@ -1545,7 +1661,7 @@ class GridFSProxy(object): def __getstate__(self): self_dict = self.__dict__ - self_dict['_fs'] = None + self_dict["_fs"] = None return self_dict def __copy__(self): @@ -1557,18 +1673,20 @@ class GridFSProxy(object): return self.__copy__() def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self.grid_id) + return "<%s: %s>" % (self.__class__.__name__, self.grid_id) def __str__(self): gridout = self.get() - filename = getattr(gridout, 'filename') if gridout else '' - return '<%s: %s (%s)>' % (self.__class__.__name__, filename, self.grid_id) + filename = getattr(gridout, "filename") if gridout else "" + return "<%s: %s (%s)>" % (self.__class__.__name__, filename, self.grid_id) def __eq__(self, other): if isinstance(other, GridFSProxy): - return ((self.grid_id == other.grid_id) and - (self.collection_name == other.collection_name) and - (self.db_alias == other.db_alias)) + return ( + (self.grid_id == other.grid_id) + and (self.collection_name == other.collection_name) + and (self.db_alias == other.db_alias) + ) else: return False @@ -1578,8 +1696,7 @@ class GridFSProxy(object): @property def fs(self): if not self._fs: - self._fs = gridfs.GridFS( - get_db(self.db_alias), self.collection_name) + self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name) return self._fs def get(self, grid_id=None): @@ -1604,16 +1721,20 @@ class GridFSProxy(object): def put(self, file_obj, **kwargs): if self.grid_id: - raise GridFSError('This document already has a file. Either delete ' - 'it or call replace to overwrite it') + raise GridFSError( + "This document already has a file. Either delete " + "it or call replace to overwrite it" + ) self.grid_id = self.fs.put(file_obj, **kwargs) self._mark_as_changed() def write(self, string): if self.grid_id: if not self.newfile: - raise GridFSError('This document already has a file. Either ' - 'delete it or call replace to overwrite it') + raise GridFSError( + "This document already has a file. Either " + "delete it or call replace to overwrite it" + ) else: self.new_file() self.newfile.write(string) @@ -1632,7 +1753,7 @@ class GridFSProxy(object): try: return gridout.read(size) except Exception: - return '' + return "" def delete(self): # Delete file from GridFS, FileField still remains @@ -1662,10 +1783,12 @@ class FileField(BaseField): .. versionchanged:: 0.5 added optional size param for read .. versionchanged:: 0.6 added db_alias for multidb support """ + proxy_class = GridFSProxy - def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs', - **kwargs): + def __init__( + self, db_alias=DEFAULT_CONNECTION_NAME, collection_name="fs", **kwargs + ): super(FileField, self).__init__(**kwargs) self.collection_name = collection_name self.db_alias = db_alias @@ -1688,9 +1811,8 @@ class FileField(BaseField): def __set__(self, instance, value): key = self.name if ( - (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or - isinstance(value, (six.binary_type, six.string_types)) - ): + hasattr(value, "read") and not isinstance(value, GridFSProxy) + ) or isinstance(value, (six.binary_type, six.string_types)): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) # If a file already exists, delete it @@ -1701,8 +1823,7 @@ class FileField(BaseField): pass # Create a new proxy object as we don't already have one - instance._data[key] = self.get_proxy_obj( - key=key, instance=instance) + instance._data[key] = self.get_proxy_obj(key=key, instance=instance) instance._data[key].put(value) else: instance._data[key] = value @@ -1715,9 +1836,12 @@ class FileField(BaseField): if collection_name is None: collection_name = self.collection_name - return self.proxy_class(key=key, instance=instance, - db_alias=db_alias, - collection_name=collection_name) + return self.proxy_class( + key=key, + instance=instance, + db_alias=db_alias, + collection_name=collection_name, + ) def to_mongo(self, value): # Store the GridFS file id in MongoDB @@ -1727,16 +1851,16 @@ class FileField(BaseField): def to_python(self, value): if value is not None: - return self.proxy_class(value, - collection_name=self.collection_name, - db_alias=self.db_alias) + return self.proxy_class( + value, collection_name=self.collection_name, db_alias=self.db_alias + ) def validate(self, value): if value.grid_id is not None: if not isinstance(value, self.proxy_class): - self.error('FileField only accepts GridFSProxy values') + self.error("FileField only accepts GridFSProxy values") if not isinstance(value.grid_id, ObjectId): - self.error('Invalid GridFSProxy value') + self.error("Invalid GridFSProxy value") class ImageGridFsProxy(GridFSProxy): @@ -1753,52 +1877,51 @@ class ImageGridFsProxy(GridFSProxy): """ field = self.instance._fields[self.key] # Handle nested fields - if hasattr(field, 'field') and isinstance(field.field, FileField): + if hasattr(field, "field") and isinstance(field.field, FileField): field = field.field try: img = Image.open(file_obj) img_format = img.format except Exception as e: - raise ValidationError('Invalid image: %s' % e) + raise ValidationError("Invalid image: %s" % e) # Progressive JPEG # TODO: fixme, at least unused, at worst bad implementation - progressive = img.info.get('progressive') or False + progressive = img.info.get("progressive") or False - if (kwargs.get('progressive') and - isinstance(kwargs.get('progressive'), bool) and - img_format == 'JPEG'): + if ( + kwargs.get("progressive") + and isinstance(kwargs.get("progressive"), bool) + and img_format == "JPEG" + ): progressive = True else: progressive = False - if (field.size and (img.size[0] > field.size['width'] or - img.size[1] > field.size['height'])): + if field.size and ( + img.size[0] > field.size["width"] or img.size[1] > field.size["height"] + ): size = field.size - if size['force']: - img = ImageOps.fit(img, - (size['width'], - size['height']), - Image.ANTIALIAS) + if size["force"]: + img = ImageOps.fit( + img, (size["width"], size["height"]), Image.ANTIALIAS + ) else: - img.thumbnail((size['width'], - size['height']), - Image.ANTIALIAS) + img.thumbnail((size["width"], size["height"]), Image.ANTIALIAS) thumbnail = None if field.thumbnail_size: size = field.thumbnail_size - if size['force']: + if size["force"]: thumbnail = ImageOps.fit( - img, (size['width'], size['height']), Image.ANTIALIAS) + img, (size["width"], size["height"]), Image.ANTIALIAS + ) else: thumbnail = img.copy() - thumbnail.thumbnail((size['width'], - size['height']), - Image.ANTIALIAS) + thumbnail.thumbnail((size["width"], size["height"]), Image.ANTIALIAS) if thumbnail: thumb_id = self._put_thumbnail(thumbnail, img_format, progressive) @@ -1811,12 +1934,9 @@ class ImageGridFsProxy(GridFSProxy): img.save(io, img_format, progressive=progressive) io.seek(0) - return super(ImageGridFsProxy, self).put(io, - width=w, - height=h, - format=img_format, - thumbnail_id=thumb_id, - **kwargs) + return super(ImageGridFsProxy, self).put( + io, width=w, height=h, format=img_format, thumbnail_id=thumb_id, **kwargs + ) def delete(self, *args, **kwargs): # deletes thumbnail @@ -1833,10 +1953,7 @@ class ImageGridFsProxy(GridFSProxy): thumbnail.save(io, format, progressive=progressive) io.seek(0) - return self.fs.put(io, width=w, - height=h, - format=format, - **kwargs) + return self.fs.put(io, width=w, height=h, format=format, **kwargs) @property def size(self): @@ -1888,32 +2005,30 @@ class ImageField(FileField): .. versionadded:: 0.6 """ + proxy_class = ImageGridFsProxy - def __init__(self, size=None, thumbnail_size=None, - collection_name='images', **kwargs): + def __init__( + self, size=None, thumbnail_size=None, collection_name="images", **kwargs + ): if not Image: - raise ImproperlyConfigured('PIL library was not found') + raise ImproperlyConfigured("PIL library was not found") - params_size = ('width', 'height', 'force') - extra_args = { - 'size': size, - 'thumbnail_size': thumbnail_size - } + params_size = ("width", "height", "force") + extra_args = {"size": size, "thumbnail_size": thumbnail_size} for att_name, att in extra_args.items(): value = None if isinstance(att, (tuple, list)): if six.PY3: - value = dict(itertools.zip_longest(params_size, att, - fillvalue=None)) + value = dict( + itertools.zip_longest(params_size, att, fillvalue=None) + ) else: value = dict(map(None, params_size, att)) setattr(self, att_name, value) - super(ImageField, self).__init__( - collection_name=collection_name, - **kwargs) + super(ImageField, self).__init__(collection_name=collection_name, **kwargs) class SequenceField(BaseField): @@ -1947,15 +2062,24 @@ class SequenceField(BaseField): """ _auto_gen = True - COLLECTION_NAME = 'mongoengine.counters' + COLLECTION_NAME = "mongoengine.counters" VALUE_DECORATOR = int - def __init__(self, collection_name=None, db_alias=None, sequence_name=None, - value_decorator=None, *args, **kwargs): + def __init__( + self, + collection_name=None, + db_alias=None, + sequence_name=None, + value_decorator=None, + *args, + **kwargs + ): self.collection_name = collection_name or self.COLLECTION_NAME self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name - self.value_decorator = value_decorator if callable(value_decorator) else self.VALUE_DECORATOR + self.value_decorator = ( + value_decorator if callable(value_decorator) else self.VALUE_DECORATOR + ) super(SequenceField, self).__init__(*args, **kwargs) def generate(self): @@ -1963,15 +2087,16 @@ class SequenceField(BaseField): Generate and Increment the counter """ sequence_name = self.get_sequence_name() - sequence_id = '%s.%s' % (sequence_name, self.name) + sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] counter = collection.find_one_and_update( - filter={'_id': sequence_id}, - update={'$inc': {'next': 1}}, + filter={"_id": sequence_id}, + update={"$inc": {"next": 1}}, return_document=ReturnDocument.AFTER, - upsert=True) - return self.value_decorator(counter['next']) + upsert=True, + ) + return self.value_decorator(counter["next"]) def set_next_value(self, value): """Helper method to set the next sequence value""" @@ -1982,8 +2107,9 @@ class SequenceField(BaseField): filter={"_id": sequence_id}, update={"$set": {"next": value}}, return_document=ReturnDocument.AFTER, - upsert=True) - return self.value_decorator(counter['next']) + upsert=True, + ) + return self.value_decorator(counter["next"]) def get_next_value(self): """Helper method to get the next value for previewing. @@ -1992,12 +2118,12 @@ class SequenceField(BaseField): as it is only fixed on set. """ sequence_name = self.get_sequence_name() - sequence_id = '%s.%s' % (sequence_name, self.name) + sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] - data = collection.find_one({'_id': sequence_id}) + data = collection.find_one({"_id": sequence_id}) if data: - return self.value_decorator(data['next'] + 1) + return self.value_decorator(data["next"] + 1) return self.value_decorator(1) @@ -2005,11 +2131,14 @@ class SequenceField(BaseField): if self.sequence_name: return self.sequence_name owner = self.owner_document - if issubclass(owner, Document) and not owner._meta.get('abstract'): + if issubclass(owner, Document) and not owner._meta.get("abstract"): return owner._get_collection_name() else: - return ''.join('_%s' % c if c.isupper() else c - for c in owner._class_name).strip('_').lower() + return ( + "".join("_%s" % c if c.isupper() else c for c in owner._class_name) + .strip("_") + .lower() + ) def __get__(self, instance, owner): value = super(SequenceField, self).__get__(instance, owner) @@ -2046,6 +2175,7 @@ class UUIDField(BaseField): .. versionadded:: 0.6 """ + _binary = None def __init__(self, binary=True, **kwargs): @@ -2090,7 +2220,7 @@ class UUIDField(BaseField): try: uuid.UUID(value) except (ValueError, TypeError, AttributeError) as exc: - self.error('Could not convert to UUID: %s' % exc) + self.error("Could not convert to UUID: %s" % exc) class GeoPointField(BaseField): @@ -2109,16 +2239,14 @@ class GeoPointField(BaseField): def validate(self, value): """Make sure that a geo-value is of type (x, y)""" if not isinstance(value, (list, tuple)): - self.error('GeoPointField can only accept tuples or lists ' - 'of (x, y)') + self.error("GeoPointField can only accept tuples or lists of (x, y)") if not len(value) == 2: - self.error('Value (%s) must be a two-dimensional point' % - repr(value)) - elif (not isinstance(value[0], (float, int)) or - not isinstance(value[1], (float, int))): - self.error( - 'Both values (%s) in point must be float or int' % repr(value)) + self.error("Value (%s) must be a two-dimensional point" % repr(value)) + elif not isinstance(value[0], (float, int)) or not isinstance( + value[1], (float, int) + ): + self.error("Both values (%s) in point must be float or int" % repr(value)) class PointField(GeoJsonBaseField): @@ -2138,7 +2266,8 @@ class PointField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = 'Point' + + _type = "Point" class LineStringField(GeoJsonBaseField): @@ -2157,7 +2286,8 @@ class LineStringField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = 'LineString' + + _type = "LineString" class PolygonField(GeoJsonBaseField): @@ -2179,7 +2309,8 @@ class PolygonField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = 'Polygon' + + _type = "Polygon" class MultiPointField(GeoJsonBaseField): @@ -2199,7 +2330,8 @@ class MultiPointField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = 'MultiPoint' + + _type = "MultiPoint" class MultiLineStringField(GeoJsonBaseField): @@ -2219,7 +2351,8 @@ class MultiLineStringField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = 'MultiLineString' + + _type = "MultiLineString" class MultiPolygonField(GeoJsonBaseField): @@ -2246,7 +2379,8 @@ class MultiPolygonField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = 'MultiPolygon' + + _type = "MultiPolygon" class LazyReferenceField(BaseField): @@ -2260,8 +2394,14 @@ class LazyReferenceField(BaseField): .. versionadded:: 0.15 """ - def __init__(self, document_type, passthrough=False, dbref=False, - reverse_delete_rule=DO_NOTHING, **kwargs): + def __init__( + self, + document_type, + passthrough=False, + dbref=False, + reverse_delete_rule=DO_NOTHING, + **kwargs + ): """Initialises the Reference Field. :param dbref: Store the reference as :class:`~pymongo.dbref.DBRef` @@ -2274,12 +2414,13 @@ class LazyReferenceField(BaseField): document. Note this only work getting field (not setting or deleting). """ # XXX ValidationError raised outside of the "validate" method. - if ( - not isinstance(document_type, six.string_types) and - not issubclass(document_type, Document) + if not isinstance(document_type, six.string_types) and not issubclass( + document_type, Document ): - self.error('Argument to LazyReferenceField constructor must be a ' - 'document class or a string') + self.error( + "Argument to LazyReferenceField constructor must be a " + "document class or a string" + ) self.dbref = dbref self.passthrough = passthrough @@ -2299,15 +2440,23 @@ class LazyReferenceField(BaseField): def build_lazyref(self, value): if isinstance(value, LazyReference): if value.passthrough != self.passthrough: - value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) + value = LazyReference( + value.document_type, value.pk, passthrough=self.passthrough + ) elif value is not None: if isinstance(value, self.document_type): - value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough) + value = LazyReference( + self.document_type, value.pk, passthrough=self.passthrough + ) elif isinstance(value, DBRef): - value = LazyReference(self.document_type, value.id, passthrough=self.passthrough) + value = LazyReference( + self.document_type, value.id, passthrough=self.passthrough + ) else: # value is the primary key of the referenced document - value = LazyReference(self.document_type, value, passthrough=self.passthrough) + value = LazyReference( + self.document_type, value, passthrough=self.passthrough + ) return value def __get__(self, instance, owner): @@ -2332,7 +2481,7 @@ class LazyReferenceField(BaseField): else: # value is the primary key of the referenced document pk = value - id_field_name = self.document_type._meta['id_field'] + id_field_name = self.document_type._meta["id_field"] id_field = self.document_type._fields[id_field_name] pk = id_field.to_mongo(pk) if self.dbref: @@ -2343,7 +2492,7 @@ class LazyReferenceField(BaseField): def validate(self, value): if isinstance(value, LazyReference): if value.collection != self.document_type._get_collection_name(): - self.error('Reference must be on a `%s` document.' % self.document_type) + self.error("Reference must be on a `%s` document." % self.document_type) pk = value.pk elif isinstance(value, self.document_type): pk = value.pk @@ -2355,7 +2504,7 @@ class LazyReferenceField(BaseField): pk = value.id else: # value is the primary key of the referenced document - id_field_name = self.document_type._meta['id_field'] + id_field_name = self.document_type._meta["id_field"] id_field = getattr(self.document_type, id_field_name) pk = value try: @@ -2364,11 +2513,15 @@ class LazyReferenceField(BaseField): self.error( "value should be `{0}` document, LazyReference or DBRef on `{0}` " "or `{0}`'s primary key (i.e. `{1}`)".format( - self.document_type.__name__, type(id_field).__name__)) + self.document_type.__name__, type(id_field).__name__ + ) + ) if pk is None: - self.error('You can only reference documents once they have been ' - 'saved to the database') + self.error( + "You can only reference documents once they have been " + "saved to the database" + ) def prepare_query_value(self, op, value): if value is None: @@ -2399,7 +2552,7 @@ class GenericLazyReferenceField(GenericReferenceField): """ def __init__(self, *args, **kwargs): - self.passthrough = kwargs.pop('passthrough', False) + self.passthrough = kwargs.pop("passthrough", False) super(GenericLazyReferenceField, self).__init__(*args, **kwargs) def _validate_choices(self, value): @@ -2410,12 +2563,20 @@ class GenericLazyReferenceField(GenericReferenceField): def build_lazyref(self, value): if isinstance(value, LazyReference): if value.passthrough != self.passthrough: - value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough) + value = LazyReference( + value.document_type, value.pk, passthrough=self.passthrough + ) elif value is not None: if isinstance(value, (dict, SON)): - value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough) + value = LazyReference( + get_document(value["_cls"]), + value["_ref"].id, + passthrough=self.passthrough, + ) elif isinstance(value, Document): - value = LazyReference(type(value), value.pk, passthrough=self.passthrough) + value = LazyReference( + type(value), value.pk, passthrough=self.passthrough + ) return value def __get__(self, instance, owner): @@ -2430,8 +2591,10 @@ class GenericLazyReferenceField(GenericReferenceField): def validate(self, value): if isinstance(value, LazyReference) and value.pk is None: - self.error('You can only reference documents once they have been' - ' saved to the database') + self.error( + "You can only reference documents once they have been" + " saved to the database" + ) return super(GenericLazyReferenceField, self).validate(value) def to_mongo(self, document): @@ -2439,9 +2602,16 @@ class GenericLazyReferenceField(GenericReferenceField): return None if isinstance(document, LazyReference): - return SON(( - ('_cls', document.document_type._class_name), - ('_ref', DBRef(document.document_type._get_collection_name(), document.pk)) - )) + return SON( + ( + ("_cls", document.document_type._class_name), + ( + "_ref", + DBRef( + document.document_type._get_collection_name(), document.pk + ), + ), + ) + ) else: return super(GenericLazyReferenceField, self).to_mongo(document) diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index b20ebc1e..5d437fef 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -15,5 +15,5 @@ def get_mongodb_version(): :return: tuple(int, int) """ - version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2) + version_list = get_connection().server_info()["versionArray"][:2] # e.g: (3, 2) return tuple(version_list) diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py index f66c038e..80c0661b 100644 --- a/mongoengine/pymongo_support.py +++ b/mongoengine/pymongo_support.py @@ -27,6 +27,6 @@ def list_collection_names(db, include_system_collections=False): collections = db.collection_names() if not include_system_collections: - collections = [c for c in collections if not c.startswith('system.')] + collections = [c for c in collections if not c.startswith("system.")] return collections diff --git a/mongoengine/queryset/__init__.py b/mongoengine/queryset/__init__.py index 5219c39e..f041d07b 100644 --- a/mongoengine/queryset/__init__.py +++ b/mongoengine/queryset/__init__.py @@ -7,11 +7,22 @@ from mongoengine.queryset.visitor import * # Expose just the public subset of all imported objects and constants. __all__ = ( - 'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager', - 'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL', - + "QuerySet", + "QuerySetNoCache", + "Q", + "queryset_manager", + "QuerySetManager", + "QueryFieldList", + "DO_NOTHING", + "NULLIFY", + "CASCADE", + "DENY", + "PULL", # Errors that might be related to a queryset, mostly here for backward # compatibility - 'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned', - 'NotUniqueError', 'OperationError', + "DoesNotExist", + "InvalidQueryError", + "MultipleObjectsReturned", + "NotUniqueError", + "OperationError", ) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 85616c85..78e85399 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -20,14 +20,18 @@ from mongoengine.base import get_document from mongoengine.common import _import_class from mongoengine.connection import get_db from mongoengine.context_managers import set_write_concern, switch_db -from mongoengine.errors import (InvalidQueryError, LookUpError, - NotUniqueError, OperationError) +from mongoengine.errors import ( + InvalidQueryError, + LookUpError, + NotUniqueError, + OperationError, +) from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode -__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') +__all__ = ("BaseQuerySet", "DO_NOTHING", "NULLIFY", "CASCADE", "DENY", "PULL") # Delete rules DO_NOTHING = 0 @@ -41,6 +45,7 @@ class BaseQuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. """ + __dereference = False _auto_dereference = True @@ -66,13 +71,12 @@ class BaseQuerySet(object): # If inheritance is allowed, only return instances and instances of # subclasses of the class being used - if document._meta.get('allow_inheritance') is True: + if document._meta.get("allow_inheritance") is True: if len(self._document._subclasses) == 1: - self._initial_query = {'_cls': self._document._subclasses[0]} + self._initial_query = {"_cls": self._document._subclasses[0]} else: - self._initial_query = { - '_cls': {'$in': self._document._subclasses}} - self._loaded_fields = QueryFieldList(always_include=['_cls']) + self._initial_query = {"_cls": {"$in": self._document._subclasses}} + self._loaded_fields = QueryFieldList(always_include=["_cls"]) self._cursor_obj = None self._limit = None @@ -83,8 +87,7 @@ class BaseQuerySet(object): self._max_time_ms = None self._comment = None - def __call__(self, q_obj=None, class_check=True, read_preference=None, - **query): + def __call__(self, q_obj=None, class_check=True, read_preference=None, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -102,8 +105,10 @@ class BaseQuerySet(object): if q_obj: # make sure proper query object is passed if not isinstance(q_obj, QNode): - msg = ('Not a query object: %s. ' - 'Did you intend to use key=value?' % q_obj) + msg = ( + "Not a query object: %s. " + "Did you intend to use key=value?" % q_obj + ) raise InvalidQueryError(msg) query &= q_obj @@ -130,10 +135,10 @@ class BaseQuerySet(object): obj_dict = self.__dict__.copy() # don't picke collection, instead pickle collection params - obj_dict.pop('_collection_obj') + obj_dict.pop("_collection_obj") # don't pickle cursor - obj_dict['_cursor_obj'] = None + obj_dict["_cursor_obj"] = None return obj_dict @@ -144,7 +149,7 @@ class BaseQuerySet(object): See https://github.com/MongoEngine/mongoengine/issues/442 """ - obj_dict['_collection_obj'] = obj_dict['_document']._get_collection() + obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() # update attributes self.__dict__.update(obj_dict) @@ -182,7 +187,7 @@ class BaseQuerySet(object): queryset._document._from_son( queryset._cursor[key], _auto_dereference=self._auto_dereference, - only_fields=self.only_fields + only_fields=self.only_fields, ) ) @@ -192,10 +197,10 @@ class BaseQuerySet(object): return queryset._document._from_son( queryset._cursor[key], _auto_dereference=self._auto_dereference, - only_fields=self.only_fields + only_fields=self.only_fields, ) - raise TypeError('Provide a slice or an integer index') + raise TypeError("Provide a slice or an integer index") def __iter__(self): raise NotImplementedError @@ -235,14 +240,13 @@ class BaseQuerySet(object): """ queryset = self.clone() if queryset._search_text: - raise OperationError( - 'It is not possible to use search_text two times.') + raise OperationError("It is not possible to use search_text two times.") - query_kwargs = SON({'$search': text}) + query_kwargs = SON({"$search": text}) if language: - query_kwargs['$language'] = language + query_kwargs["$language"] = language - queryset._query_obj &= Q(__raw__={'$text': query_kwargs}) + queryset._query_obj &= Q(__raw__={"$text": query_kwargs}) queryset._mongo_query = None queryset._cursor_obj = None queryset._search_text = text @@ -265,8 +269,7 @@ class BaseQuerySet(object): try: result = six.next(queryset) except StopIteration: - msg = ('%s matching query does not exist.' - % queryset._document._class_name) + msg = "%s matching query does not exist." % queryset._document._class_name raise queryset._document.DoesNotExist(msg) try: six.next(queryset) @@ -276,7 +279,7 @@ class BaseQuerySet(object): # If we were able to retrieve the 2nd doc, rewind the cursor and # raise the MultipleObjectsReturned exception. queryset.rewind() - message = u'%d items returned, instead of 1' % queryset.count() + message = u"%d items returned, instead of 1" % queryset.count() raise queryset._document.MultipleObjectsReturned(message) def create(self, **kwargs): @@ -295,8 +298,9 @@ class BaseQuerySet(object): result = None return result - def insert(self, doc_or_docs, load_bulk=True, - write_concern=None, signal_kwargs=None): + def insert( + self, doc_or_docs, load_bulk=True, write_concern=None, signal_kwargs=None + ): """bulk insert documents :param doc_or_docs: a document or list of documents to be inserted @@ -319,7 +323,7 @@ class BaseQuerySet(object): .. versionchanged:: 0.10.7 Add signal_kwargs argument """ - Document = _import_class('Document') + Document = _import_class("Document") if write_concern is None: write_concern = {} @@ -332,16 +336,16 @@ class BaseQuerySet(object): for doc in docs: if not isinstance(doc, self._document): - msg = ("Some documents inserted aren't instances of %s" - % str(self._document)) + msg = "Some documents inserted aren't instances of %s" % str( + self._document + ) raise OperationError(msg) if doc.pk and not doc._created: - msg = 'Some documents have ObjectIds, use doc.update() instead' + msg = "Some documents have ObjectIds, use doc.update() instead" raise OperationError(msg) signal_kwargs = signal_kwargs or {} - signals.pre_bulk_insert.send(self._document, - documents=docs, **signal_kwargs) + signals.pre_bulk_insert.send(self._document, documents=docs, **signal_kwargs) raw = [doc.to_mongo() for doc in docs] @@ -353,21 +357,25 @@ class BaseQuerySet(object): try: inserted_result = insert_func(raw) - ids = [inserted_result.inserted_id] if return_one else inserted_result.inserted_ids + ids = ( + [inserted_result.inserted_id] + if return_one + else inserted_result.inserted_ids + ) except pymongo.errors.DuplicateKeyError as err: - message = 'Could not save document (%s)' + message = "Could not save document (%s)" raise NotUniqueError(message % six.text_type(err)) except pymongo.errors.BulkWriteError as err: # inserting documents that already have an _id field will # give huge performance debt or raise - message = u'Document must not have _id value before bulk write (%s)' + message = u"Document must not have _id value before bulk write (%s)" raise NotUniqueError(message % six.text_type(err)) except pymongo.errors.OperationFailure as err: - message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', six.text_type(err)): + message = "Could not save document (%s)" + if re.match("^E1100[01] duplicate key", six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update - message = u'Tried to save duplicate unique keys (%s)' + message = u"Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % six.text_type(err)) raise OperationError(message % six.text_type(err)) @@ -377,13 +385,15 @@ class BaseQuerySet(object): if not load_bulk: signals.post_bulk_insert.send( - self._document, documents=docs, loaded=False, **signal_kwargs) + self._document, documents=docs, loaded=False, **signal_kwargs + ) return ids[0] if return_one else ids documents = self.in_bulk(ids) results = [documents.get(obj_id) for obj_id in ids] signals.post_bulk_insert.send( - self._document, documents=results, loaded=True, **signal_kwargs) + self._document, documents=results, loaded=True, **signal_kwargs + ) return results[0] if return_one else results def count(self, with_limit_and_skip=False): @@ -399,8 +409,7 @@ class BaseQuerySet(object): self._cursor_obj = None return count - def delete(self, write_concern=None, _from_doc_delete=False, - cascade_refs=None): + def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): """Delete the documents matched by the query. :param write_concern: Extra keyword arguments are passed down which @@ -423,12 +432,13 @@ class BaseQuerySet(object): # Handle deletes where skips or limits have been applied or # there is an untriggered delete signal has_delete_signal = signals.signals_available and ( - signals.pre_delete.has_receivers_for(doc) or - signals.post_delete.has_receivers_for(doc) + signals.pre_delete.has_receivers_for(doc) + or signals.post_delete.has_receivers_for(doc) ) - call_document_delete = (queryset._skip or queryset._limit or - has_delete_signal) and not _from_doc_delete + call_document_delete = ( + queryset._skip or queryset._limit or has_delete_signal + ) and not _from_doc_delete if call_document_delete: cnt = 0 @@ -437,28 +447,28 @@ class BaseQuerySet(object): cnt += 1 return cnt - delete_rules = doc._meta.get('delete_rules') or {} + delete_rules = doc._meta.get("delete_rules") or {} delete_rules = list(delete_rules.items()) # Check for DENY rules before actually deleting/nullifying any other # references for rule_entry, rule in delete_rules: document_cls, field_name = rule_entry - if document_cls._meta.get('abstract'): + if document_cls._meta.get("abstract"): continue if rule == DENY: - refs = document_cls.objects(**{field_name + '__in': self}) + refs = document_cls.objects(**{field_name + "__in": self}) if refs.limit(1).count() > 0: raise OperationError( - 'Could not delete document (%s.%s refers to it)' + "Could not delete document (%s.%s refers to it)" % (document_cls.__name__, field_name) ) # Check all the other rules for rule_entry, rule in delete_rules: document_cls, field_name = rule_entry - if document_cls._meta.get('abstract'): + if document_cls._meta.get("abstract"): continue if rule == CASCADE: @@ -467,19 +477,19 @@ class BaseQuerySet(object): if doc._collection == document_cls._collection: for ref in queryset: cascade_refs.add(ref.id) - refs = document_cls.objects(**{field_name + '__in': self, - 'pk__nin': cascade_refs}) + refs = document_cls.objects( + **{field_name + "__in": self, "pk__nin": cascade_refs} + ) if refs.count() > 0: - refs.delete(write_concern=write_concern, - cascade_refs=cascade_refs) + refs.delete(write_concern=write_concern, cascade_refs=cascade_refs) elif rule == NULLIFY: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, - **{'unset__%s' % field_name: 1}) + document_cls.objects(**{field_name + "__in": self}).update( + write_concern=write_concern, **{"unset__%s" % field_name: 1} + ) elif rule == PULL: - document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, - **{'pull_all__%s' % field_name: self}) + document_cls.objects(**{field_name + "__in": self}).update( + write_concern=write_concern, **{"pull_all__%s" % field_name: self} + ) with set_write_concern(queryset._collection, write_concern) as collection: result = collection.delete_many(queryset._query) @@ -490,8 +500,9 @@ class BaseQuerySet(object): if result.acknowledged: return result.deleted_count - def update(self, upsert=False, multi=True, write_concern=None, - full_result=False, **update): + def update( + self, upsert=False, multi=True, write_concern=None, full_result=False, **update + ): """Perform an atomic update on the fields matched by the query. :param upsert: insert if document doesn't exist (default ``False``) @@ -511,7 +522,7 @@ class BaseQuerySet(object): .. versionadded:: 0.2 """ if not update and not upsert: - raise OperationError('No update parameters, would remove data') + raise OperationError("No update parameters, would remove data") if write_concern is None: write_concern = {} @@ -522,11 +533,11 @@ class BaseQuerySet(object): # If doing an atomic upsert on an inheritable class # then ensure we add _cls to the update operation - if upsert and '_cls' in query: - if '$set' in update: - update['$set']['_cls'] = queryset._document._class_name + if upsert and "_cls" in query: + if "$set" in update: + update["$set"]["_cls"] = queryset._document._class_name else: - update['$set'] = {'_cls': queryset._document._class_name} + update["$set"] = {"_cls": queryset._document._class_name} try: with set_write_concern(queryset._collection, write_concern) as collection: update_func = collection.update_one @@ -536,14 +547,14 @@ class BaseQuerySet(object): if full_result: return result elif result.raw_result: - return result.raw_result['n'] + return result.raw_result["n"] except pymongo.errors.DuplicateKeyError as err: - raise NotUniqueError(u'Update failed (%s)' % six.text_type(err)) + raise NotUniqueError(u"Update failed (%s)" % six.text_type(err)) except pymongo.errors.OperationFailure as err: - if six.text_type(err) == u'multi not coded yet': - message = u'update() method requires MongoDB 1.1.3+' + if six.text_type(err) == u"multi not coded yet": + message = u"update() method requires MongoDB 1.1.3+" raise OperationError(message) - raise OperationError(u'Update failed (%s)' % six.text_type(err)) + raise OperationError(u"Update failed (%s)" % six.text_type(err)) def upsert_one(self, write_concern=None, **update): """Overwrite or add the first document matched by the query. @@ -561,11 +572,15 @@ class BaseQuerySet(object): .. versionadded:: 0.10.2 """ - atomic_update = self.update(multi=False, upsert=True, - write_concern=write_concern, - full_result=True, **update) + atomic_update = self.update( + multi=False, + upsert=True, + write_concern=write_concern, + full_result=True, + **update + ) - if atomic_update.raw_result['updatedExisting']: + if atomic_update.raw_result["updatedExisting"]: document = self.get() else: document = self._document.objects.with_id(atomic_update.upserted_id) @@ -594,9 +609,12 @@ class BaseQuerySet(object): multi=False, write_concern=write_concern, full_result=full_result, - **update) + **update + ) - def modify(self, upsert=False, full_response=False, remove=False, new=False, **update): + def modify( + self, upsert=False, full_response=False, remove=False, new=False, **update + ): """Update and return the updated document. Returns either the document before or after modification based on `new` @@ -621,11 +639,10 @@ class BaseQuerySet(object): """ if remove and new: - raise OperationError('Conflicting parameters: remove and new') + raise OperationError("Conflicting parameters: remove and new") if not update and not upsert and not remove: - raise OperationError( - 'No update parameters, must either update or remove') + raise OperationError("No update parameters, must either update or remove") queryset = self.clone() query = queryset._query @@ -635,27 +652,35 @@ class BaseQuerySet(object): try: if full_response: - msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' + msg = "With PyMongo 3+, it is not possible anymore to get the full response." warnings.warn(msg, DeprecationWarning) if remove: result = queryset._collection.find_one_and_delete( - query, sort=sort, **self._cursor_args) + query, sort=sort, **self._cursor_args + ) else: if new: return_doc = ReturnDocument.AFTER else: return_doc = ReturnDocument.BEFORE result = queryset._collection.find_one_and_update( - query, update, upsert=upsert, sort=sort, return_document=return_doc, - **self._cursor_args) + query, + update, + upsert=upsert, + sort=sort, + return_document=return_doc, + **self._cursor_args + ) except pymongo.errors.DuplicateKeyError as err: - raise NotUniqueError(u'Update failed (%s)' % err) + raise NotUniqueError(u"Update failed (%s)" % err) except pymongo.errors.OperationFailure as err: - raise OperationError(u'Update failed (%s)' % err) + raise OperationError(u"Update failed (%s)" % err) if full_response: - if result['value'] is not None: - result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields) + if result["value"] is not None: + result["value"] = self._document._from_son( + result["value"], only_fields=self.only_fields + ) else: if result is not None: result = self._document._from_son(result, only_fields=self.only_fields) @@ -673,7 +698,7 @@ class BaseQuerySet(object): """ queryset = self.clone() if not queryset._query_obj.empty: - msg = 'Cannot use a filter whilst using `with_id`' + msg = "Cannot use a filter whilst using `with_id`" raise InvalidQueryError(msg) return queryset.filter(pk=object_id).first() @@ -688,21 +713,22 @@ class BaseQuerySet(object): """ doc_map = {} - docs = self._collection.find({'_id': {'$in': object_ids}}, - **self._cursor_args) + docs = self._collection.find({"_id": {"$in": object_ids}}, **self._cursor_args) if self._scalar: for doc in docs: - doc_map[doc['_id']] = self._get_scalar( - self._document._from_son(doc, only_fields=self.only_fields)) + doc_map[doc["_id"]] = self._get_scalar( + self._document._from_son(doc, only_fields=self.only_fields) + ) elif self._as_pymongo: for doc in docs: - doc_map[doc['_id']] = doc + doc_map[doc["_id"]] = doc else: for doc in docs: - doc_map[doc['_id']] = self._document._from_son( + doc_map[doc["_id"]] = self._document._from_son( doc, only_fields=self.only_fields, - _auto_dereference=self._auto_dereference) + _auto_dereference=self._auto_dereference, + ) return doc_map @@ -717,8 +743,8 @@ class BaseQuerySet(object): Do NOT return any inherited documents. """ - if self._document._meta.get('allow_inheritance') is True: - self._initial_query = {'_cls': self._document._class_name} + if self._document._meta.get("allow_inheritance") is True: + self._initial_query = {"_cls": self._document._class_name} return self @@ -747,15 +773,35 @@ class BaseQuerySet(object): """ if not isinstance(new_qs, BaseQuerySet): raise OperationError( - '%s is not a subclass of BaseQuerySet' % new_qs.__name__) + "%s is not a subclass of BaseQuerySet" % new_qs.__name__ + ) - copy_props = ('_mongo_query', '_initial_query', '_none', '_query_obj', - '_where_clause', '_loaded_fields', '_ordering', - '_snapshot', '_timeout', '_class_check', '_slave_okay', - '_read_preference', '_iter', '_scalar', '_as_pymongo', - '_limit', '_skip', '_hint', '_auto_dereference', - '_search_text', 'only_fields', '_max_time_ms', - '_comment', '_batch_size') + copy_props = ( + "_mongo_query", + "_initial_query", + "_none", + "_query_obj", + "_where_clause", + "_loaded_fields", + "_ordering", + "_snapshot", + "_timeout", + "_class_check", + "_slave_okay", + "_read_preference", + "_iter", + "_scalar", + "_as_pymongo", + "_limit", + "_skip", + "_hint", + "_auto_dereference", + "_search_text", + "only_fields", + "_max_time_ms", + "_comment", + "_batch_size", + ) for prop in copy_props: val = getattr(self, prop) @@ -868,37 +914,43 @@ class BaseQuerySet(object): except LookUpError: pass - distinct = self._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=self._document) + distinct = self._dereference( + queryset._cursor.distinct(field), 1, name=field, instance=self._document + ) - doc_field = self._document._fields.get(field.split('.', 1)[0]) + doc_field = self._document._fields.get(field.split(".", 1)[0]) instance = None # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) - EmbeddedDocumentField = _import_class('EmbeddedDocumentField') - ListField = _import_class('ListField') - GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + ListField = _import_class("ListField") + GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") if isinstance(doc_field, ListField): - doc_field = getattr(doc_field, 'field', doc_field) + doc_field = getattr(doc_field, "field", doc_field) if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): - instance = getattr(doc_field, 'document_type', None) + instance = getattr(doc_field, "document_type", None) # handle distinct on subdocuments - if '.' in field: - for field_part in field.split('.')[1:]: + if "." in field: + for field_part in field.split(".")[1:]: # if looping on embedded document, get the document type instance - if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + if instance and isinstance( + doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField) + ): doc_field = instance # now get the subdocument doc_field = getattr(doc_field, field_part, doc_field) # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) if isinstance(doc_field, ListField): - doc_field = getattr(doc_field, 'field', doc_field) - if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): - instance = getattr(doc_field, 'document_type', None) + doc_field = getattr(doc_field, "field", doc_field) + if isinstance( + doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField) + ): + instance = getattr(doc_field, "document_type", None) - if instance and isinstance(doc_field, (EmbeddedDocumentField, - GenericEmbeddedDocumentField)): + if instance and isinstance( + doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField) + ): distinct = [instance(**doc) for doc in distinct] return distinct @@ -970,14 +1022,14 @@ class BaseQuerySet(object): """ # Check for an operator and transform to mongo-style if there is - operators = ['slice'] + operators = ["slice"] cleaned_fields = [] for key, value in kwargs.items(): - parts = key.split('__') + parts = key.split("__") if parts[0] in operators: op = parts.pop(0) - value = {'$' + op: value} - key = '.'.join(parts) + value = {"$" + op: value} + key = ".".join(parts) cleaned_fields.append((key, value)) # Sort fields by their values, explicitly excluded fields first, then @@ -998,7 +1050,8 @@ class BaseQuerySet(object): fields = [field for field, value in group] fields = queryset._fields_to_dbfields(fields) queryset._loaded_fields += QueryFieldList( - fields, value=value, _only_called=_only_called) + fields, value=value, _only_called=_only_called + ) return queryset @@ -1012,7 +1065,8 @@ class BaseQuerySet(object): """ queryset = self.clone() queryset._loaded_fields = QueryFieldList( - always_include=queryset._loaded_fields.always_include) + always_include=queryset._loaded_fields.always_include + ) return queryset def order_by(self, *keys): @@ -1053,7 +1107,7 @@ class BaseQuerySet(object): See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment for details. """ - return self._chainable_method('comment', text) + return self._chainable_method("comment", text) def explain(self, format=False): """Return an explain plan record for the @@ -1066,8 +1120,10 @@ class BaseQuerySet(object): # TODO remove this option completely - it's useless. If somebody # wants to pretty-print the output, they easily can. if format: - msg = ('"format" param of BaseQuerySet.explain has been ' - 'deprecated and will be removed in future versions.') + msg = ( + '"format" param of BaseQuerySet.explain has been ' + "deprecated and will be removed in future versions." + ) warnings.warn(msg, DeprecationWarning) plan = pprint.pformat(plan) @@ -1082,7 +1138,7 @@ class BaseQuerySet(object): ..versionchanged:: 0.5 - made chainable .. deprecated:: Ignored with PyMongo 3+ """ - msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' + msg = "snapshot is deprecated as it has no impact when using PyMongo 3+." warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._snapshot = enabled @@ -1107,7 +1163,7 @@ class BaseQuerySet(object): .. deprecated:: Ignored with PyMongo 3+ """ - msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' + msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+." warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._slave_okay = enabled @@ -1119,10 +1175,12 @@ class BaseQuerySet(object): :param read_preference: override ReplicaSetConnection-level preference. """ - validate_read_preference('read_preference', read_preference) + validate_read_preference("read_preference", read_preference) queryset = self.clone() queryset._read_preference = read_preference - queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference + queryset._cursor_obj = ( + None + ) # we need to re-create the cursor object whenever we apply read_preference return queryset def scalar(self, *fields): @@ -1168,7 +1226,7 @@ class BaseQuerySet(object): :param ms: the number of milliseconds before killing the query on the server """ - return self._chainable_method('max_time_ms', ms) + return self._chainable_method("max_time_ms", ms) # JSON Helpers @@ -1179,7 +1237,10 @@ class BaseQuerySet(object): def from_json(self, json_data): """Converts json data to unsaved objects""" son_data = json_util.loads(json_data) - return [self._document._from_son(data, only_fields=self.only_fields) for data in son_data] + return [ + self._document._from_son(data, only_fields=self.only_fields) + for data in son_data + ] def aggregate(self, *pipeline, **kwargs): """ @@ -1192,32 +1253,34 @@ class BaseQuerySet(object): initial_pipeline = [] if self._query: - initial_pipeline.append({'$match': self._query}) + initial_pipeline.append({"$match": self._query}) if self._ordering: - initial_pipeline.append({'$sort': dict(self._ordering)}) + initial_pipeline.append({"$sort": dict(self._ordering)}) if self._limit is not None: # As per MongoDB Documentation (https://docs.mongodb.com/manual/reference/operator/aggregation/limit/), # keeping limit stage right after sort stage is more efficient. But this leads to wrong set of documents # for a skip stage that might succeed these. So we need to maintain more documents in memory in such a # case (https://stackoverflow.com/a/24161461). - initial_pipeline.append({'$limit': self._limit + (self._skip or 0)}) + initial_pipeline.append({"$limit": self._limit + (self._skip or 0)}) if self._skip is not None: - initial_pipeline.append({'$skip': self._skip}) + initial_pipeline.append({"$skip": self._skip}) pipeline = initial_pipeline + list(pipeline) if self._read_preference is not None: - return self._collection.with_options(read_preference=self._read_preference) \ - .aggregate(pipeline, cursor={}, **kwargs) + return self._collection.with_options( + read_preference=self._read_preference + ).aggregate(pipeline, cursor={}, **kwargs) return self._collection.aggregate(pipeline, cursor={}, **kwargs) # JS functionality - def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, - scope=None): + def map_reduce( + self, map_f, reduce_f, output, finalize_f=None, limit=None, scope=None + ): """Perform a map/reduce query using the current query spec and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, it must be the last call made, as it does not return a maleable @@ -1257,10 +1320,10 @@ class BaseQuerySet(object): """ queryset = self.clone() - MapReduceDocument = _import_class('MapReduceDocument') + MapReduceDocument = _import_class("MapReduceDocument") - if not hasattr(self._collection, 'map_reduce'): - raise NotImplementedError('Requires MongoDB >= 1.7.1') + if not hasattr(self._collection, "map_reduce"): + raise NotImplementedError("Requires MongoDB >= 1.7.1") map_f_scope = {} if isinstance(map_f, Code): @@ -1275,7 +1338,7 @@ class BaseQuerySet(object): reduce_f_code = queryset._sub_js_fields(reduce_f) reduce_f = Code(reduce_f_code, reduce_f_scope) - mr_args = {'query': queryset._query} + mr_args = {"query": queryset._query} if finalize_f: finalize_f_scope = {} @@ -1284,39 +1347,39 @@ class BaseQuerySet(object): finalize_f = six.text_type(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f) finalize_f = Code(finalize_f_code, finalize_f_scope) - mr_args['finalize'] = finalize_f + mr_args["finalize"] = finalize_f if scope: - mr_args['scope'] = scope + mr_args["scope"] = scope if limit: - mr_args['limit'] = limit + mr_args["limit"] = limit - if output == 'inline' and not queryset._ordering: - map_reduce_function = 'inline_map_reduce' + if output == "inline" and not queryset._ordering: + map_reduce_function = "inline_map_reduce" else: - map_reduce_function = 'map_reduce' + map_reduce_function = "map_reduce" if isinstance(output, six.string_types): - mr_args['out'] = output + mr_args["out"] = output elif isinstance(output, dict): ordered_output = [] - for part in ('replace', 'merge', 'reduce'): + for part in ("replace", "merge", "reduce"): value = output.get(part) if value: ordered_output.append((part, value)) break else: - raise OperationError('actionData not specified for output') + raise OperationError("actionData not specified for output") - db_alias = output.get('db_alias') - remaing_args = ['db', 'sharded', 'nonAtomic'] + db_alias = output.get("db_alias") + remaing_args = ["db", "sharded", "nonAtomic"] if db_alias: - ordered_output.append(('db', get_db(db_alias).name)) + ordered_output.append(("db", get_db(db_alias).name)) del remaing_args[0] for part in remaing_args: @@ -1324,20 +1387,22 @@ class BaseQuerySet(object): if value: ordered_output.append((part, value)) - mr_args['out'] = SON(ordered_output) + mr_args["out"] = SON(ordered_output) results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) + map_f, reduce_f, **mr_args + ) - if map_reduce_function == 'map_reduce': + if map_reduce_function == "map_reduce": results = results.find() if queryset._ordering: results = results.sort(queryset._ordering) for doc in results: - yield MapReduceDocument(queryset._document, queryset._collection, - doc['_id'], doc['value']) + yield MapReduceDocument( + queryset._document, queryset._collection, doc["_id"], doc["value"] + ) def exec_js(self, code, *fields, **options): """Execute a Javascript function on the server. A list of fields may be @@ -1368,16 +1433,13 @@ class BaseQuerySet(object): fields = [queryset._document._translate_field_name(f) for f in fields] collection = queryset._document._get_collection_name() - scope = { - 'collection': collection, - 'options': options or {}, - } + scope = {"collection": collection, "options": options or {}} query = queryset._query if queryset._where_clause: - query['$where'] = queryset._where_clause + query["$where"] = queryset._where_clause - scope['query'] = query + scope["query"] = query code = Code(code, scope=scope) db = queryset._document._get_db() @@ -1407,22 +1469,22 @@ class BaseQuerySet(object): """ db_field = self._fields_to_dbfields([field]).pop() pipeline = [ - {'$match': self._query}, - {'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}} + {"$match": self._query}, + {"$group": {"_id": "sum", "total": {"$sum": "$" + db_field}}}, ] # if we're performing a sum over a list field, we sum up all the # elements in the list, hence we need to $unwind the arrays first - ListField = _import_class('ListField') - field_parts = field.split('.') + ListField = _import_class("ListField") + field_parts = field.split(".") field_instances = self._document._lookup_field(field_parts) if isinstance(field_instances[-1], ListField): - pipeline.insert(1, {'$unwind': '$' + field}) + pipeline.insert(1, {"$unwind": "$" + field}) result = tuple(self._document._get_collection().aggregate(pipeline)) if result: - return result[0]['total'] + return result[0]["total"] return 0 def average(self, field): @@ -1433,22 +1495,22 @@ class BaseQuerySet(object): """ db_field = self._fields_to_dbfields([field]).pop() pipeline = [ - {'$match': self._query}, - {'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}} + {"$match": self._query}, + {"$group": {"_id": "avg", "total": {"$avg": "$" + db_field}}}, ] # if we're performing an average over a list field, we average out # all the elements in the list, hence we need to $unwind the arrays # first - ListField = _import_class('ListField') - field_parts = field.split('.') + ListField = _import_class("ListField") + field_parts = field.split(".") field_instances = self._document._lookup_field(field_parts) if isinstance(field_instances[-1], ListField): - pipeline.insert(1, {'$unwind': '$' + field}) + pipeline.insert(1, {"$unwind": "$" + field}) result = tuple(self._document._get_collection().aggregate(pipeline)) if result: - return result[0]['total'] + return result[0]["total"] return 0 def item_frequencies(self, field, normalize=False, map_reduce=True): @@ -1474,8 +1536,7 @@ class BaseQuerySet(object): document lookups """ if map_reduce: - return self._item_frequencies_map_reduce(field, - normalize=normalize) + return self._item_frequencies_map_reduce(field, normalize=normalize) return self._item_frequencies_exec_js(field, normalize=normalize) # Iterator helpers @@ -1492,15 +1553,17 @@ class BaseQuerySet(object): return raw_doc doc = self._document._from_son( - raw_doc, _auto_dereference=self._auto_dereference, - only_fields=self.only_fields) + raw_doc, + _auto_dereference=self._auto_dereference, + only_fields=self.only_fields, + ) if self._scalar: return self._get_scalar(doc) return doc - next = __next__ # For Python2 support + next = __next__ # For Python2 support def rewind(self): """Rewind the cursor to its unevaluated state. @@ -1521,15 +1584,13 @@ class BaseQuerySet(object): @property def _cursor_args(self): - fields_name = 'projection' + fields_name = "projection" # snapshot is not handled at all by PyMongo 3+ # TODO: evaluate similar possibilities using modifiers if self._snapshot: - msg = 'The snapshot option is not anymore available with PyMongo 3+' + msg = "The snapshot option is not anymore available with PyMongo 3+" warnings.warn(msg, DeprecationWarning) - cursor_args = { - 'no_cursor_timeout': not self._timeout - } + cursor_args = {"no_cursor_timeout": not self._timeout} if self._loaded_fields: cursor_args[fields_name] = self._loaded_fields.as_dict() @@ -1538,7 +1599,7 @@ class BaseQuerySet(object): if fields_name not in cursor_args: cursor_args[fields_name] = {} - cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'} + cursor_args[fields_name]["_text_score"] = {"$meta": "textScore"} return cursor_args @@ -1555,12 +1616,11 @@ class BaseQuerySet(object): # level, not a cursor level. Thus, we need to get a cloned collection # object using `with_options` first. if self._read_preference is not None: - self._cursor_obj = self._collection\ - .with_options(read_preference=self._read_preference)\ - .find(self._query, **self._cursor_args) + self._cursor_obj = self._collection.with_options( + read_preference=self._read_preference + ).find(self._query, **self._cursor_args) else: - self._cursor_obj = self._collection.find(self._query, - **self._cursor_args) + self._cursor_obj = self._collection.find(self._query, **self._cursor_args) # Apply "where" clauses to cursor if self._where_clause: where_clause = self._sub_js_fields(self._where_clause) @@ -1576,9 +1636,9 @@ class BaseQuerySet(object): if self._ordering: # explicit ordering self._cursor_obj.sort(self._ordering) - elif self._ordering is None and self._document._meta['ordering']: + elif self._ordering is None and self._document._meta["ordering"]: # default ordering - order = self._get_order_by(self._document._meta['ordering']) + order = self._get_order_by(self._document._meta["ordering"]) self._cursor_obj.sort(order) if self._limit is not None: @@ -1607,8 +1667,10 @@ class BaseQuerySet(object): if self._mongo_query is None: self._mongo_query = self._query_obj.to_query(self._document) if self._class_check and self._initial_query: - if '_cls' in self._mongo_query: - self._mongo_query = {'$and': [self._initial_query, self._mongo_query]} + if "_cls" in self._mongo_query: + self._mongo_query = { + "$and": [self._initial_query, self._mongo_query] + } else: self._mongo_query.update(self._initial_query) return self._mongo_query @@ -1616,7 +1678,7 @@ class BaseQuerySet(object): @property def _dereference(self): if not self.__dereference: - self.__dereference = _import_class('DeReference')() + self.__dereference = _import_class("DeReference")() return self.__dereference def no_dereference(self): @@ -1649,7 +1711,9 @@ class BaseQuerySet(object): emit(null, 1); } } - """ % {'field': field} + """ % { + "field": field + } reduce_func = """ function(key, values) { var total = 0; @@ -1660,7 +1724,7 @@ class BaseQuerySet(object): return total; } """ - values = self.map_reduce(map_func, reduce_func, 'inline') + values = self.map_reduce(map_func, reduce_func, "inline") frequencies = {} for f in values: key = f.key @@ -1671,8 +1735,7 @@ class BaseQuerySet(object): if normalize: count = sum(frequencies.values()) - frequencies = {k: float(v) / count - for k, v in frequencies.items()} + frequencies = {k: float(v) / count for k, v in frequencies.items()} return frequencies @@ -1742,15 +1805,14 @@ class BaseQuerySet(object): def _fields_to_dbfields(self, fields): """Translate fields' paths to their db equivalents.""" subclasses = [] - if self._document._meta['allow_inheritance']: - subclasses = [get_document(x) - for x in self._document._subclasses][1:] + if self._document._meta["allow_inheritance"]: + subclasses = [get_document(x) for x in self._document._subclasses][1:] db_field_paths = [] for field in fields: - field_parts = field.split('.') + field_parts = field.split(".") try: - field = '.'.join( + field = ".".join( f if isinstance(f, six.string_types) else f.db_field for f in self._document._lookup_field(field_parts) ) @@ -1762,7 +1824,7 @@ class BaseQuerySet(object): # through its subclasses and see if it exists on any of them. for subdoc in subclasses: try: - subfield = '.'.join( + subfield = ".".join( f if isinstance(f, six.string_types) else f.db_field for f in subdoc._lookup_field(field_parts) ) @@ -1790,18 +1852,18 @@ class BaseQuerySet(object): if not key: continue - if key == '$text_score': - key_list.append(('_text_score', {'$meta': 'textScore'})) + if key == "$text_score": + key_list.append(("_text_score", {"$meta": "textScore"})) continue direction = pymongo.ASCENDING - if key[0] == '-': + if key[0] == "-": direction = pymongo.DESCENDING - if key[0] in ('-', '+'): + if key[0] in ("-", "+"): key = key[1:] - key = key.replace('__', '.') + key = key.replace("__", ".") try: key = self._document._translate_field_name(key) except Exception: @@ -1813,9 +1875,8 @@ class BaseQuerySet(object): return key_list def _get_scalar(self, doc): - def lookup(obj, name): - chunks = name.split('__') + chunks = name.split("__") for chunk in chunks: obj = getattr(obj, chunk) return obj @@ -1835,21 +1896,20 @@ class BaseQuerySet(object): def field_sub(match): # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') + field_name = match.group(1).split(".") fields = self._document._lookup_field(field_name) # Substitute the correct name for the field into the javascript return u'["%s"]' % fields[-1].db_field def field_path_sub(match): # Extract just the field name, and look up the field objects - field_name = match.group(1).split('.') + field_name = match.group(1).split(".") fields = self._document._lookup_field(field_name) # Substitute the correct name for the field into the javascript - return '.'.join([f.db_field for f in fields]) + return ".".join([f.db_field for f in fields]) - code = re.sub(r'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) - code = re.sub(r'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, - code) + code = re.sub(r"\[\s*~([A-z_][A-z_0-9.]+?)\s*\]", field_sub, code) + code = re.sub(r"\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}", field_path_sub, code) return code def _chainable_method(self, method_name, val): @@ -1866,22 +1926,26 @@ class BaseQuerySet(object): getattr(cursor, method_name)(val) # Cache the value on the queryset._{method_name} - setattr(queryset, '_' + method_name, val) + setattr(queryset, "_" + method_name, val) return queryset # Deprecated def ensure_index(self, **kwargs): """Deprecated use :func:`Document.ensure_index`""" - msg = ('Doc.objects()._ensure_index() is deprecated. ' - 'Use Doc.ensure_index() instead.') + msg = ( + "Doc.objects()._ensure_index() is deprecated. " + "Use Doc.ensure_index() instead." + ) warnings.warn(msg, DeprecationWarning) self._document.__class__.ensure_index(**kwargs) return self def _ensure_indexes(self): """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ('Doc.objects()._ensure_indexes() is deprecated. ' - 'Use Doc.ensure_indexes() instead.') + msg = ( + "Doc.objects()._ensure_indexes() is deprecated. " + "Use Doc.ensure_indexes() instead." + ) warnings.warn(msg, DeprecationWarning) self._document.__class__.ensure_indexes() diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py index dba724af..5c3ff222 100644 --- a/mongoengine/queryset/field_list.py +++ b/mongoengine/queryset/field_list.py @@ -1,12 +1,15 @@ -__all__ = ('QueryFieldList',) +__all__ = ("QueryFieldList",) class QueryFieldList(object): """Object that handles combinations of .only() and .exclude() calls""" + ONLY = 1 EXCLUDE = 0 - def __init__(self, fields=None, value=ONLY, always_include=None, _only_called=False): + def __init__( + self, fields=None, value=ONLY, always_include=None, _only_called=False + ): """The QueryFieldList builder :param fields: A list of fields used in `.only()` or `.exclude()` @@ -49,7 +52,7 @@ class QueryFieldList(object): self.fields = f.fields - self.fields self._clean_slice() - if '_id' in f.fields: + if "_id" in f.fields: self._id = f.value if self.always_include: @@ -59,7 +62,7 @@ class QueryFieldList(object): else: self.fields -= self.always_include - if getattr(f, '_only_called', False): + if getattr(f, "_only_called", False): self._only_called = True return self @@ -73,7 +76,7 @@ class QueryFieldList(object): if self.slice: field_list.update(self.slice) if self._id is not None: - field_list['_id'] = self._id + field_list["_id"] = self._id return field_list def reset(self): diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py index f93dbb43..5067ffbf 100644 --- a/mongoengine/queryset/manager.py +++ b/mongoengine/queryset/manager.py @@ -1,7 +1,7 @@ from functools import partial from mongoengine.queryset.queryset import QuerySet -__all__ = ('queryset_manager', 'QuerySetManager') +__all__ = ("queryset_manager", "QuerySetManager") class QuerySetManager(object): @@ -33,7 +33,7 @@ class QuerySetManager(object): return self # owner is the document that contains the QuerySetManager - queryset_class = owner._meta.get('queryset_class', self.default) + queryset_class = owner._meta.get("queryset_class", self.default) queryset = queryset_class(owner, owner._get_collection()) if self.get_queryset: arg_count = self.get_queryset.__code__.co_argcount diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index c7c593b1..4ba62d46 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -1,11 +1,24 @@ import six from mongoengine.errors import OperationError -from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING, - NULLIFY, PULL) +from mongoengine.queryset.base import ( + BaseQuerySet, + CASCADE, + DENY, + DO_NOTHING, + NULLIFY, + PULL, +) -__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE', - 'DENY', 'PULL') +__all__ = ( + "QuerySet", + "QuerySetNoCache", + "DO_NOTHING", + "NULLIFY", + "CASCADE", + "DENY", + "PULL", +) # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 @@ -57,12 +70,12 @@ class QuerySet(BaseQuerySet): def __repr__(self): """Provide a string representation of the QuerySet""" if self._iter: - return '.. queryset mid-iteration ..' + return ".. queryset mid-iteration .." self._populate_cache() - data = self._result_cache[:REPR_OUTPUT_SIZE + 1] + data = self._result_cache[: REPR_OUTPUT_SIZE + 1] if len(data) > REPR_OUTPUT_SIZE: - data[-1] = '...(remaining elements truncated)...' + data[-1] = "...(remaining elements truncated)..." return repr(data) def _iter_results(self): @@ -143,10 +156,9 @@ class QuerySet(BaseQuerySet): .. versionadded:: 0.8.3 Convert to non caching queryset """ if self._result_cache is not None: - raise OperationError('QuerySet already cached') + raise OperationError("QuerySet already cached") - return self._clone_into(QuerySetNoCache(self._document, - self._collection)) + return self._clone_into(QuerySetNoCache(self._document, self._collection)) class QuerySetNoCache(BaseQuerySet): @@ -165,7 +177,7 @@ class QuerySetNoCache(BaseQuerySet): .. versionchanged:: 0.6.13 Now doesnt modify the cursor """ if self._iter: - return '.. queryset mid-iteration ..' + return ".. queryset mid-iteration .." data = [] for _ in six.moves.range(REPR_OUTPUT_SIZE + 1): @@ -175,7 +187,7 @@ class QuerySetNoCache(BaseQuerySet): break if len(data) > REPR_OUTPUT_SIZE: - data[-1] = '...(remaining elements truncated)...' + data[-1] = "...(remaining elements truncated)..." self.rewind() return repr(data) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 128a4e44..0b73e99b 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -10,21 +10,54 @@ from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class from mongoengine.errors import InvalidQueryError -__all__ = ('query', 'update') +__all__ = ("query", "update") -COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not', 'elemMatch', 'type') -GEO_OPERATORS = ('within_distance', 'within_spherical_distance', - 'within_box', 'within_polygon', 'near', 'near_sphere', - 'max_distance', 'min_distance', 'geo_within', 'geo_within_box', - 'geo_within_polygon', 'geo_within_center', - 'geo_within_sphere', 'geo_intersects') -STRING_OPERATORS = ('contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact') -CUSTOM_OPERATORS = ('match',) -MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + - STRING_OPERATORS + CUSTOM_OPERATORS) +COMPARISON_OPERATORS = ( + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "nin", + "mod", + "all", + "size", + "exists", + "not", + "elemMatch", + "type", +) +GEO_OPERATORS = ( + "within_distance", + "within_spherical_distance", + "within_box", + "within_polygon", + "near", + "near_sphere", + "max_distance", + "min_distance", + "geo_within", + "geo_within_box", + "geo_within_polygon", + "geo_within_center", + "geo_within_sphere", + "geo_intersects", +) +STRING_OPERATORS = ( + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "exact", + "iexact", +) +CUSTOM_OPERATORS = ("match",) +MATCH_OPERATORS = ( + COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS +) # TODO make this less complex @@ -33,11 +66,11 @@ def query(_doc_cls=None, **kwargs): mongo_query = {} merge_query = defaultdict(list) for key, value in sorted(kwargs.items()): - if key == '__raw__': + if key == "__raw__": mongo_query.update(value) continue - parts = key.rsplit('__') + parts = key.rsplit("__") indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is @@ -46,11 +79,11 @@ def query(_doc_cls=None, **kwargs): op = parts.pop() # Allow to escape operator-like field name by __ - if len(parts) > 1 and parts[-1] == '': + if len(parts) > 1 and parts[-1] == "": parts.pop() negate = False - if len(parts) > 1 and parts[-1] == 'not': + if len(parts) > 1 and parts[-1] == "not": parts.pop() negate = True @@ -62,8 +95,8 @@ def query(_doc_cls=None, **kwargs): raise InvalidQueryError(e) parts = [] - CachedReferenceField = _import_class('CachedReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') + CachedReferenceField = _import_class("CachedReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") cleaned_fields = [] for field in fields: @@ -73,7 +106,7 @@ def query(_doc_cls=None, **kwargs): append_field = False # is last and CachedReferenceField elif isinstance(field, CachedReferenceField) and fields[-1] == field: - parts.append('%s._id' % field.db_field) + parts.append("%s._id" % field.db_field) else: parts.append(field.db_field) @@ -83,15 +116,15 @@ def query(_doc_cls=None, **kwargs): # Convert value to proper value field = cleaned_fields[-1] - singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] + singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"] singular_ops += STRING_OPERATORS if op in singular_ops: value = field.prepare_query_value(op, value) if isinstance(field, CachedReferenceField) and value: - value = value['_id'] + value = value["_id"] - elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): + elif op in ("in", "nin", "all", "near") and not isinstance(value, dict): # Raise an error if the in/nin/all/near param is not iterable. value = _prepare_query_for_iterable(field, op, value) @@ -101,39 +134,40 @@ def query(_doc_cls=None, **kwargs): # * If the value is an ObjectId, the key should be "field_name._ref.$id". if isinstance(field, GenericReferenceField): if isinstance(value, DBRef): - parts[-1] += '._ref' + parts[-1] += "._ref" elif isinstance(value, ObjectId): - parts[-1] += '._ref.$id' + parts[-1] += "._ref.$id" # if op and op not in COMPARISON_OPERATORS: if op: if op in GEO_OPERATORS: value = _geo_operator(field, op, value) - elif op in ('match', 'elemMatch'): - ListField = _import_class('ListField') - EmbeddedDocumentField = _import_class('EmbeddedDocumentField') + elif op in ("match", "elemMatch"): + ListField = _import_class("ListField") + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") if ( - isinstance(value, dict) and - isinstance(field, ListField) and - isinstance(field.field, EmbeddedDocumentField) + isinstance(value, dict) + and isinstance(field, ListField) + and isinstance(field.field, EmbeddedDocumentField) ): value = query(field.field.document_type, **value) else: value = field.prepare_query_value(op, value) - value = {'$elemMatch': value} + value = {"$elemMatch": value} elif op in CUSTOM_OPERATORS: - NotImplementedError('Custom method "%s" has not ' - 'been implemented' % op) + NotImplementedError( + 'Custom method "%s" has not ' "been implemented" % op + ) elif op not in STRING_OPERATORS: - value = {'$' + op: value} + value = {"$" + op: value} if negate: - value = {'$not': value} + value = {"$not": value} for i, part in indices: parts.insert(i, part) - key = '.'.join(parts) + key = ".".join(parts) if op is None or key not in mongo_query: mongo_query[key] = value @@ -142,30 +176,35 @@ def query(_doc_cls=None, **kwargs): mongo_query[key].update(value) # $max/minDistance needs to come last - convert to SON value_dict = mongo_query[key] - if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \ - ('$near' in value_dict or '$nearSphere' in value_dict): + if ("$maxDistance" in value_dict or "$minDistance" in value_dict) and ( + "$near" in value_dict or "$nearSphere" in value_dict + ): value_son = SON() for k, v in iteritems(value_dict): - if k == '$maxDistance' or k == '$minDistance': + if k == "$maxDistance" or k == "$minDistance": continue value_son[k] = v # Required for MongoDB >= 2.6, may fail when combining # PyMongo 3+ and MongoDB < 2.6 near_embedded = False - for near_op in ('$near', '$nearSphere'): + for near_op in ("$near", "$nearSphere"): if isinstance(value_dict.get(near_op), dict): value_son[near_op] = SON(value_son[near_op]) - if '$maxDistance' in value_dict: - value_son[near_op]['$maxDistance'] = value_dict['$maxDistance'] - if '$minDistance' in value_dict: - value_son[near_op]['$minDistance'] = value_dict['$minDistance'] + if "$maxDistance" in value_dict: + value_son[near_op]["$maxDistance"] = value_dict[ + "$maxDistance" + ] + if "$minDistance" in value_dict: + value_son[near_op]["$minDistance"] = value_dict[ + "$minDistance" + ] near_embedded = True if not near_embedded: - if '$maxDistance' in value_dict: - value_son['$maxDistance'] = value_dict['$maxDistance'] - if '$minDistance' in value_dict: - value_son['$minDistance'] = value_dict['$minDistance'] + if "$maxDistance" in value_dict: + value_son["$maxDistance"] = value_dict["$maxDistance"] + if "$minDistance" in value_dict: + value_son["$minDistance"] = value_dict["$minDistance"] mongo_query[key] = value_son else: # Store for manually merging later @@ -177,10 +216,10 @@ def query(_doc_cls=None, **kwargs): del mongo_query[k] if isinstance(v, list): value = [{k: val} for val in v] - if '$and' in mongo_query.keys(): - mongo_query['$and'].extend(value) + if "$and" in mongo_query.keys(): + mongo_query["$and"].extend(value) else: - mongo_query['$and'] = value + mongo_query["$and"] = value return mongo_query @@ -192,15 +231,15 @@ def update(_doc_cls=None, **update): mongo_update = {} for key, value in update.items(): - if key == '__raw__': + if key == "__raw__": mongo_update.update(value) continue - parts = key.split('__') + parts = key.split("__") # if there is no operator, default to 'set' if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: - parts.insert(0, 'set') + parts.insert(0, "set") # Check for an operator and transform to mongo-style if there is op = None @@ -208,13 +247,13 @@ def update(_doc_cls=None, **update): op = parts.pop(0) # Convert Pythonic names to Mongo equivalents operator_map = { - 'push_all': 'pushAll', - 'pull_all': 'pullAll', - 'dec': 'inc', - 'add_to_set': 'addToSet', - 'set_on_insert': 'setOnInsert' + "push_all": "pushAll", + "pull_all": "pullAll", + "dec": "inc", + "add_to_set": "addToSet", + "set_on_insert": "setOnInsert", } - if op == 'dec': + if op == "dec": # Support decrement by flipping a positive value's sign # and using 'inc' value = -value @@ -227,7 +266,7 @@ def update(_doc_cls=None, **update): match = parts.pop() # Allow to escape operator-like field name by __ - if len(parts) > 1 and parts[-1] == '': + if len(parts) > 1 and parts[-1] == "": parts.pop() if _doc_cls: @@ -244,8 +283,8 @@ def update(_doc_cls=None, **update): append_field = True if isinstance(field, six.string_types): # Convert the S operator to $ - if field == 'S': - field = '$' + if field == "S": + field = "$" parts.append(field) append_field = False else: @@ -253,7 +292,7 @@ def update(_doc_cls=None, **update): if append_field: appended_sub_field = False cleaned_fields.append(field) - if hasattr(field, 'field'): + if hasattr(field, "field"): cleaned_fields.append(field.field) appended_sub_field = True @@ -263,52 +302,53 @@ def update(_doc_cls=None, **update): else: field = cleaned_fields[-1] - GeoJsonBaseField = _import_class('GeoJsonBaseField') + GeoJsonBaseField = _import_class("GeoJsonBaseField") if isinstance(field, GeoJsonBaseField): value = field.to_mongo(value) - if op == 'pull': + if op == "pull": if field.required or value is not None: - if match in ('in', 'nin') and not isinstance(value, dict): + if match in ("in", "nin") and not isinstance(value, dict): value = _prepare_query_for_iterable(field, op, value) else: value = field.prepare_query_value(op, value) - elif op == 'push' and isinstance(value, (list, tuple, set)): + elif op == "push" and isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] - elif op in (None, 'set', 'push'): + elif op in (None, "set", "push"): if field.required or value is not None: value = field.prepare_query_value(op, value) - elif op in ('pushAll', 'pullAll'): + elif op in ("pushAll", "pullAll"): value = [field.prepare_query_value(op, v) for v in value] - elif op in ('addToSet', 'setOnInsert'): + elif op in ("addToSet", "setOnInsert"): if isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) - elif op == 'unset': + elif op == "unset": value = 1 - elif op == 'inc': + elif op == "inc": value = field.prepare_query_value(op, value) if match: - match = '$' + match + match = "$" + match value = {match: value} - key = '.'.join(parts) + key = ".".join(parts) - if 'pull' in op and '.' in key: + if "pull" in op and "." in key: # Dot operators don't work on pull operations # unless they point to a list field # Otherwise it uses nested dict syntax - if op == 'pullAll': - raise InvalidQueryError('pullAll operations only support ' - 'a single field depth') + if op == "pullAll": + raise InvalidQueryError( + "pullAll operations only support a single field depth" + ) # Look for the last list field and use dot notation until there field_classes = [c.__class__ for c in cleaned_fields] field_classes.reverse() - ListField = _import_class('ListField') - EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') + ListField = _import_class("ListField") + EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField") if ListField in field_classes or EmbeddedDocumentListField in field_classes: # Join all fields via dot notation to the last ListField or EmbeddedDocumentListField # Then process as normal @@ -317,37 +357,36 @@ def update(_doc_cls=None, **update): else: _check_field = EmbeddedDocumentListField - last_listField = len( - cleaned_fields) - field_classes.index(_check_field) - key = '.'.join(parts[:last_listField]) + last_listField = len(cleaned_fields) - field_classes.index(_check_field) + key = ".".join(parts[:last_listField]) parts = parts[last_listField:] parts.insert(0, key) parts.reverse() for key in parts: value = {key: value} - elif op == 'addToSet' and isinstance(value, list): - value = {key: {'$each': value}} - elif op in ('push', 'pushAll'): + elif op == "addToSet" and isinstance(value, list): + value = {key: {"$each": value}} + elif op in ("push", "pushAll"): if parts[-1].isdigit(): - key = '.'.join(parts[0:-1]) + key = ".".join(parts[0:-1]) position = int(parts[-1]) # $position expects an iterable. If pushing a single value, # wrap it in a list. if not isinstance(value, (set, tuple, list)): value = [value] - value = {key: {'$each': value, '$position': position}} + value = {key: {"$each": value, "$position": position}} else: - if op == 'pushAll': - op = 'push' # convert to non-deprecated keyword + if op == "pushAll": + op = "push" # convert to non-deprecated keyword if not isinstance(value, (set, tuple, list)): value = [value] - value = {key: {'$each': value}} + value = {key: {"$each": value}} else: value = {key: value} else: value = {key: value} - key = '$' + op + key = "$" + op if key not in mongo_update: mongo_update[key] = value elif key in mongo_update and isinstance(mongo_update[key], dict): @@ -358,45 +397,45 @@ def update(_doc_cls=None, **update): def _geo_operator(field, op, value): """Helper to return the query for a given geo query.""" - if op == 'max_distance': - value = {'$maxDistance': value} - elif op == 'min_distance': - value = {'$minDistance': value} + if op == "max_distance": + value = {"$maxDistance": value} + elif op == "min_distance": + value = {"$minDistance": value} elif field._geo_index == pymongo.GEO2D: - if op == 'within_distance': - value = {'$within': {'$center': value}} - elif op == 'within_spherical_distance': - value = {'$within': {'$centerSphere': value}} - elif op == 'within_polygon': - value = {'$within': {'$polygon': value}} - elif op == 'near': - value = {'$near': value} - elif op == 'near_sphere': - value = {'$nearSphere': value} - elif op == 'within_box': - value = {'$within': {'$box': value}} - else: - raise NotImplementedError('Geo method "%s" has not been ' - 'implemented for a GeoPointField' % op) - else: - if op == 'geo_within': - value = {'$geoWithin': _infer_geometry(value)} - elif op == 'geo_within_box': - value = {'$geoWithin': {'$box': value}} - elif op == 'geo_within_polygon': - value = {'$geoWithin': {'$polygon': value}} - elif op == 'geo_within_center': - value = {'$geoWithin': {'$center': value}} - elif op == 'geo_within_sphere': - value = {'$geoWithin': {'$centerSphere': value}} - elif op == 'geo_intersects': - value = {'$geoIntersects': _infer_geometry(value)} - elif op == 'near': - value = {'$near': _infer_geometry(value)} + if op == "within_distance": + value = {"$within": {"$center": value}} + elif op == "within_spherical_distance": + value = {"$within": {"$centerSphere": value}} + elif op == "within_polygon": + value = {"$within": {"$polygon": value}} + elif op == "near": + value = {"$near": value} + elif op == "near_sphere": + value = {"$nearSphere": value} + elif op == "within_box": + value = {"$within": {"$box": value}} else: raise NotImplementedError( - 'Geo method "%s" has not been implemented for a %s ' - % (op, field._name) + 'Geo method "%s" has not been ' "implemented for a GeoPointField" % op + ) + else: + if op == "geo_within": + value = {"$geoWithin": _infer_geometry(value)} + elif op == "geo_within_box": + value = {"$geoWithin": {"$box": value}} + elif op == "geo_within_polygon": + value = {"$geoWithin": {"$polygon": value}} + elif op == "geo_within_center": + value = {"$geoWithin": {"$center": value}} + elif op == "geo_within_sphere": + value = {"$geoWithin": {"$centerSphere": value}} + elif op == "geo_intersects": + value = {"$geoIntersects": _infer_geometry(value)} + elif op == "near": + value = {"$near": _infer_geometry(value)} + else: + raise NotImplementedError( + 'Geo method "%s" has not been implemented for a %s ' % (op, field._name) ) return value @@ -406,51 +445,58 @@ def _infer_geometry(value): given value. """ if isinstance(value, dict): - if '$geometry' in value: + if "$geometry" in value: return value - elif 'coordinates' in value and 'type' in value: - return {'$geometry': value} - raise InvalidQueryError('Invalid $geometry dictionary should have ' - 'type and coordinates keys') + elif "coordinates" in value and "type" in value: + return {"$geometry": value} + raise InvalidQueryError( + "Invalid $geometry dictionary should have type and coordinates keys" + ) elif isinstance(value, (list, set)): # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? try: value[0][0][0] - return {'$geometry': {'type': 'Polygon', 'coordinates': value}} + return {"$geometry": {"type": "Polygon", "coordinates": value}} except (TypeError, IndexError): pass try: value[0][0] - return {'$geometry': {'type': 'LineString', 'coordinates': value}} + return {"$geometry": {"type": "LineString", "coordinates": value}} except (TypeError, IndexError): pass try: value[0] - return {'$geometry': {'type': 'Point', 'coordinates': value}} + return {"$geometry": {"type": "Point", "coordinates": value}} except (TypeError, IndexError): pass - raise InvalidQueryError('Invalid $geometry data. Can be either a ' - 'dictionary or (nested) lists of coordinate(s)') + raise InvalidQueryError( + "Invalid $geometry data. Can be either a " + "dictionary or (nested) lists of coordinate(s)" + ) def _prepare_query_for_iterable(field, op, value): # We need a special check for BaseDocument, because - although it's iterable - using # it as such in the context of this method is most definitely a mistake. - BaseDocument = _import_class('BaseDocument') + BaseDocument = _import_class("BaseDocument") if isinstance(value, BaseDocument): - raise TypeError("When using the `in`, `nin`, `all`, or " - "`near`-operators you can\'t use a " - "`Document`, you must wrap your object " - "in a list (object -> [object]).") + raise TypeError( + "When using the `in`, `nin`, `all`, or " + "`near`-operators you can't use a " + "`Document`, you must wrap your object " + "in a list (object -> [object])." + ) - if not hasattr(value, '__iter__'): - raise TypeError("The `in`, `nin`, `all`, or " - "`near`-operators must be applied to an " - "iterable (e.g. a list).") + if not hasattr(value, "__iter__"): + raise TypeError( + "The `in`, `nin`, `all`, or " + "`near`-operators must be applied to an " + "iterable (e.g. a list)." + ) return [field.prepare_query_value(op, v) for v in value] diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 9d97094b..0fe139fd 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -3,7 +3,7 @@ import copy from mongoengine.errors import InvalidQueryError from mongoengine.queryset import transform -__all__ = ('Q', 'QNode') +__all__ = ("Q", "QNode") class QNodeVisitor(object): @@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor): self.document = document def visit_combination(self, combination): - operator = '$and' + operator = "$and" if combination.operation == combination.OR: - operator = '$or' + operator = "$or" return {operator: combination.children} def visit_query(self, query): @@ -96,7 +96,7 @@ class QNode(object): """Combine this node with another node into a QCombination object. """ - if getattr(other, 'empty', True): + if getattr(other, "empty", True): return self if self.empty: @@ -132,8 +132,8 @@ class QCombination(QNode): self.children.append(node) def __repr__(self): - op = ' & ' if self.operation is self.AND else ' | ' - return '(%s)' % op.join([repr(node) for node in self.children]) + op = " & " if self.operation is self.AND else " | " + return "(%s)" % op.join([repr(node) for node in self.children]) def accept(self, visitor): for i in range(len(self.children)): @@ -156,7 +156,7 @@ class Q(QNode): self.query = query def __repr__(self): - return 'Q(**%s)' % repr(self.query) + return "Q(**%s)" % repr(self.query) def accept(self, visitor): return visitor.visit_query(self) diff --git a/mongoengine/signals.py b/mongoengine/signals.py index a892dec0..0db63604 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,5 +1,12 @@ -__all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', - 'post_save', 'pre_delete', 'post_delete') +__all__ = ( + "pre_init", + "post_init", + "pre_save", + "pre_save_post_validation", + "post_save", + "pre_delete", + "post_delete", +) signals_available = False try: @@ -7,6 +14,7 @@ try: signals_available = True except ImportError: + class Namespace(object): def signal(self, name, doc=None): return _FakeSignal(name, doc) @@ -23,13 +31,16 @@ except ImportError: self.__doc__ = doc def _fail(self, *args, **kwargs): - raise RuntimeError('signalling support is unavailable ' - 'because the blinker library is ' - 'not installed.') + raise RuntimeError( + "signalling support is unavailable " + "because the blinker library is " + "not installed." + ) send = lambda *a, **kw: None # noqa - connect = disconnect = has_receivers_for = receivers_for = \ - temporarily_connected_to = _fail + connect = ( + disconnect + ) = has_receivers_for = receivers_for = temporarily_connected_to = _fail del _fail @@ -37,12 +48,12 @@ except ImportError: # not put signals in here. Create your own namespace instead. _signals = Namespace() -pre_init = _signals.signal('pre_init') -post_init = _signals.signal('post_init') -pre_save = _signals.signal('pre_save') -pre_save_post_validation = _signals.signal('pre_save_post_validation') -post_save = _signals.signal('post_save') -pre_delete = _signals.signal('pre_delete') -post_delete = _signals.signal('post_delete') -pre_bulk_insert = _signals.signal('pre_bulk_insert') -post_bulk_insert = _signals.signal('post_bulk_insert') +pre_init = _signals.signal("pre_init") +post_init = _signals.signal("post_init") +pre_save = _signals.signal("pre_save") +pre_save_post_validation = _signals.signal("pre_save_post_validation") +post_save = _signals.signal("post_save") +pre_delete = _signals.signal("pre_delete") +post_delete = _signals.signal("post_delete") +pre_bulk_insert = _signals.signal("pre_bulk_insert") +post_bulk_insert = _signals.signal("post_bulk_insert") diff --git a/requirements.txt b/requirements.txt index 9bb319a5..62ad8766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +black nose pymongo>=3.4 six==1.10.0 diff --git a/setup.cfg b/setup.cfg index 84086601..4bded428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ detailed-errors=1 cover-package=mongoengine [flake8] -ignore=E501,F401,F403,F405,I201,I202,W504, W605 +ignore=E501,F401,F403,F405,I201,I202,W504, W605, W503 exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests max-complexity=47 application-import-names=mongoengine,tests diff --git a/setup.py b/setup.py index f1f5dea7..c73a93ff 100644 --- a/setup.py +++ b/setup.py @@ -8,13 +8,10 @@ try: except ImportError: pass -DESCRIPTION = ( - 'MongoEngine is a Python Object-Document ' - 'Mapper for working with MongoDB.' -) +DESCRIPTION = "MongoEngine is a Python Object-Document Mapper for working with MongoDB." try: - with open('README.rst') as fin: + with open("README.rst") as fin: LONG_DESCRIPTION = fin.read() except Exception: LONG_DESCRIPTION = None @@ -24,23 +21,23 @@ def get_version(version_tuple): """Return the version tuple as a string, e.g. for (0, 10, 7), return '0.10.7'. """ - return '.'.join(map(str, version_tuple)) + return ".".join(map(str, version_tuple)) # Dirty hack to get version number from monogengine/__init__.py - we can't # import it as it depends on PyMongo and PyMongo isn't installed until this # file is read -init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py') -version_line = list(filter(lambda l: l.startswith('VERSION'), open(init)))[0] +init = os.path.join(os.path.dirname(__file__), "mongoengine", "__init__.py") +version_line = list(filter(lambda l: l.startswith("VERSION"), open(init)))[0] -VERSION = get_version(eval(version_line.split('=')[-1])) +VERSION = get_version(eval(version_line.split("=")[-1])) CLASSIFIERS = [ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", @@ -48,39 +45,40 @@ CLASSIFIERS = [ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", - 'Topic :: Database', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Topic :: Database", + "Topic :: Software Development :: Libraries :: Python Modules", ] extra_opts = { - 'packages': find_packages(exclude=['tests', 'tests.*']), - 'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0'] + "packages": find_packages(exclude=["tests", "tests.*"]), + "tests_require": ["nose", "coverage==4.2", "blinker", "Pillow>=2.0.0"], } if sys.version_info[0] == 3: - extra_opts['use_2to3'] = True - if 'test' in sys.argv or 'nosetests' in sys.argv: - extra_opts['packages'] = find_packages() - extra_opts['package_data'] = { - 'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']} + extra_opts["use_2to3"] = True + if "test" in sys.argv or "nosetests" in sys.argv: + extra_opts["packages"] = find_packages() + extra_opts["package_data"] = { + "tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"] + } else: - extra_opts['tests_require'] += ['python-dateutil'] + extra_opts["tests_require"] += ["python-dateutil"] setup( - name='mongoengine', + name="mongoengine", version=VERSION, - author='Harry Marr', - author_email='harry.marr@gmail.com', + author="Harry Marr", + author_email="harry.marr@gmail.com", maintainer="Stefan Wojcik", maintainer_email="wojcikstefan@gmail.com", - url='http://mongoengine.org/', - download_url='https://github.com/MongoEngine/mongoengine/tarball/master', - license='MIT', + url="http://mongoengine.org/", + download_url="https://github.com/MongoEngine/mongoengine/tarball/master", + license="MIT", include_package_data=True, description=DESCRIPTION, long_description=LONG_DESCRIPTION, - platforms=['any'], + platforms=["any"], classifiers=CLASSIFIERS, - install_requires=['pymongo>=3.4', 'six'], - test_suite='nose.collector', + install_requires=["pymongo>=3.4", "six"], + test_suite="nose.collector", **extra_opts ) diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py index 3aebe4ba..a755e7a3 100644 --- a/tests/all_warnings/__init__.py +++ b/tests/all_warnings/__init__.py @@ -9,34 +9,32 @@ import warnings from mongoengine import * -__all__ = ('AllWarnings', ) +__all__ = ("AllWarnings",) class AllWarnings(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") self.warning_list = [] self.showwarning_default = warnings.showwarning warnings.showwarning = self.append_to_warning_list def append_to_warning_list(self, message, category, *args): - self.warning_list.append({"message": message, - "category": category}) + self.warning_list.append({"message": message, "category": category}) def tearDown(self): # restore default handling of warnings warnings.showwarning = self.showwarning_default def test_document_collection_syntax_warning(self): - class NonAbstractBase(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class InheritedDocumentFailTest(NonAbstractBase): - meta = {'collection': 'fail'} + meta = {"collection": "fail"} warning = self.warning_list[0] self.assertEqual(SyntaxWarning, warning["category"]) - self.assertEqual('non_abstract_base', - InheritedDocumentFailTest._get_collection_name()) + self.assertEqual( + "non_abstract_base", InheritedDocumentFailTest._get_collection_name() + ) diff --git a/tests/document/__init__.py b/tests/document/__init__.py index dc35c969..f2230c48 100644 --- a/tests/document/__init__.py +++ b/tests/document/__init__.py @@ -9,5 +9,5 @@ from .instance import * from .json_serialisation import * from .validation import * -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 4fc648b7..87f1215b 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -7,13 +7,12 @@ from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, PULL from mongoengine.connection import get_db -__all__ = ("ClassMethodsTest", ) +__all__ = ("ClassMethodsTest",) class ClassMethodsTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") self.db = get_db() class Person(Document): @@ -33,11 +32,13 @@ class ClassMethodsTest(unittest.TestCase): def test_definition(self): """Ensure that document may be defined using fields. """ - self.assertEqual(['_cls', 'age', 'id', 'name'], - sorted(self.Person._fields.keys())) - self.assertEqual(["IntField", "ObjectIdField", "StringField", "StringField"], - sorted([x.__class__.__name__ for x in - self.Person._fields.values()])) + self.assertEqual( + ["_cls", "age", "id", "name"], sorted(self.Person._fields.keys()) + ) + self.assertEqual( + ["IntField", "ObjectIdField", "StringField", "StringField"], + sorted([x.__class__.__name__ for x in self.Person._fields.values()]), + ) def test_get_db(self): """Ensure that get_db returns the expected db. @@ -49,21 +50,21 @@ class ClassMethodsTest(unittest.TestCase): """Ensure that get_collection_name returns the expected collection name. """ - collection_name = 'person' + collection_name = "person" self.assertEqual(collection_name, self.Person._get_collection_name()) def test_get_collection(self): """Ensure that get_collection returns the expected collection. """ - collection_name = 'person' + collection_name = "person" collection = self.Person._get_collection() self.assertEqual(self.db[collection_name], collection) def test_drop_collection(self): """Ensure that the collection may be dropped from the database. """ - collection_name = 'person' - self.Person(name='Test').save() + collection_name = "person" + self.Person(name="Test").save() self.assertIn(collection_name, list_collection_names(self.db)) self.Person.drop_collection() @@ -73,14 +74,16 @@ class ClassMethodsTest(unittest.TestCase): """Ensure that register delete rule adds a delete rule to the document meta. """ + class Job(Document): employee = ReferenceField(self.Person) - self.assertEqual(self.Person._meta.get('delete_rules'), None) + self.assertEqual(self.Person._meta.get("delete_rules"), None) - self.Person.register_delete_rule(Job, 'employee', NULLIFY) - self.assertEqual(self.Person._meta['delete_rules'], - {(Job, 'employee'): NULLIFY}) + self.Person.register_delete_rule(Job, "employee", NULLIFY) + self.assertEqual( + self.Person._meta["delete_rules"], {(Job, "employee"): NULLIFY} + ) def test_compare_indexes(self): """ Ensure that the indexes are properly created and that @@ -93,23 +96,27 @@ class ClassMethodsTest(unittest.TestCase): description = StringField() tags = StringField() - meta = { - 'indexes': [('author', 'title')] - } + meta = {"indexes": [("author", "title")]} BlogPost.drop_collection() BlogPost.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - BlogPost.ensure_index(['author', 'description']) - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('author', 1), ('description', 1)]]}) + BlogPost.ensure_index(["author", "description"]) + self.assertEqual( + BlogPost.compare_indexes(), + {"missing": [], "extra": [[("author", 1), ("description", 1)]]}, + ) - BlogPost._get_collection().drop_index('author_1_description_1') - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + BlogPost._get_collection().drop_index("author_1_description_1") + self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - BlogPost._get_collection().drop_index('author_1_title_1') - self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('author', 1), ('title', 1)]], 'extra': []}) + BlogPost._get_collection().drop_index("author_1_title_1") + self.assertEqual( + BlogPost.compare_indexes(), + {"missing": [[("author", 1), ("title", 1)]], "extra": []}, + ) def test_compare_indexes_inheritance(self): """ Ensure that the indexes are properly created and that @@ -122,32 +129,34 @@ class ClassMethodsTest(unittest.TestCase): title = StringField() description = StringField() - meta = { - 'allow_inheritance': True - } + meta = {"allow_inheritance": True} class BlogPostWithTags(BlogPost): tags = StringField() tag_list = ListField(StringField()) - meta = { - 'indexes': [('author', 'tags')] - } + meta = {"indexes": [("author", "tags")]} BlogPost.drop_collection() BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - BlogPostWithTags.ensure_index(['author', 'tag_list']) - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': [[('_cls', 1), ('author', 1), ('tag_list', 1)]]}) + BlogPostWithTags.ensure_index(["author", "tag_list"]) + self.assertEqual( + BlogPost.compare_indexes(), + {"missing": [], "extra": [[("_cls", 1), ("author", 1), ("tag_list", 1)]]}, + ) - BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tag_list_1') - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) + BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tag_list_1") + self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) - BlogPostWithTags._get_collection().drop_index('_cls_1_author_1_tags_1') - self.assertEqual(BlogPost.compare_indexes(), {'missing': [[('_cls', 1), ('author', 1), ('tags', 1)]], 'extra': []}) + BlogPostWithTags._get_collection().drop_index("_cls_1_author_1_tags_1") + self.assertEqual( + BlogPost.compare_indexes(), + {"missing": [[("_cls", 1), ("author", 1), ("tags", 1)]], "extra": []}, + ) def test_compare_indexes_multiple_subclasses(self): """ Ensure that compare_indexes behaves correctly if called from a @@ -159,32 +168,30 @@ class ClassMethodsTest(unittest.TestCase): title = StringField() description = StringField() - meta = { - 'allow_inheritance': True - } + meta = {"allow_inheritance": True} class BlogPostWithTags(BlogPost): tags = StringField() tag_list = ListField(StringField()) - meta = { - 'indexes': [('author', 'tags')] - } + meta = {"indexes": [("author", "tags")]} class BlogPostWithCustomField(BlogPost): custom = DictField() - meta = { - 'indexes': [('author', 'custom')] - } + meta = {"indexes": [("author", "custom")]} BlogPost.ensure_indexes() BlogPostWithTags.ensure_indexes() BlogPostWithCustomField.ensure_indexes() - self.assertEqual(BlogPost.compare_indexes(), {'missing': [], 'extra': []}) - self.assertEqual(BlogPostWithTags.compare_indexes(), {'missing': [], 'extra': []}) - self.assertEqual(BlogPostWithCustomField.compare_indexes(), {'missing': [], 'extra': []}) + self.assertEqual(BlogPost.compare_indexes(), {"missing": [], "extra": []}) + self.assertEqual( + BlogPostWithTags.compare_indexes(), {"missing": [], "extra": []} + ) + self.assertEqual( + BlogPostWithCustomField.compare_indexes(), {"missing": [], "extra": []} + ) def test_compare_indexes_for_text_indexes(self): """ Ensure that compare_indexes behaves correctly for text indexes """ @@ -192,17 +199,20 @@ class ClassMethodsTest(unittest.TestCase): class Doc(Document): a = StringField() b = StringField() - meta = {'indexes': [ - {'fields': ['$a', "$b"], - 'default_language': 'english', - 'weights': {'a': 10, 'b': 2} - } - ]} + meta = { + "indexes": [ + { + "fields": ["$a", "$b"], + "default_language": "english", + "weights": {"a": 10, "b": 2}, + } + ] + } Doc.drop_collection() Doc.ensure_indexes() actual = Doc.compare_indexes() - expected = {'missing': [], 'extra': []} + expected = {"missing": [], "extra": []} self.assertEqual(actual, expected) def test_list_indexes_inheritance(self): @@ -215,23 +225,17 @@ class ClassMethodsTest(unittest.TestCase): title = StringField() description = StringField() - meta = { - 'allow_inheritance': True - } + meta = {"allow_inheritance": True} class BlogPostWithTags(BlogPost): tags = StringField() - meta = { - 'indexes': [('author', 'tags')] - } + meta = {"indexes": [("author", "tags")]} class BlogPostWithTagsAndExtraText(BlogPostWithTags): extra_text = StringField() - meta = { - 'indexes': [('author', 'tags', 'extra_text')] - } + meta = {"indexes": [("author", "tags", "extra_text")]} BlogPost.drop_collection() @@ -239,17 +243,21 @@ class ClassMethodsTest(unittest.TestCase): BlogPostWithTags.ensure_indexes() BlogPostWithTagsAndExtraText.ensure_indexes() - self.assertEqual(BlogPost.list_indexes(), - BlogPostWithTags.list_indexes()) - self.assertEqual(BlogPost.list_indexes(), - BlogPostWithTagsAndExtraText.list_indexes()) - self.assertEqual(BlogPost.list_indexes(), - [[('_cls', 1), ('author', 1), ('tags', 1)], - [('_cls', 1), ('author', 1), ('tags', 1), ('extra_text', 1)], - [(u'_id', 1)], [('_cls', 1)]]) + self.assertEqual(BlogPost.list_indexes(), BlogPostWithTags.list_indexes()) + self.assertEqual( + BlogPost.list_indexes(), BlogPostWithTagsAndExtraText.list_indexes() + ) + self.assertEqual( + BlogPost.list_indexes(), + [ + [("_cls", 1), ("author", 1), ("tags", 1)], + [("_cls", 1), ("author", 1), ("tags", 1), ("extra_text", 1)], + [(u"_id", 1)], + [("_cls", 1)], + ], + ) def test_register_delete_rule_inherited(self): - class Vaccine(Document): name = StringField(required=True) @@ -257,15 +265,17 @@ class ClassMethodsTest(unittest.TestCase): class Animal(Document): family = StringField(required=True) - vaccine_made = ListField(ReferenceField("Vaccine", reverse_delete_rule=PULL)) + vaccine_made = ListField( + ReferenceField("Vaccine", reverse_delete_rule=PULL) + ) meta = {"allow_inheritance": True, "indexes": ["family"]} class Cat(Animal): name = StringField(required=True) - self.assertEqual(Vaccine._meta['delete_rules'][(Animal, 'vaccine_made')], PULL) - self.assertEqual(Vaccine._meta['delete_rules'][(Cat, 'vaccine_made')], PULL) + self.assertEqual(Vaccine._meta["delete_rules"][(Animal, "vaccine_made")], PULL) + self.assertEqual(Vaccine._meta["delete_rules"][(Cat, "vaccine_made")], PULL) def test_collection_naming(self): """Ensure that a collection with a specified name may be used. @@ -273,74 +283,73 @@ class ClassMethodsTest(unittest.TestCase): class DefaultNamingTest(Document): pass - self.assertEqual('default_naming_test', - DefaultNamingTest._get_collection_name()) + + self.assertEqual( + "default_naming_test", DefaultNamingTest._get_collection_name() + ) class CustomNamingTest(Document): - meta = {'collection': 'pimp_my_collection'} + meta = {"collection": "pimp_my_collection"} - self.assertEqual('pimp_my_collection', - CustomNamingTest._get_collection_name()) + self.assertEqual("pimp_my_collection", CustomNamingTest._get_collection_name()) class DynamicNamingTest(Document): - meta = {'collection': lambda c: "DYNAMO"} - self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name()) + meta = {"collection": lambda c: "DYNAMO"} + + self.assertEqual("DYNAMO", DynamicNamingTest._get_collection_name()) # Use Abstract class to handle backwards compatibility class BaseDocument(Document): - meta = { - 'abstract': True, - 'collection': lambda c: c.__name__.lower() - } + meta = {"abstract": True, "collection": lambda c: c.__name__.lower()} class OldNamingConvention(BaseDocument): pass - self.assertEqual('oldnamingconvention', - OldNamingConvention._get_collection_name()) + + self.assertEqual( + "oldnamingconvention", OldNamingConvention._get_collection_name() + ) class InheritedAbstractNamingTest(BaseDocument): - meta = {'collection': 'wibble'} - self.assertEqual('wibble', - InheritedAbstractNamingTest._get_collection_name()) + meta = {"collection": "wibble"} + + self.assertEqual("wibble", InheritedAbstractNamingTest._get_collection_name()) # Mixin tests class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } + meta = {"collection": lambda c: c.__name__.lower()} class OldMixinNamingConvention(Document, BaseMixin): pass - self.assertEqual('oldmixinnamingconvention', - OldMixinNamingConvention._get_collection_name()) + + self.assertEqual( + "oldmixinnamingconvention", OldMixinNamingConvention._get_collection_name() + ) class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } + meta = {"collection": lambda c: c.__name__.lower()} class BaseDocument(Document, BaseMixin): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class MyDocument(BaseDocument): pass - self.assertEqual('basedocument', MyDocument._get_collection_name()) + self.assertEqual("basedocument", MyDocument._get_collection_name()) def test_custom_collection_name_operations(self): """Ensure that a collection with a specified name is used as expected. """ - collection_name = 'personCollTest' + collection_name = "personCollTest" class Person(Document): name = StringField() - meta = {'collection': collection_name} + meta = {"collection": collection_name} Person(name="Test User").save() self.assertIn(collection_name, list_collection_names(self.db)) user_obj = self.db[collection_name].find_one() - self.assertEqual(user_obj['name'], "Test User") + self.assertEqual(user_obj["name"], "Test User") user_obj = Person.objects[0] self.assertEqual(user_obj.name, "Test User") @@ -354,7 +363,7 @@ class ClassMethodsTest(unittest.TestCase): class Person(Document): name = StringField(primary_key=True) - meta = {'collection': 'app'} + meta = {"collection": "app"} Person(name="Test User").save() @@ -364,5 +373,5 @@ class ClassMethodsTest(unittest.TestCase): Person.drop_collection() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/delta.py b/tests/document/delta.py index 504c1707..8f1575e6 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -8,7 +8,6 @@ from tests.utils import MongoDBTestCase class DeltaTest(MongoDBTestCase): - def setUp(self): super(DeltaTest, self).setUp() @@ -31,7 +30,6 @@ class DeltaTest(MongoDBTestCase): self.delta(DynamicDocument) def delta(self, DocClass): - class Doc(DocClass): string_field = StringField() int_field = IntField() @@ -46,37 +44,37 @@ class DeltaTest(MongoDBTestCase): self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + doc.string_field = "hello" + self.assertEqual(doc._get_changed_fields(), ["string_field"]) + self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + self.assertEqual(doc._get_changed_fields(), ["int_field"]) + self.assertEqual(doc._delta(), ({"int_field": 1}, {})) doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} + dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["dict_field"]) + self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] + list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["list_field"]) + self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["dict_field"]) + self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["list_field"]) + self.assertEqual(doc._delta(), ({}, {"list_field": 1})) def test_delta_recursive(self): self.delta_recursive(Document, EmbeddedDocument) @@ -85,7 +83,6 @@ class DeltaTest(MongoDBTestCase): self.delta_recursive(DynamicDocument, DynamicEmbeddedDocument) def delta_recursive(self, DocClass, EmbeddedClass): - class Embedded(EmbeddedClass): id = StringField() string_field = StringField() @@ -110,165 +107,207 @@ class DeltaTest(MongoDBTestCase): embedded_1 = Embedded() embedded_1.id = "010101" - embedded_1.string_field = 'hello' + embedded_1.string_field = "hello" embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] + embedded_1.dict_field = {"hello": "world"} + embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) + self.assertEqual(doc._get_changed_fields(), ["embedded_field"]) embedded_delta = { - 'id': "010101", - 'string_field': 'hello', - 'int_field': 1, - 'dict_field': {'hello': 'world'}, - 'list_field': ['1', 2, {'hello': 'world'}] + "id": "010101", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], } self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), - ({'embedded_field': embedded_delta}, {})) + self.assertEqual(doc._delta(), ({"embedded_field": embedded_delta}, {})) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), - ['embedded_field.dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["embedded_field.dict_field"]) + self.assertEqual(doc.embedded_field._delta(), ({}, {"dict_field": 1})) + self.assertEqual(doc._delta(), ({}, {"embedded_field.dict_field": 1})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.dict_field, {}) doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) + self.assertEqual(doc.embedded_field._delta(), ({}, {"list_field": 1})) + self.assertEqual(doc._delta(), ({}, {"embedded_field.list_field": 1})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field, []) embedded_2 = Embedded() - embedded_2.string_field = 'hello' + embedded_2.string_field = "hello" embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] + embedded_2.dict_field = {"hello": "world"} + embedded_2.list_field = ["1", 2, {"hello": "world"}] - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) + doc.embedded_field.list_field = ["1", 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field"]) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) + self.assertEqual( + doc.embedded_field._delta(), + ( + { + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, + ), + ) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_cls': 'Embedded', - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) + self.assertEqual( + doc._delta(), + ( + { + "embedded_field.list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "dict_field": {"hello": "world"}, + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[0], "1") self.assertEqual(doc.embedded_field.list_field[1], 2) for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], - embedded_2[k]) + self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field.2.string_field']) - self.assertEqual(doc.embedded_field._delta(), - ({'list_field.2.string_field': 'world'}, {})) - self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.string_field': 'world'}, {})) + doc.embedded_field.list_field[2].string_field = "world" + self.assertEqual( + doc._get_changed_fields(), ["embedded_field.list_field.2.string_field"] + ) + self.assertEqual( + doc.embedded_field._delta(), ({"list_field.2.string_field": "world"}, {}) + ) + self.assertEqual( + doc._delta(), ({"embedded_field.list_field.2.string_field": "world"}, {}) + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, - 'world') + self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field.2']) - self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': { - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - }, {})) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': { - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - }, {})) + self.assertEqual(doc._get_changed_fields(), ["embedded_field.list_field.2"]) + self.assertEqual( + doc.embedded_field._delta(), + ( + { + "list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, + ), + ) + self.assertEqual( + doc._delta(), + ( + { + "embedded_field.list_field.2": { + "_cls": "Embedded", + "string_field": "hello world", + "int_field": 1, + "list_field": ["1", 2, {"hello": "world"}], + "dict_field": {"hello": "world"}, + } + }, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, - 'hello world') + self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': - [2, {'hello': 'world'}]}, {})) + self.assertEqual( + doc._delta(), + ({"embedded_field.list_field.2.list_field": [2, {"hello": "world"}]}, {}), + ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': - [2, {'hello': 'world'}, 1]}, {})) + self.assertEqual( + doc._delta(), + ( + {"embedded_field.list_field.2.list_field": [2, {"hello": "world"}, 1]}, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, - [2, {'hello': 'world'}, 1]) + self.assertEqual( + doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] + ) doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, - [1, 2, {'hello': 'world'}]) + self.assertEqual( + doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] + ) - del doc.embedded_field.list_field[2].list_field[2]['hello'] - self.assertEqual(doc._delta(), - ({}, {'embedded_field.list_field.2.list_field.2.hello': 1})) + del doc.embedded_field.list_field[2].list_field[2]["hello"] + self.assertEqual( + doc._delta(), ({}, {"embedded_field.list_field.2.list_field.2.hello": 1}) + ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual(doc._delta(), - ({}, {'embedded_field.list_field.2.list_field': 1})) + self.assertEqual( + doc._delta(), ({}, {"embedded_field.list_field.2.list_field": 1}) + ) doc.save() doc = doc.reload(10) - doc.dict_field['Embedded'] = embedded_1 + doc.dict_field["Embedded"] = embedded_1 doc.save() doc = doc.reload(10) - doc.dict_field['Embedded'].string_field = 'Hello World' - self.assertEqual(doc._get_changed_fields(), - ['dict_field.Embedded.string_field']) - self.assertEqual(doc._delta(), - ({'dict_field.Embedded.string_field': 'Hello World'}, {})) + doc.dict_field["Embedded"].string_field = "Hello World" + self.assertEqual( + doc._get_changed_fields(), ["dict_field.Embedded.string_field"] + ) + self.assertEqual( + doc._delta(), ({"dict_field.Embedded.string_field": "Hello World"}, {}) + ) def test_circular_reference_deltas(self): self.circular_reference_deltas(Document, Document) @@ -277,14 +316,13 @@ class DeltaTest(MongoDBTestCase): self.circular_reference_deltas(DynamicDocument, DynamicDocument) def circular_reference_deltas(self, DocClass1, DocClass2): - class Person(DocClass1): name = StringField() - owns = ListField(ReferenceField('Organization')) + owns = ListField(ReferenceField("Organization")) class Organization(DocClass2): name = StringField() - owner = ReferenceField('Person') + owner = ReferenceField("Person") Person.drop_collection() Organization.drop_collection() @@ -310,16 +348,15 @@ class DeltaTest(MongoDBTestCase): self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) def circular_reference_deltas_2(self, DocClass1, DocClass2, dbref=True): - class Person(DocClass1): name = StringField() - owns = ListField(ReferenceField('Organization', dbref=dbref)) - employer = ReferenceField('Organization', dbref=dbref) + owns = ListField(ReferenceField("Organization", dbref=dbref)) + employer = ReferenceField("Organization", dbref=dbref) class Organization(DocClass2): name = StringField() - owner = ReferenceField('Person', dbref=dbref) - employees = ListField(ReferenceField('Person', dbref=dbref)) + owner = ReferenceField("Person", dbref=dbref) + employees = ListField(ReferenceField("Person", dbref=dbref)) Person.drop_collection() Organization.drop_collection() @@ -353,12 +390,11 @@ class DeltaTest(MongoDBTestCase): self.delta_db_field(DynamicDocument) def delta_db_field(self, DocClass): - class Doc(DocClass): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') + string_field = StringField(db_field="db_string_field") + int_field = IntField(db_field="db_int_field") + dict_field = DictField(db_field="db_dict_field") + list_field = ListField(db_field="db_list_field") Doc.drop_collection() doc = Doc() @@ -368,53 +404,53 @@ class DeltaTest(MongoDBTestCase): self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['db_string_field']) - self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) + doc.string_field = "hello" + self.assertEqual(doc._get_changed_fields(), ["db_string_field"]) + self.assertEqual(doc._delta(), ({"db_string_field": "hello"}, {})) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['db_int_field']) - self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) + self.assertEqual(doc._get_changed_fields(), ["db_int_field"]) + self.assertEqual(doc._delta(), ({"db_int_field": 1}, {})) doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} + dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) + self.assertEqual(doc._delta(), ({"db_dict_field": dict_value}, {})) doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] + list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) + self.assertEqual(doc._delta(), ({"db_list_field": list_value}, {})) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["db_dict_field"]) + self.assertEqual(doc._delta(), ({}, {"db_dict_field": 1})) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["db_list_field"]) + self.assertEqual(doc._delta(), ({}, {"db_list_field": 1})) # Test it saves that data doc = Doc() doc.save() - doc.string_field = 'hello' + doc.string_field = "hello" doc.int_field = 1 - doc.dict_field = {'hello': 'world'} - doc.list_field = ['1', 2, {'hello': 'world'}] + doc.dict_field = {"hello": "world"} + doc.list_field = ["1", 2, {"hello": "world"}] doc.save() doc = doc.reload(10) - self.assertEqual(doc.string_field, 'hello') + self.assertEqual(doc.string_field, "hello") self.assertEqual(doc.int_field, 1) - self.assertEqual(doc.dict_field, {'hello': 'world'}) - self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) + self.assertEqual(doc.dict_field, {"hello": "world"}) + self.assertEqual(doc.list_field, ["1", 2, {"hello": "world"}]) def test_delta_recursive_db_field(self): self.delta_recursive_db_field(Document, EmbeddedDocument) @@ -423,20 +459,20 @@ class DeltaTest(MongoDBTestCase): self.delta_recursive_db_field(DynamicDocument, DynamicEmbeddedDocument) def delta_recursive_db_field(self, DocClass, EmbeddedClass): - class Embedded(EmbeddedClass): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') + string_field = StringField(db_field="db_string_field") + int_field = IntField(db_field="db_int_field") + dict_field = DictField(db_field="db_dict_field") + list_field = ListField(db_field="db_list_field") class Doc(DocClass): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - embedded_field = EmbeddedDocumentField(Embedded, - db_field='db_embedded_field') + string_field = StringField(db_field="db_string_field") + int_field = IntField(db_field="db_int_field") + dict_field = DictField(db_field="db_dict_field") + list_field = ListField(db_field="db_list_field") + embedded_field = EmbeddedDocumentField( + Embedded, db_field="db_embedded_field" + ) Doc.drop_collection() doc = Doc() @@ -447,171 +483,228 @@ class DeltaTest(MongoDBTestCase): self.assertEqual(doc._delta(), ({}, {})) embedded_1 = Embedded() - embedded_1.string_field = 'hello' + embedded_1.string_field = "hello" embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] + embedded_1.dict_field = {"hello": "world"} + embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) + self.assertEqual(doc._get_changed_fields(), ["db_embedded_field"]) embedded_delta = { - 'db_string_field': 'hello', - 'db_int_field': 1, - 'db_dict_field': {'hello': 'world'}, - 'db_list_field': ['1', 2, {'hello': 'world'}] + "db_string_field": "hello", + "db_int_field": 1, + "db_dict_field": {"hello": "world"}, + "db_list_field": ["1", 2, {"hello": "world"}], } self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - self.assertEqual(doc._delta(), - ({'db_embedded_field': embedded_delta}, {})) + self.assertEqual(doc._delta(), ({"db_embedded_field": embedded_delta}, {})) doc.save() doc = doc.reload(10) doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_dict_field']) - self.assertEqual(doc.embedded_field._delta(), - ({}, {'db_dict_field': 1})) - self.assertEqual(doc._delta(), - ({}, {'db_embedded_field.db_dict_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_dict_field"]) + self.assertEqual(doc.embedded_field._delta(), ({}, {"db_dict_field": 1})) + self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_dict_field": 1})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.dict_field, {}) doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), - ({}, {'db_list_field': 1})) - self.assertEqual(doc._delta(), - ({}, {'db_embedded_field.db_list_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) + self.assertEqual(doc.embedded_field._delta(), ({}, {"db_list_field": 1})) + self.assertEqual(doc._delta(), ({}, {"db_embedded_field.db_list_field": 1})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field, []) embedded_2 = Embedded() - embedded_2.string_field = 'hello' + embedded_2.string_field = "hello" embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] + embedded_2.dict_field = {"hello": "world"} + embedded_2.list_field = ["1", 2, {"hello": "world"}] - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { - '_cls': 'Embedded', - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) + doc.embedded_field.list_field = ["1", 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), ["db_embedded_field.db_list_field"]) + self.assertEqual( + doc.embedded_field._delta(), + ( + { + "db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, + ), + ) - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { - '_cls': 'Embedded', - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) + self.assertEqual( + doc._delta(), + ( + { + "db_embedded_field.db_list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "db_string_field": "hello", + "db_dict_field": {"hello": "world"}, + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + }, + ] + }, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[0], "1") self.assertEqual(doc.embedded_field.list_field[1], 2) for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], - embedded_2[k]) + self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field.2.db_string_field']) - self.assertEqual(doc.embedded_field._delta(), - ({'db_list_field.2.db_string_field': 'world'}, {})) - self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, - {})) + doc.embedded_field.list_field[2].string_field = "world" + self.assertEqual( + doc._get_changed_fields(), + ["db_embedded_field.db_list_field.2.db_string_field"], + ) + self.assertEqual( + doc.embedded_field._delta(), + ({"db_list_field.2.db_string_field": "world"}, {}), + ) + self.assertEqual( + doc._delta(), + ({"db_embedded_field.db_list_field.2.db_string_field": "world"}, {}), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, - 'world') + self.assertEqual(doc.embedded_field.list_field[2].string_field, "world") # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field.2']) - self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': { - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}}, {})) - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field.2': { - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}} - }, {})) + self.assertEqual( + doc._get_changed_fields(), ["db_embedded_field.db_list_field.2"] + ) + self.assertEqual( + doc.embedded_field._delta(), + ( + { + "db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, + ), + ) + self.assertEqual( + doc._delta(), + ( + { + "db_embedded_field.db_list_field.2": { + "_cls": "Embedded", + "db_string_field": "hello world", + "db_int_field": 1, + "db_list_field": ["1", 2, {"hello": "world"}], + "db_dict_field": {"hello": "world"}, + } + }, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, - 'hello world') + self.assertEqual(doc.embedded_field.list_field[2].string_field, "hello world") # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_list_field': - [2, {'hello': 'world'}]}, {})) + self.assertEqual( + doc._delta(), + ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + ] + }, + {}, + ), + ) doc.save() doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_list_field': - [2, {'hello': 'world'}, 1]}, {})) + self.assertEqual( + doc._delta(), + ( + { + "db_embedded_field.db_list_field.2.db_list_field": [ + 2, + {"hello": "world"}, + 1, + ] + }, + {}, + ), + ) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, - [2, {'hello': 'world'}, 1]) + self.assertEqual( + doc.embedded_field.list_field[2].list_field, [2, {"hello": "world"}, 1] + ) doc.embedded_field.list_field[2].list_field.sort(key=str) doc.save() doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, - [1, 2, {'hello': 'world'}]) + self.assertEqual( + doc.embedded_field.list_field[2].list_field, [1, 2, {"hello": "world"}] + ) - del doc.embedded_field.list_field[2].list_field[2]['hello'] - self.assertEqual(doc._delta(), - ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1})) + del doc.embedded_field.list_field[2].list_field[2]["hello"] + self.assertEqual( + doc._delta(), + ({}, {"db_embedded_field.db_list_field.2.db_list_field.2.hello": 1}), + ) doc.save() doc = doc.reload(10) del doc.embedded_field.list_field[2].list_field - self.assertEqual(doc._delta(), ({}, - {'db_embedded_field.db_list_field.2.db_list_field': 1})) + self.assertEqual( + doc._delta(), ({}, {"db_embedded_field.db_list_field.2.db_list_field": 1}) + ) def test_delta_for_dynamic_documents(self): class Person(DynamicDocument): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} Person.drop_collection() p = Person(name="James", age=34) - self.assertEqual(p._delta(), ( - SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) + self.assertEqual( + p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + ) p.doc = 123 del p.doc - self.assertEqual(p._delta(), ( - SON([('_cls', 'Person'), ('name', 'James'), ('age', 34)]), {})) + self.assertEqual( + p._delta(), (SON([("_cls", "Person"), ("name", "James"), ("age", 34)]), {}) + ) p = Person() p.name = "Dean" @@ -620,20 +713,19 @@ class DeltaTest(MongoDBTestCase): p.age = 24 self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) + self.assertEqual(p._get_changed_fields(), ["age"]) + self.assertEqual(p._delta(), ({"age": 24}, {})) p = Person.objects(age=22).get() p.age = 24 self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) + self.assertEqual(p._get_changed_fields(), ["age"]) + self.assertEqual(p._delta(), ({"age": 24}, {})) p.save() self.assertEqual(1, Person.objects(age=24).count()) def test_dynamic_delta(self): - class Doc(DynamicDocument): pass @@ -645,41 +737,43 @@ class DeltaTest(MongoDBTestCase): self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._delta(), ({}, {})) - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + doc.string_field = "hello" + self.assertEqual(doc._get_changed_fields(), ["string_field"]) + self.assertEqual(doc._delta(), ({"string_field": "hello"}, {})) doc._changed_fields = [] doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + self.assertEqual(doc._get_changed_fields(), ["int_field"]) + self.assertEqual(doc._delta(), ({"int_field": 1}, {})) doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} + dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["dict_field"]) + self.assertEqual(doc._delta(), ({"dict_field": dict_value}, {})) doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] + list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + self.assertEqual(doc._get_changed_fields(), ["list_field"]) + self.assertEqual(doc._delta(), ({"list_field": list_value}, {})) # Test unsetting doc._changed_fields = [] doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["dict_field"]) + self.assertEqual(doc._delta(), ({}, {"dict_field": 1})) doc._changed_fields = [] doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + self.assertEqual(doc._get_changed_fields(), ["list_field"]) + self.assertEqual(doc._delta(), ({}, {"list_field": 1})) def test_delta_with_dbref_true(self): - person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) - employee.name = 'test' + person, organization, employee = self.circular_reference_deltas_2( + Document, Document, True + ) + employee.name = "test" self.assertEqual(organization._get_changed_fields(), []) @@ -690,11 +784,13 @@ class DeltaTest(MongoDBTestCase): organization.employees.append(person) updates, removals = organization._delta() self.assertEqual({}, removals) - self.assertIn('employees', updates) + self.assertIn("employees", updates) def test_delta_with_dbref_false(self): - person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) - employee.name = 'test' + person, organization, employee = self.circular_reference_deltas_2( + Document, Document, False + ) + employee.name = "test" self.assertEqual(organization._get_changed_fields(), []) @@ -705,7 +801,7 @@ class DeltaTest(MongoDBTestCase): organization.employees.append(person) updates, removals = organization._delta() self.assertEqual({}, removals) - self.assertIn('employees', updates) + self.assertIn("employees", updates) def test_nested_nested_fields_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -717,11 +813,13 @@ class DeltaTest(MongoDBTestCase): MyDoc.drop_collection() - mydoc = MyDoc(name='testcase1', subs={'a': {'b': EmbeddedDoc(name='foo')}}).save() + mydoc = MyDoc( + name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}} + ).save() mydoc = MyDoc.objects.first() - subdoc = mydoc.subs['a']['b'] - subdoc.name = 'bar' + subdoc = mydoc.subs["a"]["b"] + subdoc.name = "bar" self.assertEqual(["name"], subdoc._get_changed_fields()) self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields()) @@ -741,11 +839,11 @@ class DeltaTest(MongoDBTestCase): MyDoc().save() mydoc = MyDoc.objects.first() - mydoc.subs['a'] = EmbeddedDoc() + mydoc.subs["a"] = EmbeddedDoc() self.assertEqual(["subs.a"], mydoc._get_changed_fields()) - subdoc = mydoc.subs['a'] - subdoc.name = 'bar' + subdoc = mydoc.subs["a"] + subdoc.name = "bar" self.assertEqual(["name"], subdoc._get_changed_fields()) self.assertEqual(["subs.a"], mydoc._get_changed_fields()) @@ -763,16 +861,16 @@ class DeltaTest(MongoDBTestCase): MyDoc.drop_collection() - MyDoc(subs={'a': EmbeddedDoc(name='foo')}).save() + MyDoc(subs={"a": EmbeddedDoc(name="foo")}).save() mydoc = MyDoc.objects.first() - subdoc = mydoc.subs['a'] - subdoc.name = 'bar' + subdoc = mydoc.subs["a"] + subdoc.name = "bar" self.assertEqual(["name"], subdoc._get_changed_fields()) self.assertEqual(["subs.a.name"], mydoc._get_changed_fields()) - mydoc.subs['a'] = EmbeddedDoc() + mydoc.subs["a"] = EmbeddedDoc() self.assertEqual(["subs.a"], mydoc._get_changed_fields()) mydoc.save() @@ -787,39 +885,39 @@ class DeltaTest(MongoDBTestCase): class User(Document): name = StringField() - org = ReferenceField('Organization', required=True) + org = ReferenceField("Organization", required=True) Organization.drop_collection() User.drop_collection() - org1 = Organization(name='Org 1') + org1 = Organization(name="Org 1") org1.save() - org2 = Organization(name='Org 2') + org2 = Organization(name="Org 2") org2.save() - user = User(name='Fred', org=org1) + user = User(name="Fred", org=org1) user.save() org1.reload() org2.reload() user.reload() - self.assertEqual(org1.name, 'Org 1') - self.assertEqual(org2.name, 'Org 2') - self.assertEqual(user.name, 'Fred') + self.assertEqual(org1.name, "Org 1") + self.assertEqual(org2.name, "Org 2") + self.assertEqual(user.name, "Fred") - user.name = 'Harold' + user.name = "Harold" user.org = org2 - org2.name = 'New Org 2' - self.assertEqual(org2.name, 'New Org 2') + org2.name = "New Org 2" + self.assertEqual(org2.name, "New Org 2") user.save() org2.save() - self.assertEqual(org2.name, 'New Org 2') + self.assertEqual(org2.name, "New Org 2") org2.reload() - self.assertEqual(org2.name, 'New Org 2') + self.assertEqual(org2.name, "New Org 2") def test_delta_for_nested_map_fields(self): class UInfoDocument(Document): @@ -855,10 +953,10 @@ class DeltaTest(MongoDBTestCase): self.assertEqual(True, "users.007.roles.666" in delta[0]) self.assertEqual(True, "users.007.rolist" in delta[0]) self.assertEqual(True, "users.007.info" in delta[0]) - self.assertEqual('superadmin', delta[0]["users.007.roles.666"]["type"]) - self.assertEqual('oops', delta[0]["users.007.rolist"][0]["type"]) + self.assertEqual("superadmin", delta[0]["users.007.roles.666"]["type"]) + self.assertEqual("oops", delta[0]["users.007.rolist"][0]["type"]) self.assertEqual(uinfo.id, delta[0]["users.007.info"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 44548d27..414d3352 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -3,17 +3,16 @@ import unittest from mongoengine import * from tests.utils import MongoDBTestCase -__all__ = ("TestDynamicDocument", ) +__all__ = ("TestDynamicDocument",) class TestDynamicDocument(MongoDBTestCase): - def setUp(self): super(TestDynamicDocument, self).setUp() class Person(DynamicDocument): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} Person.drop_collection() @@ -26,8 +25,7 @@ class TestDynamicDocument(MongoDBTestCase): p.name = "James" p.age = 34 - self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", - "age": 34}) + self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", "age": 34}) self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"]) p.save() self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"]) @@ -35,7 +33,7 @@ class TestDynamicDocument(MongoDBTestCase): self.assertEqual(self.Person.objects.first().age, 34) # Confirm no changes to self.Person - self.assertFalse(hasattr(self.Person, 'age')) + self.assertFalse(hasattr(self.Person, "age")) def test_change_scope_of_variable(self): """Test changing the scope of a dynamic field has no adverse effects""" @@ -45,11 +43,11 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - p.misc = {'hello': 'world'} + p.misc = {"hello": "world"} p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) + self.assertEqual(p.misc, {"hello": "world"}) def test_delete_dynamic_field(self): """Test deleting a dynamic field works""" @@ -60,23 +58,23 @@ class TestDynamicDocument(MongoDBTestCase): p.save() p = self.Person.objects.get() - p.misc = {'hello': 'world'} + p.misc = {"hello": "world"} p.save() p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) + self.assertEqual(p.misc, {"hello": "world"}) collection = self.db[self.Person._get_collection_name()] obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'misc', 'name']) + self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "misc", "name"]) del p.misc p.save() p = self.Person.objects.get() - self.assertFalse(hasattr(p, 'misc')) + self.assertFalse(hasattr(p, "misc")) obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'name']) + self.assertEqual(sorted(obj.keys()), ["_cls", "_id", "name"]) def test_reload_after_unsetting(self): p = self.Person() @@ -91,77 +89,52 @@ class TestDynamicDocument(MongoDBTestCase): p.update(age=1) self.assertEqual(len(p._data), 3) - self.assertEqual(sorted(p._data.keys()), ['_cls', 'id', 'name']) + self.assertEqual(sorted(p._data.keys()), ["_cls", "id", "name"]) p.reload() self.assertEqual(len(p._data), 4) - self.assertEqual(sorted(p._data.keys()), ['_cls', 'age', 'id', 'name']) + self.assertEqual(sorted(p._data.keys()), ["_cls", "age", "id", "name"]) def test_fields_without_underscore(self): """Ensure we can query dynamic fields""" Person = self.Person - p = self.Person(name='Dean') + p = self.Person(name="Dean") p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) - self.assertEqual( - raw_p, - { - '_cls': u'Person', - '_id': p.id, - 'name': u'Dean' - } - ) + self.assertEqual(raw_p, {"_cls": u"Person", "_id": p.id, "name": u"Dean"}) - p.name = 'OldDean' - p.newattr = 'garbage' + p.name = "OldDean" + p.newattr = "garbage" p.save() raw_p = Person.objects.as_pymongo().get(id=p.id) self.assertEqual( raw_p, - { - '_cls': u'Person', - '_id': p.id, - 'name': 'OldDean', - 'newattr': u'garbage' - } + {"_cls": u"Person", "_id": p.id, "name": "OldDean", "newattr": u"garbage"}, ) def test_fields_containing_underscore(self): """Ensure we can query dynamic fields""" + class WeirdPerson(DynamicDocument): name = StringField() _name = StringField() WeirdPerson.drop_collection() - p = WeirdPerson(name='Dean', _name='Dean') + p = WeirdPerson(name="Dean", _name="Dean") p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual( - raw_p, - { - '_id': p.id, - '_name': u'Dean', - 'name': u'Dean' - } - ) + self.assertEqual(raw_p, {"_id": p.id, "_name": u"Dean", "name": u"Dean"}) - p.name = 'OldDean' - p._name = 'NewDean' - p._newattr1 = 'garbage' # Unknown fields won't be added + p.name = "OldDean" + p._name = "NewDean" + p._newattr1 = "garbage" # Unknown fields won't be added p.save() raw_p = WeirdPerson.objects.as_pymongo().get(id=p.id) - self.assertEqual( - raw_p, - { - '_id': p.id, - '_name': u'NewDean', - 'name': u'OldDean', - } - ) + self.assertEqual(raw_p, {"_id": p.id, "_name": u"NewDean", "name": u"OldDean"}) def test_dynamic_document_queries(self): """Ensure we can query dynamic fields""" @@ -193,26 +166,25 @@ class TestDynamicDocument(MongoDBTestCase): p2.age = 10 p2.save() - self.assertEqual(Person.objects(age__icontains='ten').count(), 2) + self.assertEqual(Person.objects(age__icontains="ten").count(), 2) self.assertEqual(Person.objects(age__gte=10).count(), 1) def test_complex_data_lookups(self): """Ensure you can query dynamic document dynamic fields""" p = self.Person() - p.misc = {'hello': 'world'} + p.misc = {"hello": "world"} p.save() - self.assertEqual(1, self.Person.objects(misc__hello='world').count()) + self.assertEqual(1, self.Person.objects(misc__hello="world").count()) def test_three_level_complex_data_lookups(self): """Ensure you can query three level document dynamic fields""" - p = self.Person.objects.create( - misc={'hello': {'hello2': 'world'}} - ) - self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count()) + p = self.Person.objects.create(misc={"hello": {"hello2": "world"}}) + self.assertEqual(1, self.Person.objects(misc__hello__hello2="world").count()) def test_complex_embedded_document_validation(self): """Ensure embedded dynamic documents may be validated""" + class Embedded(DynamicEmbeddedDocument): content = URLField() @@ -222,10 +194,10 @@ class TestDynamicDocument(MongoDBTestCase): Doc.drop_collection() doc = Doc() - embedded_doc_1 = Embedded(content='http://mongoengine.org') + embedded_doc_1 = Embedded(content="http://mongoengine.org") embedded_doc_1.validate() - embedded_doc_2 = Embedded(content='this is not a url') + embedded_doc_2 = Embedded(content="this is not a url") self.assertRaises(ValidationError, embedded_doc_2.validate) doc.embedded_field_1 = embedded_doc_1 @@ -234,15 +206,17 @@ class TestDynamicDocument(MongoDBTestCase): def test_inheritance(self): """Ensure that dynamic document plays nice with inheritance""" + class Employee(self.Person): salary = IntField() Employee.drop_collection() - self.assertIn('name', Employee._fields) - self.assertIn('salary', Employee._fields) - self.assertEqual(Employee._get_collection_name(), - self.Person._get_collection_name()) + self.assertIn("name", Employee._fields) + self.assertIn("salary", Employee._fields) + self.assertEqual( + Employee._get_collection_name(), self.Person._get_collection_name() + ) joe_bloggs = Employee() joe_bloggs.name = "Joe Bloggs" @@ -258,6 +232,7 @@ class TestDynamicDocument(MongoDBTestCase): def test_embedded_dynamic_document(self): """Test dynamic embedded documents""" + class Embedded(DynamicEmbeddedDocument): pass @@ -268,78 +243,88 @@ class TestDynamicDocument(MongoDBTestCase): doc = Doc() embedded_1 = Embedded() - embedded_1.string_field = 'hello' + embedded_1.string_field = "hello" embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] + embedded_1.dict_field = {"hello": "world"} + embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - self.assertEqual(doc.to_mongo(), { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}] - } - }) - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field, - ['1', 2, {'hello': 'world'}]) - - def test_complex_embedded_documents(self): - """Test complex dynamic embedded documents setups""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - embedded_1.list_field = ['1', 2, embedded_2] - doc.embedded_field = embedded_1 - - self.assertEqual(doc.to_mongo(), { - "embedded_field": { - "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, - {"_cls": "Embedded", + self.assertEqual( + doc.to_mongo(), + { + "embedded_field": { + "_cls": "Embedded", "string_field": "hello", "int_field": 1, "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}]} - ] - } - }) + "list_field": ["1", 2, {"hello": "world"}], + } + }, + ) + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc.embedded_field.__class__, Embedded) + self.assertEqual(doc.embedded_field.string_field, "hello") + self.assertEqual(doc.embedded_field.int_field, 1) + self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) + self.assertEqual(doc.embedded_field.list_field, ["1", 2, {"hello": "world"}]) + + def test_complex_embedded_documents(self): + """Test complex dynamic embedded documents setups""" + + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_1 = Embedded() + embedded_1.string_field = "hello" + embedded_1.int_field = 1 + embedded_1.dict_field = {"hello": "world"} + + embedded_2 = Embedded() + embedded_2.string_field = "hello" + embedded_2.int_field = 1 + embedded_2.dict_field = {"hello": "world"} + embedded_2.list_field = ["1", 2, {"hello": "world"}] + + embedded_1.list_field = ["1", 2, embedded_2] + doc.embedded_field = embedded_1 + + self.assertEqual( + doc.to_mongo(), + { + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": [ + "1", + 2, + { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ["1", 2, {"hello": "world"}], + }, + ], + } + }, + ) doc.save() doc = Doc.objects.first() self.assertEqual(doc.embedded_field.__class__, Embedded) self.assertEqual(doc.embedded_field.string_field, "hello") self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.dict_field, {"hello": "world"}) + self.assertEqual(doc.embedded_field.list_field[0], "1") self.assertEqual(doc.embedded_field.list_field[1], 2) embedded_field = doc.embedded_field.list_field[2] @@ -347,9 +332,8 @@ class TestDynamicDocument(MongoDBTestCase): self.assertEqual(embedded_field.__class__, Embedded) self.assertEqual(embedded_field.string_field, "hello") self.assertEqual(embedded_field.int_field, 1) - self.assertEqual(embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(embedded_field.list_field, ['1', 2, - {'hello': 'world'}]) + self.assertEqual(embedded_field.dict_field, {"hello": "world"}) + self.assertEqual(embedded_field.list_field, ["1", 2, {"hello": "world"}]) def test_dynamic_and_embedded(self): """Ensure embedded documents play nicely""" @@ -392,10 +376,15 @@ class TestDynamicDocument(MongoDBTestCase): Person.drop_collection() - Person(name="Eric", address=Address(city="San Francisco", street_number="1337")).save() + Person( + name="Eric", address=Address(city="San Francisco", street_number="1337") + ).save() - self.assertEqual(Person.objects.first().address.street_number, '1337') - self.assertEqual(Person.objects.only('address__street_number').first().address.street_number, '1337') + self.assertEqual(Person.objects.first().address.street_number, "1337") + self.assertEqual( + Person.objects.only("address__street_number").first().address.street_number, + "1337", + ) def test_dynamic_and_embedded_dict_access(self): """Ensure embedded dynamic documents work with dict[] style access""" @@ -435,5 +424,5 @@ class TestDynamicDocument(MongoDBTestCase): self.assertEqual(Person.objects.first().age, 35) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 764ef0c5..570e619e 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -10,13 +10,12 @@ from six import iteritems from mongoengine import * from mongoengine.connection import get_db -__all__ = ("IndexesTest", ) +__all__ = ("IndexesTest",) class IndexesTest(unittest.TestCase): - def setUp(self): - self.connection = connect(db='mongoenginetest') + self.connection = connect(db="mongoenginetest") self.db = get_db() class Person(Document): @@ -45,52 +44,43 @@ class IndexesTest(unittest.TestCase): self._index_test(DynamicDocument) def _index_test(self, InheritFrom): - class BlogPost(InheritFrom): - date = DateTimeField(db_field='addDate', default=datetime.now) + date = DateTimeField(db_field="addDate", default=datetime.now) category = StringField() tags = ListField(StringField()) - meta = { - 'indexes': [ - '-date', - 'tags', - ('category', '-date') - ] - } + meta = {"indexes": ["-date", "tags", ("category", "-date")]} - expected_specs = [{'fields': [('addDate', -1)]}, - {'fields': [('tags', 1)]}, - {'fields': [('category', 1), ('addDate', -1)]}] - self.assertEqual(expected_specs, BlogPost._meta['index_specs']) + expected_specs = [ + {"fields": [("addDate", -1)]}, + {"fields": [("tags", 1)]}, + {"fields": [("category", 1), ("addDate", -1)]}, + ] + self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') self.assertEqual(len(info), 4) - info = [value['key'] for key, value in iteritems(info)] + info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected['fields'], info) + self.assertIn(expected["fields"], info) def _index_test_inheritance(self, InheritFrom): - class BlogPost(InheritFrom): - date = DateTimeField(db_field='addDate', default=datetime.now) + date = DateTimeField(db_field="addDate", default=datetime.now) category = StringField() tags = ListField(StringField()) meta = { - 'indexes': [ - '-date', - 'tags', - ('category', '-date') - ], - 'allow_inheritance': True + "indexes": ["-date", "tags", ("category", "-date")], + "allow_inheritance": True, } - expected_specs = [{'fields': [('_cls', 1), ('addDate', -1)]}, - {'fields': [('_cls', 1), ('tags', 1)]}, - {'fields': [('_cls', 1), ('category', 1), - ('addDate', -1)]}] - self.assertEqual(expected_specs, BlogPost._meta['index_specs']) + expected_specs = [ + {"fields": [("_cls", 1), ("addDate", -1)]}, + {"fields": [("_cls", 1), ("tags", 1)]}, + {"fields": [("_cls", 1), ("category", 1), ("addDate", -1)]}, + ] + self.assertEqual(expected_specs, BlogPost._meta["index_specs"]) BlogPost.ensure_indexes() info = BlogPost.objects._collection.index_information() @@ -99,24 +89,24 @@ class IndexesTest(unittest.TestCase): # the indices on -date and tags will both contain # _cls as first element in the key self.assertEqual(len(info), 4) - info = [value['key'] for key, value in iteritems(info)] + info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected['fields'], info) + self.assertIn(expected["fields"], info) class ExtendedBlogPost(BlogPost): title = StringField() - meta = {'indexes': ['title']} + meta = {"indexes": ["title"]} - expected_specs.append({'fields': [('_cls', 1), ('title', 1)]}) - self.assertEqual(expected_specs, ExtendedBlogPost._meta['index_specs']) + expected_specs.append({"fields": [("_cls", 1), ("title", 1)]}) + self.assertEqual(expected_specs, ExtendedBlogPost._meta["index_specs"]) BlogPost.drop_collection() ExtendedBlogPost.ensure_indexes() info = ExtendedBlogPost.objects._collection.index_information() - info = [value['key'] for key, value in iteritems(info)] + info = [value["key"] for key, value in iteritems(info)] for expected in expected_specs: - self.assertIn(expected['fields'], info) + self.assertIn(expected["fields"], info) def test_indexes_document_inheritance(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -135,21 +125,15 @@ class IndexesTest(unittest.TestCase): class A(Document): title = StringField() - meta = { - 'indexes': [ - { - 'fields': ('title',), - }, - ], - 'allow_inheritance': True, - } + meta = {"indexes": [{"fields": ("title",)}], "allow_inheritance": True} class B(A): description = StringField() - self.assertEqual(A._meta['index_specs'], B._meta['index_specs']) - self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], - A._meta['index_specs']) + self.assertEqual(A._meta["index_specs"], B._meta["index_specs"]) + self.assertEqual( + [{"fields": [("_cls", 1), ("title", 1)]}], A._meta["index_specs"] + ) def test_index_no_cls(self): """Ensure index specs are inhertited correctly""" @@ -157,14 +141,12 @@ class IndexesTest(unittest.TestCase): class A(Document): title = StringField() meta = { - 'indexes': [ - {'fields': ('title',), 'cls': False}, - ], - 'allow_inheritance': True, - 'index_cls': False - } + "indexes": [{"fields": ("title",), "cls": False}], + "allow_inheritance": True, + "index_cls": False, + } - self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields']) + self.assertEqual([("title", 1)], A._meta["index_specs"][0]["fields"]) A._get_collection().drop_indexes() A.ensure_indexes() info = A._get_collection().index_information() @@ -174,34 +156,30 @@ class IndexesTest(unittest.TestCase): c = StringField() d = StringField() meta = { - 'indexes': [{'fields': ['c']}, {'fields': ['d'], 'cls': True}], - 'allow_inheritance': True + "indexes": [{"fields": ["c"]}, {"fields": ["d"], "cls": True}], + "allow_inheritance": True, } - self.assertEqual([('c', 1)], B._meta['index_specs'][1]['fields']) - self.assertEqual([('_cls', 1), ('d', 1)], B._meta['index_specs'][2]['fields']) + + self.assertEqual([("c", 1)], B._meta["index_specs"][1]["fields"]) + self.assertEqual([("_cls", 1), ("d", 1)], B._meta["index_specs"][2]["fields"]) def test_build_index_spec_is_not_destructive(self): - class MyDoc(Document): keywords = StringField() - meta = { - 'indexes': ['keywords'], - 'allow_inheritance': False - } + meta = {"indexes": ["keywords"], "allow_inheritance": False} - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) + self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) # Force index creation MyDoc.ensure_indexes() - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) + self.assertEqual(MyDoc._meta["index_specs"], [{"fields": [("keywords", 1)]}]) def test_embedded_document_index_meta(self): """Ensure that embedded document indexes are created explicitly """ + class Rank(EmbeddedDocument): title = StringField(required=True) @@ -209,138 +187,123 @@ class IndexesTest(unittest.TestCase): name = StringField(required=True) rank = EmbeddedDocumentField(Rank, required=False) - meta = { - 'indexes': [ - 'rank.title', - ], - 'allow_inheritance': False - } + meta = {"indexes": ["rank.title"], "allow_inheritance": False} - self.assertEqual([{'fields': [('rank.title', 1)]}], - Person._meta['index_specs']) + self.assertEqual([{"fields": [("rank.title", 1)]}], Person._meta["index_specs"]) Person.drop_collection() # Indexes are lazy so use list() to perform query list(Person.objects) info = Person.objects._collection.index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('rank.title', 1)], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("rank.title", 1)], info) def test_explicit_geo2d_index(self): """Ensure that geo2d indexes work when created via meta[indexes] """ + class Place(Document): location = DictField() - meta = { - 'allow_inheritance': True, - 'indexes': [ - '*location.point', - ] - } + meta = {"allow_inheritance": True, "indexes": ["*location.point"]} - self.assertEqual([{'fields': [('location.point', '2d')]}], - Place._meta['index_specs']) + self.assertEqual( + [{"fields": [("location.point", "2d")]}], Place._meta["index_specs"] + ) Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('location.point', '2d')], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("location.point", "2d")], info) def test_explicit_geo2d_index_embedded(self): """Ensure that geo2d indexes work when created via meta[indexes] """ + class EmbeddedLocation(EmbeddedDocument): location = DictField() class Place(Document): - current = DictField(field=EmbeddedDocumentField('EmbeddedLocation')) - meta = { - 'allow_inheritance': True, - 'indexes': [ - '*current.location.point', - ] - } + current = DictField(field=EmbeddedDocumentField("EmbeddedLocation")) + meta = {"allow_inheritance": True, "indexes": ["*current.location.point"]} - self.assertEqual([{'fields': [('current.location.point', '2d')]}], - Place._meta['index_specs']) + self.assertEqual( + [{"fields": [("current.location.point", "2d")]}], Place._meta["index_specs"] + ) Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('current.location.point', '2d')], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("current.location.point", "2d")], info) def test_explicit_geosphere_index(self): """Ensure that geosphere indexes work when created via meta[indexes] """ + class Place(Document): location = DictField() - meta = { - 'allow_inheritance': True, - 'indexes': [ - '(location.point', - ] - } + meta = {"allow_inheritance": True, "indexes": ["(location.point"]} - self.assertEqual([{'fields': [('location.point', '2dsphere')]}], - Place._meta['index_specs']) + self.assertEqual( + [{"fields": [("location.point", "2dsphere")]}], Place._meta["index_specs"] + ) Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('location.point', '2dsphere')], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("location.point", "2dsphere")], info) def test_explicit_geohaystack_index(self): """Ensure that geohaystack indexes work when created via meta[indexes] """ - raise SkipTest('GeoHaystack index creation is not supported for now' - 'from meta, as it requires a bucketSize parameter.') + raise SkipTest( + "GeoHaystack index creation is not supported for now" + "from meta, as it requires a bucketSize parameter." + ) class Place(Document): location = DictField() name = StringField() - meta = { - 'indexes': [ - (')location.point', 'name') - ] - } - self.assertEqual([{'fields': [('location.point', 'geoHaystack'), ('name', 1)]}], - Place._meta['index_specs']) + meta = {"indexes": [(")location.point", "name")]} + + self.assertEqual( + [{"fields": [("location.point", "geoHaystack"), ("name", 1)]}], + Place._meta["index_specs"], + ) Place.ensure_indexes() info = Place._get_collection().index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('location.point', 'geoHaystack')], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("location.point", "geoHaystack")], info) def test_create_geohaystack_index(self): """Ensure that geohaystack indexes can be created """ + class Place(Document): location = DictField() name = StringField() - Place.create_index({'fields': (')location.point', 'name')}, bucketSize=10) + Place.create_index({"fields": (")location.point", "name")}, bucketSize=10) info = Place._get_collection().index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('location.point', 'geoHaystack'), ('name', 1)], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("location.point", "geoHaystack"), ("name", 1)], info) def test_dictionary_indexes(self): """Ensure that indexes are used when meta[indexes] contains dictionaries instead of lists. """ + class BlogPost(Document): - date = DateTimeField(db_field='addDate', default=datetime.now) + date = DateTimeField(db_field="addDate", default=datetime.now) category = StringField() tags = ListField(StringField()) - meta = { - 'indexes': [ - {'fields': ['-date'], 'unique': True, 'sparse': True}, - ], - } + meta = {"indexes": [{"fields": ["-date"], "unique": True, "sparse": True}]} - self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, - 'sparse': True}], - BlogPost._meta['index_specs']) + self.assertEqual( + [{"fields": [("addDate", -1)], "unique": True, "sparse": True}], + BlogPost._meta["index_specs"], + ) BlogPost.drop_collection() @@ -351,48 +314,48 @@ class IndexesTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() - info = [(value['key'], - value.get('unique', False), - value.get('sparse', False)) - for key, value in iteritems(info)] - self.assertIn(([('addDate', -1)], True, True), info) + info = [ + (value["key"], value.get("unique", False), value.get("sparse", False)) + for key, value in iteritems(info) + ] + self.assertIn(([("addDate", -1)], True, True), info) BlogPost.drop_collection() def test_abstract_index_inheritance(self): - class UserBase(Document): user_guid = StringField(required=True) meta = { - 'abstract': True, - 'indexes': ['user_guid'], - 'allow_inheritance': True + "abstract": True, + "indexes": ["user_guid"], + "allow_inheritance": True, } class Person(UserBase): name = StringField() - meta = { - 'indexes': ['name'], - } + meta = {"indexes": ["name"]} + Person.drop_collection() - Person(name="test", user_guid='123').save() + Person(name="test", user_guid="123").save() self.assertEqual(1, Person.objects.count()) info = Person.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), - ['_cls_1_name_1', '_cls_1_user_guid_1', '_id_']) + self.assertEqual( + sorted(info.keys()), ["_cls_1_name_1", "_cls_1_user_guid_1", "_id_"] + ) def test_disable_index_creation(self): """Tests setting auto_create_index to False on the connection will disable any index generation. """ + class User(Document): meta = { - 'allow_inheritance': True, - 'indexes': ['user_guid'], - 'auto_create_index': False + "allow_inheritance": True, + "indexes": ["user_guid"], + "auto_create_index": False, } user_guid = StringField(required=True) @@ -401,88 +364,81 @@ class IndexesTest(unittest.TestCase): User.drop_collection() - User(user_guid='123').save() - MongoUser(user_guid='123').save() + User(user_guid="123").save() + MongoUser(user_guid="123").save() self.assertEqual(2, User.objects.count()) info = User.objects._collection.index_information() - self.assertEqual(list(info.keys()), ['_id_']) + self.assertEqual(list(info.keys()), ["_id_"]) User.ensure_indexes() info = User.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) + self.assertEqual(sorted(info.keys()), ["_cls_1_user_guid_1", "_id_"]) def test_embedded_document_index(self): """Tests settings an index on an embedded document """ + class Date(EmbeddedDocument): - year = IntField(db_field='yr') + year = IntField(db_field="yr") class BlogPost(Document): title = StringField() date = EmbeddedDocumentField(Date) - meta = { - 'indexes': [ - '-date.year' - ], - } + meta = {"indexes": ["-date.year"]} BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() - self.assertEqual(sorted(info.keys()), ['_id_', 'date.yr_-1']) + self.assertEqual(sorted(info.keys()), ["_id_", "date.yr_-1"]) def test_list_embedded_document_index(self): """Ensure list embedded documents can be indexed """ + class Tag(EmbeddedDocument): - name = StringField(db_field='tag') + name = StringField(db_field="tag") class BlogPost(Document): title = StringField() tags = ListField(EmbeddedDocumentField(Tag)) - meta = { - 'indexes': [ - 'tags.name' - ] - } + meta = {"indexes": ["tags.name"]} BlogPost.drop_collection() info = BlogPost.objects._collection.index_information() # we don't use _cls in with list fields by default - self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1']) + self.assertEqual(sorted(info.keys()), ["_id_", "tags.tag_1"]) - post1 = BlogPost(title="Embedded Indexes tests in place", - tags=[Tag(name="about"), Tag(name="time")]) + post1 = BlogPost( + title="Embedded Indexes tests in place", + tags=[Tag(name="about"), Tag(name="time")], + ) post1.save() def test_recursive_embedded_objects_dont_break_indexes(self): - class RecursiveObject(EmbeddedDocument): - obj = EmbeddedDocumentField('self') + obj = EmbeddedDocumentField("self") class RecursiveDocument(Document): recursive_obj = EmbeddedDocumentField(RecursiveObject) - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} RecursiveDocument.ensure_indexes() info = RecursiveDocument._get_collection().index_information() - self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_']) + self.assertEqual(sorted(info.keys()), ["_cls_1", "_id_"]) def test_covered_index(self): """Ensure that covered indexes can be used """ + class Test(Document): a = IntField() b = IntField() - meta = { - 'indexes': ['a'], - 'allow_inheritance': False - } + meta = {"indexes": ["a"], "allow_inheritance": False} Test.drop_collection() @@ -491,45 +447,51 @@ class IndexesTest(unittest.TestCase): # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. - query_plan = Test.objects(id=obj.id).exclude('a').explain() + query_plan = Test.objects(id=obj.id).exclude("a").explain() self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), - 'IDHACK' + query_plan.get("queryPlanner") + .get("winningPlan") + .get("inputStage") + .get("stage"), + "IDHACK", ) - query_plan = Test.objects(id=obj.id).only('id').explain() + query_plan = Test.objects(id=obj.id).only("id").explain() self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), - 'IDHACK' + query_plan.get("queryPlanner") + .get("winningPlan") + .get("inputStage") + .get("stage"), + "IDHACK", ) - query_plan = Test.objects(a=1).only('a').exclude('id').explain() + query_plan = Test.objects(a=1).only("a").exclude("id").explain() self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), - 'IXSCAN' + query_plan.get("queryPlanner") + .get("winningPlan") + .get("inputStage") + .get("stage"), + "IXSCAN", ) self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('stage'), - 'PROJECTION' + query_plan.get("queryPlanner").get("winningPlan").get("stage"), "PROJECTION" ) query_plan = Test.objects(a=1).explain() self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), - 'IXSCAN' + query_plan.get("queryPlanner") + .get("winningPlan") + .get("inputStage") + .get("stage"), + "IXSCAN", ) self.assertEqual( - query_plan.get('queryPlanner').get('winningPlan').get('stage'), - 'FETCH' + query_plan.get("queryPlanner").get("winningPlan").get("stage"), "FETCH" ) def test_index_on_id(self): class BlogPost(Document): - meta = { - 'indexes': [ - ['categories', 'id'] - ] - } + meta = {"indexes": [["categories", "id"]]} title = StringField(required=True) description = StringField(required=True) @@ -538,22 +500,16 @@ class IndexesTest(unittest.TestCase): BlogPost.drop_collection() indexes = BlogPost.objects._collection.index_information() - self.assertEqual(indexes['categories_1__id_1']['key'], - [('categories', 1), ('_id', 1)]) + self.assertEqual( + indexes["categories_1__id_1"]["key"], [("categories", 1), ("_id", 1)] + ) def test_hint(self): - TAGS_INDEX_NAME = 'tags_1' + TAGS_INDEX_NAME = "tags_1" class BlogPost(Document): tags = ListField(StringField()) - meta = { - 'indexes': [ - { - 'fields': ['tags'], - 'name': TAGS_INDEX_NAME - } - ], - } + meta = {"indexes": [{"fields": ["tags"], "name": TAGS_INDEX_NAME}]} BlogPost.drop_collection() @@ -562,41 +518,42 @@ class IndexesTest(unittest.TestCase): BlogPost(tags=tags).save() # Hinting by shape should work. - self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) + self.assertEqual(BlogPost.objects.hint([("tags", 1)]).count(), 10) # Hinting by index name should work. self.assertEqual(BlogPost.objects.hint(TAGS_INDEX_NAME).count(), 10) # Clearing the hint should work fine. self.assertEqual(BlogPost.objects.hint().count(), 10) - self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).hint().count(), 10) + self.assertEqual(BlogPost.objects.hint([("ZZ", 1)]).hint().count(), 10) # Hinting on a non-existent index shape should fail. with self.assertRaises(OperationFailure): - BlogPost.objects.hint([('ZZ', 1)]).count() + BlogPost.objects.hint([("ZZ", 1)]).count() # Hinting on a non-existent index name should fail. with self.assertRaises(OperationFailure): - BlogPost.objects.hint('Bad Name').count() + BlogPost.objects.hint("Bad Name").count() # Invalid shape argument (missing list brackets) should fail. with self.assertRaises(ValueError): - BlogPost.objects.hint(('tags', 1)).count() + BlogPost.objects.hint(("tags", 1)).count() def test_unique(self): """Ensure that uniqueness constraints are applied to fields. """ + class BlogPost(Document): title = StringField() slug = StringField(unique=True) BlogPost.drop_collection() - post1 = BlogPost(title='test1', slug='test') + post1 = BlogPost(title="test1", slug="test") post1.save() # Two posts with the same slug is not allowed - post2 = BlogPost(title='test2', slug='test') + post2 = BlogPost(title="test2", slug="test") self.assertRaises(NotUniqueError, post2.save) self.assertRaises(NotUniqueError, BlogPost.objects.insert, post2) @@ -605,54 +562,62 @@ class IndexesTest(unittest.TestCase): def test_primary_key_unique_not_working(self): """Relates to #1445""" + class Blog(Document): id = StringField(primary_key=True, unique=True) Blog.drop_collection() with self.assertRaises(OperationFailure) as ctx_err: - Blog(id='garbage').save() + Blog(id="garbage").save() # One of the errors below should happen. Which one depends on the # PyMongo version and dict order. err_msg = str(ctx_err.exception) self.assertTrue( - any([ - "The field 'unique' is not valid for an _id index specification" in err_msg, - "The field 'background' is not valid for an _id index specification" in err_msg, - "The field 'sparse' is not valid for an _id index specification" in err_msg, - ]) + any( + [ + "The field 'unique' is not valid for an _id index specification" + in err_msg, + "The field 'background' is not valid for an _id index specification" + in err_msg, + "The field 'sparse' is not valid for an _id index specification" + in err_msg, + ] + ) ) def test_unique_with(self): """Ensure that unique_with constraints are applied to fields. """ + class Date(EmbeddedDocument): - year = IntField(db_field='yr') + year = IntField(db_field="yr") class BlogPost(Document): title = StringField() date = EmbeddedDocumentField(Date) - slug = StringField(unique_with='date.year') + slug = StringField(unique_with="date.year") BlogPost.drop_collection() - post1 = BlogPost(title='test1', date=Date(year=2009), slug='test') + post1 = BlogPost(title="test1", date=Date(year=2009), slug="test") post1.save() # day is different so won't raise exception - post2 = BlogPost(title='test2', date=Date(year=2010), slug='test') + post2 = BlogPost(title="test2", date=Date(year=2010), slug="test") post2.save() # Now there will be two docs with the same slug and the same day: fail - post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') + post3 = BlogPost(title="test3", date=Date(year=2010), slug="test") self.assertRaises(OperationError, post3.save) def test_unique_embedded_document(self): """Ensure that uniqueness constraints are applied to fields on embedded documents. """ + class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') + year = IntField(db_field="yr") slug = StringField(unique=True) class BlogPost(Document): @@ -661,18 +626,15 @@ class IndexesTest(unittest.TestCase): BlogPost.drop_collection() - post1 = BlogPost(title='test1', - sub=SubDocument(year=2009, slug="test")) + post1 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test")) post1.save() # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', - sub=SubDocument(year=2010, slug='another-slug')) + post2 = BlogPost(title="test2", sub=SubDocument(year=2010, slug="another-slug")) post2.save() # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', - sub=SubDocument(year=2010, slug='test')) + post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) self.assertRaises(NotUniqueError, post3.save) def test_unique_embedded_document_in_list(self): @@ -681,8 +643,9 @@ class IndexesTest(unittest.TestCase): embedded documents, even when the embedded documents in in a list field. """ + class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') + year = IntField(db_field="yr") slug = StringField(unique=True) class BlogPost(Document): @@ -692,16 +655,15 @@ class IndexesTest(unittest.TestCase): BlogPost.drop_collection() post1 = BlogPost( - title='test1', subs=[ - SubDocument(year=2009, slug='conflict'), - SubDocument(year=2009, slug='conflict') - ] + title="test1", + subs=[ + SubDocument(year=2009, slug="conflict"), + SubDocument(year=2009, slug="conflict"), + ], ) post1.save() - post2 = BlogPost( - title='test2', subs=[SubDocument(year=2014, slug='conflict')] - ) + post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) self.assertRaises(NotUniqueError, post2.save) @@ -711,33 +673,32 @@ class IndexesTest(unittest.TestCase): embedded documents, even when the embedded documents in a sorted list field. """ + class SubDocument(EmbeddedDocument): year = IntField() slug = StringField(unique=True) class BlogPost(Document): title = StringField() - subs = SortedListField(EmbeddedDocumentField(SubDocument), - ordering='year') + subs = SortedListField(EmbeddedDocumentField(SubDocument), ordering="year") BlogPost.drop_collection() post1 = BlogPost( - title='test1', subs=[ - SubDocument(year=2009, slug='conflict'), - SubDocument(year=2009, slug='conflict') - ] + title="test1", + subs=[ + SubDocument(year=2009, slug="conflict"), + SubDocument(year=2009, slug="conflict"), + ], ) post1.save() # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn('subs.slug_1', indexes) - self.assertTrue(indexes['subs.slug_1']['unique']) + self.assertIn("subs.slug_1", indexes) + self.assertTrue(indexes["subs.slug_1"]["unique"]) - post2 = BlogPost( - title='test2', subs=[SubDocument(year=2014, slug='conflict')] - ) + post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) self.assertRaises(NotUniqueError, post2.save) @@ -747,6 +708,7 @@ class IndexesTest(unittest.TestCase): embedded documents, even when the embedded documents in an embedded list field. """ + class SubDocument(EmbeddedDocument): year = IntField() slug = StringField(unique=True) @@ -758,21 +720,20 @@ class IndexesTest(unittest.TestCase): BlogPost.drop_collection() post1 = BlogPost( - title='test1', subs=[ - SubDocument(year=2009, slug='conflict'), - SubDocument(year=2009, slug='conflict') - ] + title="test1", + subs=[ + SubDocument(year=2009, slug="conflict"), + SubDocument(year=2009, slug="conflict"), + ], ) post1.save() # confirm that the unique index is created indexes = BlogPost._get_collection().index_information() - self.assertIn('subs.slug_1', indexes) - self.assertTrue(indexes['subs.slug_1']['unique']) + self.assertIn("subs.slug_1", indexes) + self.assertTrue(indexes["subs.slug_1"]["unique"]) - post2 = BlogPost( - title='test2', subs=[SubDocument(year=2014, slug='conflict')] - ) + post2 = BlogPost(title="test2", subs=[SubDocument(year=2014, slug="conflict")]) self.assertRaises(NotUniqueError, post2.save) @@ -780,60 +741,51 @@ class IndexesTest(unittest.TestCase): """Ensure that uniqueness constraints are applied to fields on embedded documents. And work with unique_with as well. """ + class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') + year = IntField(db_field="yr") slug = StringField(unique=True) class BlogPost(Document): - title = StringField(unique_with='sub.year') + title = StringField(unique_with="sub.year") sub = EmbeddedDocumentField(SubDocument) BlogPost.drop_collection() - post1 = BlogPost(title='test1', - sub=SubDocument(year=2009, slug="test")) + post1 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test")) post1.save() # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', - sub=SubDocument(year=2010, slug='another-slug')) + post2 = BlogPost(title="test2", sub=SubDocument(year=2010, slug="another-slug")) post2.save() # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', - sub=SubDocument(year=2010, slug='test')) + post3 = BlogPost(title="test3", sub=SubDocument(year=2010, slug="test")) self.assertRaises(NotUniqueError, post3.save) # Now there will be two docs with the same title and year - post3 = BlogPost(title='test1', - sub=SubDocument(year=2009, slug='test-1')) + post3 = BlogPost(title="test1", sub=SubDocument(year=2009, slug="test-1")) self.assertRaises(NotUniqueError, post3.save) def test_ttl_indexes(self): - class Log(Document): created = DateTimeField(default=datetime.now) - meta = { - 'indexes': [ - {'fields': ['created'], 'expireAfterSeconds': 3600} - ] - } + meta = {"indexes": [{"fields": ["created"], "expireAfterSeconds": 3600}]} Log.drop_collection() # Indexes are lazy so use list() to perform query list(Log.objects) info = Log.objects._collection.index_information() - self.assertEqual(3600, - info['created_1']['expireAfterSeconds']) + self.assertEqual(3600, info["created_1"]["expireAfterSeconds"]) def test_index_drop_dups_silently_ignored(self): class Customer(Document): cust_id = IntField(unique=True, required=True) meta = { - 'indexes': ['cust_id'], - 'index_drop_dups': True, - 'allow_inheritance': False, + "indexes": ["cust_id"], + "index_drop_dups": True, + "allow_inheritance": False, } Customer.drop_collection() @@ -843,12 +795,10 @@ class IndexesTest(unittest.TestCase): """Ensure that 'unique' constraints aren't overridden by meta.indexes. """ + class Customer(Document): cust_id = IntField(unique=True, required=True) - meta = { - 'indexes': ['cust_id'], - 'allow_inheritance': False, - } + meta = {"indexes": ["cust_id"], "allow_inheritance": False} Customer.drop_collection() cust = Customer(cust_id=1) @@ -870,37 +820,39 @@ class IndexesTest(unittest.TestCase): """If you set a field as primary, then unexpected behaviour can occur. You won't create a duplicate but you will update an existing document. """ + class User(Document): name = StringField(primary_key=True) password = StringField() User.drop_collection() - user = User(name='huangz', password='secret') + user = User(name="huangz", password="secret") user.save() - user = User(name='huangz', password='secret2') + user = User(name="huangz", password="secret2") user.save() self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, 'secret2') + self.assertEqual(User.objects.get().password, "secret2") def test_unique_and_primary_create(self): """Create a new record with a duplicate primary key throws an exception """ + class User(Document): name = StringField(primary_key=True) password = StringField() User.drop_collection() - User.objects.create(name='huangz', password='secret') + User.objects.create(name="huangz", password="secret") with self.assertRaises(NotUniqueError): - User.objects.create(name='huangz', password='secret2') + User.objects.create(name="huangz", password="secret2") self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, 'secret') + self.assertEqual(User.objects.get().password, "secret") def test_index_with_pk(self): """Ensure you can use `pk` as part of a query""" @@ -909,21 +861,24 @@ class IndexesTest(unittest.TestCase): comment_id = IntField(required=True) try: + class BlogPost(Document): comments = EmbeddedDocumentField(Comment) - meta = {'indexes': [ - {'fields': ['pk', 'comments.comment_id'], - 'unique': True}]} + meta = { + "indexes": [ + {"fields": ["pk", "comments.comment_id"], "unique": True} + ] + } + except UnboundLocalError: - self.fail('Unbound local error at index + pk definition') + self.fail("Unbound local error at index + pk definition") info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in iteritems(info)] - index_item = [('_id', 1), ('comments.comment_id', 1)] + info = [value["key"] for key, value in iteritems(info)] + index_item = [("_id", 1), ("comments.comment_id", 1)] self.assertIn(index_item, info) def test_compound_key_embedded(self): - class CompoundKey(EmbeddedDocument): name = StringField(required=True) term = StringField(required=True) @@ -935,12 +890,12 @@ class IndexesTest(unittest.TestCase): my_key = CompoundKey(name="n", term="ok") report = ReportEmbedded(text="OK", key=my_key).save() - self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, - report.to_mongo()) + self.assertEqual( + {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() + ) self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key)) def test_compound_key_dictfield(self): - class ReportDictField(Document): key = DictField(primary_key=True) text = StringField() @@ -948,65 +903,60 @@ class IndexesTest(unittest.TestCase): my_key = {"name": "n", "term": "ok"} report = ReportDictField(text="OK", key=my_key).save() - self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}}, - report.to_mongo()) + self.assertEqual( + {"text": "OK", "_id": {"term": "ok", "name": "n"}}, report.to_mongo() + ) # We can't directly call ReportDictField.objects.get(pk=my_key), # because dicts are unordered, and if the order in MongoDB is # different than the one in `my_key`, this test will fail. - self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key['name'])) - self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key['term'])) + self.assertEqual(report, ReportDictField.objects.get(pk__name=my_key["name"])) + self.assertEqual(report, ReportDictField.objects.get(pk__term=my_key["term"])) def test_string_indexes(self): - class MyDoc(Document): provider_ids = DictField() - meta = { - "indexes": ["provider_ids.foo", "provider_ids.bar"], - } + meta = {"indexes": ["provider_ids.foo", "provider_ids.bar"]} info = MyDoc.objects._collection.index_information() - info = [value['key'] for key, value in iteritems(info)] - self.assertIn([('provider_ids.foo', 1)], info) - self.assertIn([('provider_ids.bar', 1)], info) + info = [value["key"] for key, value in iteritems(info)] + self.assertIn([("provider_ids.foo", 1)], info) + self.assertIn([("provider_ids.bar", 1)], info) def test_sparse_compound_indexes(self): - class MyDoc(Document): provider_ids = DictField() meta = { - "indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"), - 'sparse': True}], + "indexes": [ + {"fields": ("provider_ids.foo", "provider_ids.bar"), "sparse": True} + ] } info = MyDoc.objects._collection.index_information() - self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)], - info['provider_ids.foo_1_provider_ids.bar_1']['key']) - self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse']) + self.assertEqual( + [("provider_ids.foo", 1), ("provider_ids.bar", 1)], + info["provider_ids.foo_1_provider_ids.bar_1"]["key"], + ) + self.assertTrue(info["provider_ids.foo_1_provider_ids.bar_1"]["sparse"]) def test_text_indexes(self): class Book(Document): title = DictField() - meta = { - "indexes": ["$title"], - } + meta = {"indexes": ["$title"]} indexes = Book.objects._collection.index_information() self.assertIn("title_text", indexes) key = indexes["title_text"]["key"] - self.assertIn(('_fts', 'text'), key) + self.assertIn(("_fts", "text"), key) def test_hashed_indexes(self): - class Book(Document): ref_id = StringField() - meta = { - "indexes": ["#ref_id"], - } + meta = {"indexes": ["#ref_id"]} indexes = Book.objects._collection.index_information() self.assertIn("ref_id_hashed", indexes) - self.assertIn(('ref_id', 'hashed'), indexes["ref_id_hashed"]["key"]) + self.assertIn(("ref_id", "hashed"), indexes["ref_id_hashed"]["key"]) def test_indexes_after_database_drop(self): """ @@ -1017,35 +967,36 @@ class IndexesTest(unittest.TestCase): """ # Use a new connection and database since dropping the database could # cause concurrent tests to fail. - connection = connect(db='tempdatabase', - alias='test_indexes_after_database_drop') + connection = connect( + db="tempdatabase", alias="test_indexes_after_database_drop" + ) class BlogPost(Document): title = StringField() slug = StringField(unique=True) - meta = {'db_alias': 'test_indexes_after_database_drop'} + meta = {"db_alias": "test_indexes_after_database_drop"} try: BlogPost.drop_collection() # Create Post #1 - post1 = BlogPost(title='test1', slug='test') + post1 = BlogPost(title="test1", slug="test") post1.save() # Drop the Database - connection.drop_database('tempdatabase') + connection.drop_database("tempdatabase") # Re-create Post #1 - post1 = BlogPost(title='test1', slug='test') + post1 = BlogPost(title="test1", slug="test") post1.save() # Create Post #2 - post2 = BlogPost(title='test2', slug='test') + post2 = BlogPost(title="test2", slug="test") self.assertRaises(NotUniqueError, post2.save) finally: # Drop the temporary database at the end - connection.drop_database('tempdatabase') + connection.drop_database("tempdatabase") def test_index_dont_send_cls_option(self): """ @@ -1057,24 +1008,19 @@ class IndexesTest(unittest.TestCase): options that are passed to ensureIndex. For more details, see: https://jira.mongodb.org/browse/SERVER-769 """ + class TestDoc(Document): txt = StringField() meta = { - 'allow_inheritance': True, - 'indexes': [ - {'fields': ('txt',), 'cls': False} - ] + "allow_inheritance": True, + "indexes": [{"fields": ("txt",), "cls": False}], } class TestChildDoc(TestDoc): txt2 = StringField() - meta = { - 'indexes': [ - {'fields': ('txt2',), 'cls': False} - ] - } + meta = {"indexes": [{"fields": ("txt2",), "cls": False}]} TestDoc.drop_collection() TestDoc.ensure_indexes() @@ -1082,54 +1028,51 @@ class IndexesTest(unittest.TestCase): index_info = TestDoc._get_collection().index_information() for key in index_info: - del index_info[key]['v'] # drop the index version - we don't care about that here - if 'ns' in index_info[key]: - del index_info[key]['ns'] # drop the index namespace - we don't care about that here, MongoDB 3+ - if 'dropDups' in index_info[key]: - del index_info[key]['dropDups'] # drop the index dropDups - it is deprecated in MongoDB 3+ + del index_info[key][ + "v" + ] # drop the index version - we don't care about that here + if "ns" in index_info[key]: + del index_info[key][ + "ns" + ] # drop the index namespace - we don't care about that here, MongoDB 3+ + if "dropDups" in index_info[key]: + del index_info[key][ + "dropDups" + ] # drop the index dropDups - it is deprecated in MongoDB 3+ - self.assertEqual(index_info, { - 'txt_1': { - 'key': [('txt', 1)], - 'background': False + self.assertEqual( + index_info, + { + "txt_1": {"key": [("txt", 1)], "background": False}, + "_id_": {"key": [("_id", 1)]}, + "txt2_1": {"key": [("txt2", 1)], "background": False}, + "_cls_1": {"key": [("_cls", 1)], "background": False}, }, - '_id_': { - 'key': [('_id', 1)], - }, - 'txt2_1': { - 'key': [('txt2', 1)], - 'background': False - }, - '_cls_1': { - 'key': [('_cls', 1)], - 'background': False, - } - }) + ) def test_compound_index_underscore_cls_not_overwritten(self): """ Test that the compound index doesn't get another _cls when it is specified """ + class TestDoc(Document): shard_1 = StringField() txt_1 = StringField() meta = { - 'collection': 'test', - 'allow_inheritance': True, - 'sparse': True, - 'shard_key': 'shard_1', - 'indexes': [ - ('shard_1', '_cls', 'txt_1'), - ] + "collection": "test", + "allow_inheritance": True, + "sparse": True, + "shard_key": "shard_1", + "indexes": [("shard_1", "_cls", "txt_1")], } TestDoc.drop_collection() TestDoc.ensure_indexes() index_info = TestDoc._get_collection().index_information() - self.assertIn('shard_1_1__cls_1_txt_1_1', index_info) + self.assertIn("shard_1_1__cls_1_txt_1_1", index_info) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index d81039f4..4f21d5f4 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -4,18 +4,24 @@ import warnings from six import iteritems -from mongoengine import (BooleanField, Document, EmbeddedDocument, - EmbeddedDocumentField, GenericReferenceField, - IntField, ReferenceField, StringField) +from mongoengine import ( + BooleanField, + Document, + EmbeddedDocument, + EmbeddedDocumentField, + GenericReferenceField, + IntField, + ReferenceField, + StringField, +) from mongoengine.pymongo_support import list_collection_names from tests.utils import MongoDBTestCase from tests.fixtures import Base -__all__ = ('InheritanceTest', ) +__all__ = ("InheritanceTest",) class InheritanceTest(MongoDBTestCase): - def tearDown(self): for collection in list_collection_names(self.db): self.db.drop_collection(collection) @@ -25,16 +31,16 @@ class InheritanceTest(MongoDBTestCase): # and when object gets reloaded (prevent regression of #1950) class EmbedData(EmbeddedDocument): data = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class DataDoc(Document): name = StringField() embed = EmbeddedDocumentField(EmbedData) - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} - test_doc = DataDoc(name='test', embed=EmbedData(data='data')) - self.assertEqual(test_doc._cls, 'DataDoc') - self.assertEqual(test_doc.embed._cls, 'EmbedData') + test_doc = DataDoc(name="test", embed=EmbedData(data="data")) + self.assertEqual(test_doc._cls, "DataDoc") + self.assertEqual(test_doc.embed._cls, "EmbedData") test_doc.save() saved_doc = DataDoc.objects.with_id(test_doc.id) self.assertEqual(test_doc._cls, saved_doc._cls) @@ -44,163 +50,234 @@ class InheritanceTest(MongoDBTestCase): def test_superclasses(self): """Ensure that the correct list of superclasses is assembled. """ + class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Guppy(Fish): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass + meta = {"allow_inheritance": True} + + class Fish(Animal): + pass + + class Guppy(Fish): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Fish._superclasses, ('Animal',)) - self.assertEqual(Guppy._superclasses, ('Animal', 'Animal.Fish')) - self.assertEqual(Mammal._superclasses, ('Animal',)) - self.assertEqual(Dog._superclasses, ('Animal', 'Animal.Mammal')) - self.assertEqual(Human._superclasses, ('Animal', 'Animal.Mammal')) + self.assertEqual(Fish._superclasses, ("Animal",)) + self.assertEqual(Guppy._superclasses, ("Animal", "Animal.Fish")) + self.assertEqual(Mammal._superclasses, ("Animal",)) + self.assertEqual(Dog._superclasses, ("Animal", "Animal.Mammal")) + self.assertEqual(Human._superclasses, ("Animal", "Animal.Mammal")) def test_external_superclasses(self): """Ensure that the correct list of super classes is assembled when importing part of the model. """ - class Animal(Base): pass - class Fish(Animal): pass - class Guppy(Fish): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass - self.assertEqual(Animal._superclasses, ('Base', )) - self.assertEqual(Fish._superclasses, ('Base', 'Base.Animal',)) - self.assertEqual(Guppy._superclasses, ('Base', 'Base.Animal', - 'Base.Animal.Fish')) - self.assertEqual(Mammal._superclasses, ('Base', 'Base.Animal',)) - self.assertEqual(Dog._superclasses, ('Base', 'Base.Animal', - 'Base.Animal.Mammal')) - self.assertEqual(Human._superclasses, ('Base', 'Base.Animal', - 'Base.Animal.Mammal')) + class Animal(Base): + pass + + class Fish(Animal): + pass + + class Guppy(Fish): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass + + self.assertEqual(Animal._superclasses, ("Base",)) + self.assertEqual(Fish._superclasses, ("Base", "Base.Animal")) + self.assertEqual( + Guppy._superclasses, ("Base", "Base.Animal", "Base.Animal.Fish") + ) + self.assertEqual(Mammal._superclasses, ("Base", "Base.Animal")) + self.assertEqual( + Dog._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") + ) + self.assertEqual( + Human._superclasses, ("Base", "Base.Animal", "Base.Animal.Mammal") + ) def test_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is assembled. """ - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Guppy(Fish): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass - self.assertEqual(Animal._subclasses, ('Animal', - 'Animal.Fish', - 'Animal.Fish.Guppy', - 'Animal.Mammal', - 'Animal.Mammal.Dog', - 'Animal.Mammal.Human')) - self.assertEqual(Fish._subclasses, ('Animal.Fish', - 'Animal.Fish.Guppy',)) - self.assertEqual(Guppy._subclasses, ('Animal.Fish.Guppy',)) - self.assertEqual(Mammal._subclasses, ('Animal.Mammal', - 'Animal.Mammal.Dog', - 'Animal.Mammal.Human')) - self.assertEqual(Human._subclasses, ('Animal.Mammal.Human',)) + class Animal(Document): + meta = {"allow_inheritance": True} + + class Fish(Animal): + pass + + class Guppy(Fish): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass + + self.assertEqual( + Animal._subclasses, + ( + "Animal", + "Animal.Fish", + "Animal.Fish.Guppy", + "Animal.Mammal", + "Animal.Mammal.Dog", + "Animal.Mammal.Human", + ), + ) + self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Guppy")) + self.assertEqual(Guppy._subclasses, ("Animal.Fish.Guppy",)) + self.assertEqual( + Mammal._subclasses, + ("Animal.Mammal", "Animal.Mammal.Dog", "Animal.Mammal.Human"), + ) + self.assertEqual(Human._subclasses, ("Animal.Mammal.Human",)) def test_external_subclasses(self): """Ensure that the correct list of _subclasses (subclasses) is assembled when importing part of the model. """ - class Animal(Base): pass - class Fish(Animal): pass - class Guppy(Fish): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass - self.assertEqual(Animal._subclasses, ('Base.Animal', - 'Base.Animal.Fish', - 'Base.Animal.Fish.Guppy', - 'Base.Animal.Mammal', - 'Base.Animal.Mammal.Dog', - 'Base.Animal.Mammal.Human')) - self.assertEqual(Fish._subclasses, ('Base.Animal.Fish', - 'Base.Animal.Fish.Guppy',)) - self.assertEqual(Guppy._subclasses, ('Base.Animal.Fish.Guppy',)) - self.assertEqual(Mammal._subclasses, ('Base.Animal.Mammal', - 'Base.Animal.Mammal.Dog', - 'Base.Animal.Mammal.Human')) - self.assertEqual(Human._subclasses, ('Base.Animal.Mammal.Human',)) + class Animal(Base): + pass + + class Fish(Animal): + pass + + class Guppy(Fish): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass + + self.assertEqual( + Animal._subclasses, + ( + "Base.Animal", + "Base.Animal.Fish", + "Base.Animal.Fish.Guppy", + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", + ), + ) + self.assertEqual( + Fish._subclasses, ("Base.Animal.Fish", "Base.Animal.Fish.Guppy") + ) + self.assertEqual(Guppy._subclasses, ("Base.Animal.Fish.Guppy",)) + self.assertEqual( + Mammal._subclasses, + ( + "Base.Animal.Mammal", + "Base.Animal.Mammal.Dog", + "Base.Animal.Mammal.Human", + ), + ) + self.assertEqual(Human._subclasses, ("Base.Animal.Mammal.Human",)) def test_dynamic_declarations(self): """Test that declaring an extra class updates meta data""" class Animal(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ('Animal',)) + self.assertEqual(Animal._subclasses, ("Animal",)) # Test dynamically adding a class changes the meta data class Fish(Animal): pass self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish')) + self.assertEqual(Animal._subclasses, ("Animal", "Animal.Fish")) - self.assertEqual(Fish._superclasses, ('Animal', )) - self.assertEqual(Fish._subclasses, ('Animal.Fish',)) + self.assertEqual(Fish._superclasses, ("Animal",)) + self.assertEqual(Fish._subclasses, ("Animal.Fish",)) # Test dynamically adding an inherited class changes the meta data class Pike(Fish): pass self.assertEqual(Animal._superclasses, ()) - self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish', - 'Animal.Fish.Pike')) + self.assertEqual( + Animal._subclasses, ("Animal", "Animal.Fish", "Animal.Fish.Pike") + ) - self.assertEqual(Fish._superclasses, ('Animal', )) - self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike')) + self.assertEqual(Fish._superclasses, ("Animal",)) + self.assertEqual(Fish._subclasses, ("Animal.Fish", "Animal.Fish.Pike")) - self.assertEqual(Pike._superclasses, ('Animal', 'Animal.Fish')) - self.assertEqual(Pike._subclasses, ('Animal.Fish.Pike',)) + self.assertEqual(Pike._superclasses, ("Animal", "Animal.Fish")) + self.assertEqual(Pike._subclasses, ("Animal.Fish.Pike",)) def test_inheritance_meta_data(self): """Ensure that document may inherit fields from a superclass document. """ + class Person(Document): name = StringField() age = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Employee(Person): salary = IntField() - self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'], - sorted(Employee._fields.keys())) - self.assertEqual(Employee._get_collection_name(), - Person._get_collection_name()) + self.assertEqual( + ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + ) + self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) def test_inheritance_to_mongo_keys(self): """Ensure that document may inherit fields from a superclass document. """ + class Person(Document): name = StringField() age = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Employee(Person): salary = IntField() - self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'], - sorted(Employee._fields.keys())) - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) - self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) - self.assertEqual(Employee._get_collection_name(), - Person._get_collection_name()) + self.assertEqual( + ["_cls", "age", "id", "name", "salary"], sorted(Employee._fields.keys()) + ) + self.assertEqual( + Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] + ) + self.assertEqual( + Employee(name="Bob", age=35, salary=0).to_mongo().keys(), + ["_cls", "name", "age", "salary"], + ) + self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) def test_indexes_and_multiple_inheritance(self): """ Ensure that all of the indexes are created for a document with @@ -210,18 +287,12 @@ class InheritanceTest(MongoDBTestCase): class A(Document): a = StringField() - meta = { - 'allow_inheritance': True, - 'indexes': ['a'] - } + meta = {"allow_inheritance": True, "indexes": ["a"]} class B(Document): b = StringField() - meta = { - 'allow_inheritance': True, - 'indexes': ['b'] - } + meta = {"allow_inheritance": True, "indexes": ["b"]} class C(A, B): pass @@ -233,8 +304,12 @@ class InheritanceTest(MongoDBTestCase): C.ensure_indexes() self.assertEqual( - sorted([idx['key'] for idx in C._get_collection().index_information().values()]), - sorted([[(u'_cls', 1), (u'b', 1)], [(u'_id', 1)], [(u'_cls', 1), (u'a', 1)]]) + sorted( + [idx["key"] for idx in C._get_collection().index_information().values()] + ), + sorted( + [[(u"_cls", 1), (u"b", 1)], [(u"_id", 1)], [(u"_cls", 1), (u"a", 1)]] + ), ) def test_polymorphic_queries(self): @@ -242,11 +317,19 @@ class InheritanceTest(MongoDBTestCase): """ class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Dog(Mammal): pass - class Human(Mammal): pass + meta = {"allow_inheritance": True} + + class Fish(Animal): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass Animal.drop_collection() @@ -269,58 +352,68 @@ class InheritanceTest(MongoDBTestCase): """Ensure that inheritance is disabled by default on simple classes and that _cls will not be used. """ + class Animal(Document): name = StringField() # can't inherit because Animal didn't explicitly allow inheritance with self.assertRaises(ValueError) as cm: + class Dog(Animal): pass + self.assertIn("Document Animal may not be subclassed", str(cm.exception)) # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog').save() - self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) + dog = Animal(name="dog").save() + self.assertEqual(dog.to_mongo().keys(), ["_id", "name"]) collection = self.db[Animal._get_collection_name()] obj = collection.find_one() - self.assertNotIn('_cls', obj) + self.assertNotIn("_cls", obj) def test_cant_turn_off_inheritance_on_subclass(self): """Ensure if inheritance is on in a subclass you cant turn it off. """ + class Animal(Document): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} with self.assertRaises(ValueError) as cm: + class Mammal(Animal): - meta = {'allow_inheritance': False} - self.assertEqual(str(cm.exception), 'Only direct subclasses of Document may set "allow_inheritance" to False') + meta = {"allow_inheritance": False} + + self.assertEqual( + str(cm.exception), + 'Only direct subclasses of Document may set "allow_inheritance" to False', + ) def test_allow_inheritance_abstract_document(self): """Ensure that abstract documents can set inheritance rules and that _cls will not be used. """ + class FinalDocument(Document): - meta = {'abstract': True, - 'allow_inheritance': False} + meta = {"abstract": True, "allow_inheritance": False} class Animal(FinalDocument): name = StringField() with self.assertRaises(ValueError) as cm: + class Mammal(Animal): pass # Check that _cls isn't present in simple documents - doc = Animal(name='dog') - self.assertNotIn('_cls', doc.to_mongo()) + doc = Animal(name="dog") + self.assertNotIn("_cls", doc.to_mongo()) def test_using_abstract_class_in_reference_field(self): # Ensures no regression of #1920 class AbstractHuman(Document): - meta = {'abstract': True} + meta = {"abstract": True} class Dad(AbstractHuman): name = StringField() @@ -329,130 +422,122 @@ class InheritanceTest(MongoDBTestCase): dad = ReferenceField(AbstractHuman) # Referencing the abstract class address = StringField() - dad = Dad(name='5').save() - Home(dad=dad, address='street').save() + dad = Dad(name="5").save() + Home(dad=dad, address="street").save() home = Home.objects.first() - home.address = 'garbage' - home.save() # Was failing with ValidationError + home.address = "garbage" + home.save() # Was failing with ValidationError def test_abstract_class_referencing_self(self): # Ensures no regression of #1920 class Human(Document): - meta = {'abstract': True} - creator = ReferenceField('self', dbref=True) + meta = {"abstract": True} + creator = ReferenceField("self", dbref=True) class User(Human): name = StringField() - user = User(name='John').save() - user2 = User(name='Foo', creator=user).save() + user = User(name="John").save() + user2 = User(name="Foo", creator=user).save() user2 = User.objects.with_id(user2.id) - user2.name = 'Bar' - user2.save() # Was failing with ValidationError + user2.name = "Bar" + user2.save() # Was failing with ValidationError def test_abstract_handle_ids_in_metaclass_properly(self): - class City(Document): continent = StringField() - meta = {'abstract': True, - 'allow_inheritance': False} + meta = {"abstract": True, "allow_inheritance": False} class EuropeanCity(City): name = StringField() - berlin = EuropeanCity(name='Berlin', continent='Europe') + berlin = EuropeanCity(name="Berlin", continent="Europe") self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], 'id') + self.assertEqual(berlin._fields_ordered[0], "id") def test_auto_id_not_set_if_specific_in_parent_class(self): - class City(Document): continent = StringField() city_id = IntField(primary_key=True) - meta = {'abstract': True, - 'allow_inheritance': False} + meta = {"abstract": True, "allow_inheritance": False} class EuropeanCity(City): name = StringField() - berlin = EuropeanCity(name='Berlin', continent='Europe') + berlin = EuropeanCity(name="Berlin", continent="Europe") self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._fields_ordered), 3) - self.assertEqual(berlin._fields_ordered[0], 'city_id') + self.assertEqual(berlin._fields_ordered[0], "city_id") def test_auto_id_vs_non_pk_id_field(self): - class City(Document): continent = StringField() id = IntField() - meta = {'abstract': True, - 'allow_inheritance': False} + meta = {"abstract": True, "allow_inheritance": False} class EuropeanCity(City): name = StringField() - berlin = EuropeanCity(name='Berlin', continent='Europe') + berlin = EuropeanCity(name="Berlin", continent="Europe") self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered)) self.assertEqual(len(berlin._fields_ordered), 4) - self.assertEqual(berlin._fields_ordered[0], 'auto_id_0') + self.assertEqual(berlin._fields_ordered[0], "auto_id_0") berlin.save() self.assertEqual(berlin.pk, berlin.auto_id_0) def test_abstract_document_creation_does_not_fail(self): class City(Document): continent = StringField() - meta = {'abstract': True, - 'allow_inheritance': False} + meta = {"abstract": True, "allow_inheritance": False} - city = City(continent='asia') + city = City(continent="asia") self.assertEqual(None, city.pk) # TODO: expected error? Shouldn't we create a new error type? with self.assertRaises(KeyError): - setattr(city, 'pk', 1) + setattr(city, "pk", 1) def test_allow_inheritance_embedded_document(self): """Ensure embedded documents respect inheritance.""" + class Comment(EmbeddedDocument): content = StringField() with self.assertRaises(ValueError): + class SpecialComment(Comment): pass - doc = Comment(content='test') - self.assertNotIn('_cls', doc.to_mongo()) + doc = Comment(content="test") + self.assertNotIn("_cls", doc.to_mongo()) class Comment(EmbeddedDocument): content = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} - doc = Comment(content='test') - self.assertIn('_cls', doc.to_mongo()) + doc = Comment(content="test") + self.assertIn("_cls", doc.to_mongo()) def test_document_inheritance(self): """Ensure mutliple inheritance of abstract documents """ + class DateCreatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } + meta = {"allow_inheritance": True, "abstract": True} class DateUpdatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } + meta = {"allow_inheritance": True, "abstract": True} try: + class MyDocument(DateCreatedDocument, DateUpdatedDocument): pass + except Exception: self.assertTrue(False, "Couldn't create MyDocument class") @@ -460,47 +545,55 @@ class InheritanceTest(MongoDBTestCase): """Ensure that a document superclass can be marked as abstract thereby not using it as the name for the collection.""" - defaults = {'index_background': True, - 'index_drop_dups': True, - 'index_opts': {'hello': 'world'}, - 'allow_inheritance': True, - 'queryset_class': 'QuerySet', - 'db_alias': 'myDB', - 'shard_key': ('hello', 'world')} + defaults = { + "index_background": True, + "index_drop_dups": True, + "index_opts": {"hello": "world"}, + "allow_inheritance": True, + "queryset_class": "QuerySet", + "db_alias": "myDB", + "shard_key": ("hello", "world"), + } - meta_settings = {'abstract': True} + meta_settings = {"abstract": True} meta_settings.update(defaults) class Animal(Document): name = StringField() meta = meta_settings - class Fish(Animal): pass - class Guppy(Fish): pass + class Fish(Animal): + pass + + class Guppy(Fish): + pass class Mammal(Animal): - meta = {'abstract': True} - class Human(Mammal): pass + meta = {"abstract": True} + + class Human(Mammal): + pass for k, v in iteritems(defaults): for cls in [Animal, Fish, Guppy]: self.assertEqual(cls._meta[k], v) - self.assertNotIn('collection', Animal._meta) - self.assertNotIn('collection', Mammal._meta) + self.assertNotIn("collection", Animal._meta) + self.assertNotIn("collection", Mammal._meta) self.assertEqual(Animal._get_collection_name(), None) self.assertEqual(Mammal._get_collection_name(), None) - self.assertEqual(Fish._get_collection_name(), 'fish') - self.assertEqual(Guppy._get_collection_name(), 'fish') - self.assertEqual(Human._get_collection_name(), 'human') + self.assertEqual(Fish._get_collection_name(), "fish") + self.assertEqual(Guppy._get_collection_name(), "fish") + self.assertEqual(Human._get_collection_name(), "human") # ensure that a subclass of a non-abstract class can't be abstract with self.assertRaises(ValueError): + class EvilHuman(Human): evil = BooleanField(default=True) - meta = {'abstract': True} + meta = {"abstract": True} def test_abstract_embedded_documents(self): # 789: EmbeddedDocument shouldn't inherit abstract @@ -519,7 +612,7 @@ class InheritanceTest(MongoDBTestCase): class Drink(Document): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Drinker(Document): drink = GenericReferenceField() @@ -528,13 +621,13 @@ class InheritanceTest(MongoDBTestCase): warnings.simplefilter("error") class AcloholicDrink(Drink): - meta = {'collection': 'booze'} + meta = {"collection": "booze"} except SyntaxWarning: warnings.simplefilter("ignore") class AlcoholicDrink(Drink): - meta = {'collection': 'booze'} + meta = {"collection": "booze"} else: raise AssertionError("SyntaxWarning should be triggered") @@ -545,13 +638,13 @@ class InheritanceTest(MongoDBTestCase): AlcoholicDrink.drop_collection() Drinker.drop_collection() - red_bull = Drink(name='Red Bull') + red_bull = Drink(name="Red Bull") red_bull.save() programmer = Drinker(drink=red_bull) programmer.save() - beer = AlcoholicDrink(name='Beer') + beer = AlcoholicDrink(name="Beer") beer.save() real_person = Drinker(drink=beer) real_person.save() @@ -560,5 +653,5 @@ class InheritanceTest(MongoDBTestCase): self.assertEqual(Drinker.objects[1].drink.name, beer.name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/instance.py b/tests/document/instance.py index 06f65076..49606cff 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -16,24 +16,33 @@ from mongoengine import signals from mongoengine.base import _document_registry, get_document from mongoengine.connection import get_db from mongoengine.context_managers import query_counter, switch_db -from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, \ - InvalidQueryError, NotRegistered, NotUniqueError, SaveConditionError) +from mongoengine.errors import ( + FieldDoesNotExist, + InvalidDocumentError, + InvalidQueryError, + NotRegistered, + NotUniqueError, + SaveConditionError, +) from mongoengine.mongodb_support import MONGODB_34, MONGODB_36, get_mongodb_version from mongoengine.pymongo_support import list_collection_names from mongoengine.queryset import NULLIFY, Q from tests import fixtures -from tests.fixtures import (PickleDynamicEmbedded, PickleDynamicTest, \ - PickleEmbedded, PickleSignalsTest, PickleTest) +from tests.fixtures import ( + PickleDynamicEmbedded, + PickleDynamicTest, + PickleEmbedded, + PickleSignalsTest, + PickleTest, +) from tests.utils import MongoDBTestCase, get_as_pymongo -TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), - '../fields/mongoengine.png') +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "../fields/mongoengine.png") __all__ = ("InstanceTest",) class InstanceTest(MongoDBTestCase): - def setUp(self): class Job(EmbeddedDocument): name = StringField() @@ -58,7 +67,8 @@ class InstanceTest(MongoDBTestCase): def assertDbEqual(self, docs): self.assertEqual( list(self.Person._get_collection().find().sort("id")), - sorted(docs, key=lambda doc: doc["_id"])) + sorted(docs, key=lambda doc: doc["_id"]), + ) def assertHasInstance(self, field, instance): self.assertTrue(hasattr(field, "_instance")) @@ -70,12 +80,10 @@ class InstanceTest(MongoDBTestCase): def test_capped_collection(self): """Ensure that capped collections work properly.""" + class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - 'max_size': 4096, - } + meta = {"max_documents": 10, "max_size": 4096} Log.drop_collection() @@ -90,16 +98,14 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(Log.objects.count(), 10) options = Log.objects._collection.options() - self.assertEqual(options['capped'], True) - self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 4096) + self.assertEqual(options["capped"], True) + self.assertEqual(options["max"], 10) + self.assertEqual(options["size"], 4096) # Check that the document cannot be redefined with different options class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 11, - } + meta = {"max_documents": 11} # Accessing Document.objects creates the collection with self.assertRaises(InvalidCollectionError): @@ -107,11 +113,10 @@ class InstanceTest(MongoDBTestCase): def test_capped_collection_default(self): """Ensure that capped collections defaults work properly.""" + class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - } + meta = {"max_documents": 10} Log.drop_collection() @@ -119,16 +124,14 @@ class InstanceTest(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options['capped'], True) - self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 10 * 2**20) + self.assertEqual(options["capped"], True) + self.assertEqual(options["max"], 10) + self.assertEqual(options["size"], 10 * 2 ** 20) # Check that the document with default value can be recreated class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - } + meta = {"max_documents": 10} # Create the collection by accessing Document.objects Log.objects @@ -138,11 +141,10 @@ class InstanceTest(MongoDBTestCase): MongoDB rounds up max_size to next multiple of 256, recreating a doc with the same spec failed in mongoengine <0.10 """ + class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_size': 10000, - } + meta = {"max_size": 10000} Log.drop_collection() @@ -150,15 +152,13 @@ class InstanceTest(MongoDBTestCase): Log().save() options = Log.objects._collection.options() - self.assertEqual(options['capped'], True) - self.assertTrue(options['size'] >= 10000) + self.assertEqual(options["capped"], True) + self.assertTrue(options["size"] >= 10000) # Check that the document with odd max_size value can be recreated class Log(Document): date = DateTimeField(default=datetime.now) - meta = { - 'max_size': 10000, - } + meta = {"max_size": 10000} # Create the collection by accessing Document.objects Log.objects @@ -166,26 +166,28 @@ class InstanceTest(MongoDBTestCase): def test_repr(self): """Ensure that unicode representation works """ + class Article(Document): title = StringField() def __unicode__(self): return self.title - doc = Article(title=u'привет мир') + doc = Article(title=u"привет мир") - self.assertEqual('', repr(doc)) + self.assertEqual("", repr(doc)) def test_repr_none(self): """Ensure None values are handled correctly.""" + class Article(Document): title = StringField() def __str__(self): return None - doc = Article(title=u'привет мир') - self.assertEqual('', repr(doc)) + doc = Article(title=u"привет мир") + self.assertEqual("", repr(doc)) def test_queryset_resurrects_dropped_collection(self): self.Person.drop_collection() @@ -203,8 +205,9 @@ class InstanceTest(MongoDBTestCase): """Ensure that the correct subclasses are returned from a query when using references / generic references """ + class Animal(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Fish(Animal): pass @@ -255,7 +258,7 @@ class InstanceTest(MongoDBTestCase): class Stats(Document): created = DateTimeField(default=datetime.now) - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class CompareStats(Document): generated = DateTimeField(default=datetime.now) @@ -278,6 +281,7 @@ class InstanceTest(MongoDBTestCase): def test_db_field_load(self): """Ensure we load data correctly from the right db field.""" + class Person(Document): name = StringField(required=True) _rank = StringField(required=False, db_field="rank") @@ -297,14 +301,13 @@ class InstanceTest(MongoDBTestCase): def test_db_embedded_doc_field_load(self): """Ensure we load embedded document data correctly.""" + class Rank(EmbeddedDocument): title = StringField(required=True) class Person(Document): name = StringField(required=True) - rank_ = EmbeddedDocumentField(Rank, - required=False, - db_field='rank') + rank_ = EmbeddedDocumentField(Rank, required=False, db_field="rank") @property def rank(self): @@ -322,45 +325,50 @@ class InstanceTest(MongoDBTestCase): def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys.""" + class User(Document): username = StringField(primary_key=True) name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} User.drop_collection() - self.assertEqual(User._fields['username'].db_field, '_id') - self.assertEqual(User._meta['id_field'], 'username') + self.assertEqual(User._fields["username"].db_field, "_id") + self.assertEqual(User._meta["id_field"], "username") - User.objects.create(username='test', name='test user') + User.objects.create(username="test", name="test user") user = User.objects.first() - self.assertEqual(user.id, 'test') - self.assertEqual(user.pk, 'test') + self.assertEqual(user.id, "test") + self.assertEqual(user.pk, "test") user_dict = User.objects._collection.find_one() - self.assertEqual(user_dict['_id'], 'test') + self.assertEqual(user_dict["_id"], "test") def test_change_custom_id_field_in_subclass(self): """Subclasses cannot override which field is the primary key.""" + class User(Document): username = StringField(primary_key=True) name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} with self.assertRaises(ValueError) as e: + class EmailUser(User): email = StringField(primary_key=True) + exc = e.exception - self.assertEqual(str(exc), 'Cannot override primary key field') + self.assertEqual(str(exc), "Cannot override primary key field") def test_custom_id_field_is_required(self): """Ensure the custom primary key field is required.""" + class User(Document): username = StringField(primary_key=True) name = StringField() with self.assertRaises(ValidationError) as e: - User(name='test').save() + User(name="test").save() exc = e.exception self.assertTrue("Field is required: ['username']" in str(exc)) @@ -368,7 +376,7 @@ class InstanceTest(MongoDBTestCase): class Place(Document): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class NicePlace(Place): pass @@ -380,7 +388,7 @@ class InstanceTest(MongoDBTestCase): # Mimic Place and NicePlace definitions being in a different file # and the NicePlace model not being imported in at query time. - del(_document_registry['Place.NicePlace']) + del _document_registry["Place.NicePlace"] with self.assertRaises(NotRegistered): list(Place.objects.all()) @@ -388,10 +396,10 @@ class InstanceTest(MongoDBTestCase): def test_document_registry_regressions(self): class Location(Document): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Area(Location): - location = ReferenceField('Location', dbref=True) + location = ReferenceField("Location", dbref=True) Location.drop_collection() @@ -413,18 +421,19 @@ class InstanceTest(MongoDBTestCase): def test_key_like_attribute_access(self): person = self.Person(age=30) - self.assertEqual(person['age'], 30) + self.assertEqual(person["age"], 30) with self.assertRaises(KeyError): - person['unknown_attr'] + person["unknown_attr"] def test_save_abstract_document(self): """Saving an abstract document should fail.""" + class Doc(Document): name = StringField() - meta = {'abstract': True} + meta = {"abstract": True} with self.assertRaises(InvalidDocumentError): - Doc(name='aaa').save() + Doc(name="aaa").save() def test_reload(self): """Ensure that attributes may be reloaded.""" @@ -439,7 +448,7 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 20) - person.reload('age') + person.reload("age") self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 21) @@ -454,19 +463,22 @@ class InstanceTest(MongoDBTestCase): def test_reload_sharded(self): class Animal(Document): superphylum = StringField() - meta = {'shard_key': ('superphylum',)} + meta = {"shard_key": ("superphylum",)} Animal.drop_collection() - doc = Animal(superphylum='Deuterostomia') + doc = Animal(superphylum="Deuterostomia") doc.save() mongo_db = get_mongodb_version() - CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query' + CMD_QUERY_KEY = "command" if mongo_db >= MONGODB_36 else "query" with query_counter() as q: doc.reload() - query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] - self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum'])) + query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] + self.assertEqual( + set(query_op[CMD_QUERY_KEY]["filter"].keys()), + set(["_id", "superphylum"]), + ) Animal.drop_collection() @@ -476,10 +488,10 @@ class InstanceTest(MongoDBTestCase): class Animal(Document): superphylum = EmbeddedDocumentField(SuperPhylum) - meta = {'shard_key': ('superphylum.name',)} + meta = {"shard_key": ("superphylum.name",)} Animal.drop_collection() - doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) + doc = Animal(superphylum=SuperPhylum(name="Deuterostomia")) doc.save() doc.reload() Animal.drop_collection() @@ -488,49 +500,57 @@ class InstanceTest(MongoDBTestCase): """Ensures updating a doc with a specified shard_key includes it in the query. """ + class Animal(Document): is_mammal = BooleanField() name = StringField() - meta = {'shard_key': ('is_mammal', 'id')} + meta = {"shard_key": ("is_mammal", "id")} Animal.drop_collection() - doc = Animal(is_mammal=True, name='Dog') + doc = Animal(is_mammal=True, name="Dog") doc.save() mongo_db = get_mongodb_version() with query_counter() as q: - doc.name = 'Cat' + doc.name = "Cat" doc.save() - query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0] - self.assertEqual(query_op['op'], 'update') + query_op = q.db.system.profile.find({"ns": "mongoenginetest.animal"})[0] + self.assertEqual(query_op["op"], "update") if mongo_db <= MONGODB_34: - self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + self.assertEqual( + set(query_op["query"].keys()), set(["_id", "is_mammal"]) + ) else: - self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal'])) + self.assertEqual( + set(query_op["command"]["q"].keys()), set(["_id", "is_mammal"]) + ) Animal.drop_collection() def test_reload_with_changed_fields(self): """Ensures reloading will not affect changed fields""" + class User(Document): name = StringField() number = IntField() + User.drop_collection() user = User(name="Bob", number=1).save() user.name = "John" user.number = 2 - self.assertEqual(user._get_changed_fields(), ['name', 'number']) - user.reload('number') - self.assertEqual(user._get_changed_fields(), ['name']) + self.assertEqual(user._get_changed_fields(), ["name", "number"]) + user.reload("number") + self.assertEqual(user._get_changed_fields(), ["name"]) user.save() user.reload() self.assertEqual(user.name, "John") def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly.""" + class Embedded(EmbeddedDocument): dict_field = DictField() list_field = ListField() @@ -542,24 +562,30 @@ class InstanceTest(MongoDBTestCase): Doc.drop_collection() doc = Doc() - doc.dict_field = {'hello': 'world'} - doc.list_field = ['1', 2, {'hello': 'world'}] + doc.dict_field = {"hello": "world"} + doc.list_field = ["1", 2, {"hello": "world"}] embedded_1 = Embedded() - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] + embedded_1.dict_field = {"hello": "world"} + embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 doc.save() doc = doc.reload(10) doc.list_field.append(1) - doc.dict_field['woot'] = "woot" + doc.dict_field["woot"] = "woot" doc.embedded_field.list_field.append(1) - doc.embedded_field.dict_field['woot'] = "woot" + doc.embedded_field.dict_field["woot"] = "woot" - self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field.woot', 'embedded_field.list_field', - 'embedded_field.dict_field.woot']) + self.assertEqual( + doc._get_changed_fields(), + [ + "list_field", + "dict_field.woot", + "embedded_field.list_field", + "embedded_field.dict_field.woot", + ], + ) doc.save() self.assertEqual(len(doc.list_field), 4) @@ -572,9 +598,9 @@ class InstanceTest(MongoDBTestCase): doc.list_field.append(1) doc.save() - doc.dict_field['extra'] = 1 - doc = doc.reload(10, 'list_field') - self.assertEqual(doc._get_changed_fields(), ['dict_field.extra']) + doc.dict_field["extra"] = 1 + doc = doc.reload(10, "list_field") + self.assertEqual(doc._get_changed_fields(), ["dict_field.extra"]) self.assertEqual(len(doc.list_field), 5) self.assertEqual(len(doc.dict_field), 3) self.assertEqual(len(doc.embedded_field.list_field), 4) @@ -596,19 +622,17 @@ class InstanceTest(MongoDBTestCase): def test_reload_of_non_strict_with_special_field_name(self): """Ensures reloading works for documents with meta strict == False.""" + class Post(Document): - meta = { - 'strict': False - } + meta = {"strict": False} title = StringField() items = ListField() Post.drop_collection() - Post._get_collection().insert_one({ - "title": "Items eclipse", - "items": ["more lorem", "even more ipsum"] - }) + Post._get_collection().insert_one( + {"title": "Items eclipse", "items": ["more lorem", "even more ipsum"]} + ) post = Post.objects.first() post.reload() @@ -617,22 +641,22 @@ class InstanceTest(MongoDBTestCase): def test_dictionary_access(self): """Ensure that dictionary-style field access works properly.""" - person = self.Person(name='Test User', age=30, job=self.Job()) - self.assertEqual(person['name'], 'Test User') + person = self.Person(name="Test User", age=30, job=self.Job()) + self.assertEqual(person["name"], "Test User") - self.assertRaises(KeyError, person.__getitem__, 'salary') - self.assertRaises(KeyError, person.__setitem__, 'salary', 50) + self.assertRaises(KeyError, person.__getitem__, "salary") + self.assertRaises(KeyError, person.__setitem__, "salary", 50) - person['name'] = 'Another User' - self.assertEqual(person['name'], 'Another User') + person["name"] = "Another User" + self.assertEqual(person["name"], "Another User") # Length = length(assigned fields + id) self.assertEqual(len(person), 5) - self.assertIn('age', person) + self.assertIn("age", person) person.age = None - self.assertNotIn('age', person) - self.assertNotIn('nationality', person) + self.assertNotIn("age", person) + self.assertNotIn("nationality", person) def test_embedded_document_to_mongo(self): class Person(EmbeddedDocument): @@ -644,29 +668,33 @@ class InstanceTest(MongoDBTestCase): class Employee(Person): salary = IntField() - self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), - ['_cls', 'name', 'age']) + self.assertEqual( + Person(name="Bob", age=35).to_mongo().keys(), ["_cls", "name", "age"] + ) self.assertEqual( Employee(name="Bob", age=35, salary=0).to_mongo().keys(), - ['_cls', 'name', 'age', 'salary']) + ["_cls", "name", "age", "salary"], + ) def test_embedded_document_to_mongo_id(self): class SubDoc(EmbeddedDocument): id = StringField(required=True) sub_doc = SubDoc(id="abc") - self.assertEqual(sub_doc.to_mongo().keys(), ['id']) + self.assertEqual(sub_doc.to_mongo().keys(), ["id"]) def test_embedded_document(self): """Ensure that embedded documents are set up correctly.""" + class Comment(EmbeddedDocument): content = StringField() - self.assertIn('content', Comment._fields) - self.assertNotIn('id', Comment._fields) + self.assertIn("content", Comment._fields) + self.assertNotIn("id", Comment._fields) def test_embedded_document_instance(self): """Ensure that embedded documents can reference parent instance.""" + class Embedded(EmbeddedDocument): string = StringField() @@ -686,6 +714,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that embedded documents in complex fields can reference parent instance. """ + class Embedded(EmbeddedDocument): string = StringField() @@ -702,15 +731,19 @@ class InstanceTest(MongoDBTestCase): def test_embedded_document_complex_instance_no_use_db_field(self): """Ensure that use_db_field is propagated to list of Emb Docs.""" + class Embedded(EmbeddedDocument): - string = StringField(db_field='s') + string = StringField(db_field="s") class Doc(Document): embedded_field = ListField(EmbeddedDocumentField(Embedded)) - d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo( - use_db_field=False).to_dict() - self.assertEqual(d['embedded_field'], [{'string': 'Hi'}]) + d = ( + Doc(embedded_field=[Embedded(string="Hi")]) + .to_mongo(use_db_field=False) + .to_dict() + ) + self.assertEqual(d["embedded_field"], [{"string": "Hi"}]) def test_instance_is_set_on_setattr(self): class Email(EmbeddedDocument): @@ -722,7 +755,7 @@ class InstanceTest(MongoDBTestCase): Account.drop_collection() acc = Account() - acc.email = Email(email='test@example.com') + acc.email = Email(email="test@example.com") self.assertHasInstance(acc._data["email"], acc) acc.save() @@ -738,7 +771,7 @@ class InstanceTest(MongoDBTestCase): Account.drop_collection() acc = Account() - acc.emails = [Email(email='test@example.com')] + acc.emails = [Email(email="test@example.com")] self.assertHasInstance(acc._data["emails"][0], acc) acc.save() @@ -764,22 +797,19 @@ class InstanceTest(MongoDBTestCase): @classmethod def pre_save_post_validation(cls, sender, document, **kwargs): - document.content = 'checked' + document.content = "checked" - signals.pre_save_post_validation.connect(BlogPost.pre_save_post_validation, sender=BlogPost) + signals.pre_save_post_validation.connect( + BlogPost.pre_save_post_validation, sender=BlogPost + ) BlogPost.drop_collection() - post = BlogPost(content='unchecked').save() - self.assertEqual(post.content, 'checked') + post = BlogPost(content="unchecked").save() + self.assertEqual(post.content, "checked") # Make sure pre_save_post_validation changes makes it to the db raw_doc = get_as_pymongo(post) - self.assertEqual( - raw_doc, - { - 'content': 'checked', - '_id': post.id - }) + self.assertEqual(raw_doc, {"content": "checked", "_id": post.id}) # Important to disconnect as it could cause some assertions in test_signals # to fail (due to the garbage collection timing of this signal) @@ -810,13 +840,7 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(t.cleaned, True) raw_doc = get_as_pymongo(t) # Make sure clean changes makes it to the db - self.assertEqual( - raw_doc, - { - 'status': 'published', - 'cleaned': True, - '_id': t.id - }) + self.assertEqual(raw_doc, {"status": "published", "cleaned": True, "_id": t.id}) def test_document_embedded_clean(self): class TestEmbeddedDocument(EmbeddedDocument): @@ -824,12 +848,12 @@ class InstanceTest(MongoDBTestCase): y = IntField(required=True) z = IntField(required=True) - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} def clean(self): if self.z: if self.z != self.x + self.y: - raise ValidationError('Value of z != x + y') + raise ValidationError("Value of z != x + y") else: self.z = self.x + self.y @@ -846,7 +870,7 @@ class InstanceTest(MongoDBTestCase): expected_msg = "Value of z != x + y" self.assertIn(expected_msg, cm.exception.message) - self.assertEqual(cm.exception.to_dict(), {'doc': {'__all__': expected_msg}}) + self.assertEqual(cm.exception.to_dict(), {"doc": {"__all__": expected_msg}}) t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() self.assertEqual(t.doc.z, 35) @@ -869,7 +893,7 @@ class InstanceTest(MongoDBTestCase): docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] with self.assertRaises(InvalidQueryError): - doc1.modify({'id': doc2.id}, set__value=20) + doc1.modify({"id": doc2.id}, set__value=20) self.assertDbEqual(docs) @@ -878,7 +902,7 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - n_modified = doc1.modify({'name': doc2.name}, set__age=100) + n_modified = doc1.modify({"name": doc2.name}, set__age=100) self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -888,7 +912,7 @@ class InstanceTest(MongoDBTestCase): doc2 = self.Person(id=ObjectId(), name="jim", age=20) docs = [dict(doc1.to_mongo())] - n_modified = doc2.modify({'name': doc2.name}, set__age=100) + n_modified = doc2.modify({"name": doc2.name}, set__age=100) self.assertEqual(n_modified, 0) self.assertDbEqual(docs) @@ -896,7 +920,8 @@ class InstanceTest(MongoDBTestCase): def test_modify_update(self): other_doc = self.Person(name="bob", age=10).save() doc = self.Person( - name="jim", age=20, job=self.Job(name="10gen", years=3)).save() + name="jim", age=20, job=self.Job(name="10gen", years=3) + ).save() doc_copy = doc._from_son(doc.to_mongo()) @@ -906,7 +931,8 @@ class InstanceTest(MongoDBTestCase): doc.job.years = 3 n_modified = doc.modify( - set__age=21, set__job__name="MongoDB", unset__job__years=True) + set__age=21, set__job__name="MongoDB", unset__job__years=True + ) self.assertEqual(n_modified, 1) doc_copy.age = 21 doc_copy.job.name = "MongoDB" @@ -926,62 +952,56 @@ class InstanceTest(MongoDBTestCase): content = EmbeddedDocumentField(Content) post = BlogPost.objects.create( - tags=['python'], content=Content(keywords=['ipsum'])) - - self.assertEqual(post.tags, ['python']) - post.modify(push__tags__0=['code', 'mongo']) - self.assertEqual(post.tags, ['code', 'mongo', 'python']) - - # Assert same order of the list items is maintained in the db - self.assertEqual( - BlogPost._get_collection().find_one({'_id': post.pk})['tags'], - ['code', 'mongo', 'python'] + tags=["python"], content=Content(keywords=["ipsum"]) ) - self.assertEqual(post.content.keywords, ['ipsum']) - post.modify(push__content__keywords__0=['lorem']) - self.assertEqual(post.content.keywords, ['lorem', 'ipsum']) + self.assertEqual(post.tags, ["python"]) + post.modify(push__tags__0=["code", "mongo"]) + self.assertEqual(post.tags, ["code", "mongo", "python"]) # Assert same order of the list items is maintained in the db self.assertEqual( - BlogPost._get_collection().find_one({'_id': post.pk})['content']['keywords'], - ['lorem', 'ipsum'] + BlogPost._get_collection().find_one({"_id": post.pk})["tags"], + ["code", "mongo", "python"], + ) + + self.assertEqual(post.content.keywords, ["ipsum"]) + post.modify(push__content__keywords__0=["lorem"]) + self.assertEqual(post.content.keywords, ["lorem", "ipsum"]) + + # Assert same order of the list items is maintained in the db + self.assertEqual( + BlogPost._get_collection().find_one({"_id": post.pk})["content"][ + "keywords" + ], + ["lorem", "ipsum"], ) def test_save(self): """Ensure that a document may be saved in the database.""" # Create person object and save it to the database - person = self.Person(name='Test User', age=30) + person = self.Person(name="Test User", age=30) person.save() # Ensure that the object is in the database raw_doc = get_as_pymongo(person) self.assertEqual( raw_doc, - { - '_cls': 'Person', - 'name': 'Test User', - 'age': 30, - '_id': person.id - }) + {"_cls": "Person", "name": "Test User", "age": 30, "_id": person.id}, + ) def test_save_skip_validation(self): class Recipient(Document): email = EmailField(required=True) - recipient = Recipient(email='not-an-email') + recipient = Recipient(email="not-an-email") with self.assertRaises(ValidationError): recipient.save() recipient.save(validate=False) raw_doc = get_as_pymongo(recipient) - self.assertEqual( - raw_doc, - { - 'email': 'not-an-email', - '_id': recipient.id - }) + self.assertEqual(raw_doc, {"email": "not-an-email", "_id": recipient.id}) def test_save_with_bad_id(self): class Clown(Document): @@ -1012,8 +1032,8 @@ class InstanceTest(MongoDBTestCase): def test_save_max_recursion_not_hit(self): class Person(Document): name = StringField() - parent = ReferenceField('self') - friend = ReferenceField('self') + parent = ReferenceField("self") + friend = ReferenceField("self") Person.drop_collection() @@ -1031,28 +1051,28 @@ class InstanceTest(MongoDBTestCase): # Confirm can save and it resets the changed fields without hitting # max recursion error p0 = Person.objects.first() - p0.name = 'wpjunior' + p0.name = "wpjunior" p0.save() def test_save_max_recursion_not_hit_with_file_field(self): class Foo(Document): name = StringField() picture = FileField() - bar = ReferenceField('self') + bar = ReferenceField("self") Foo.drop_collection() - a = Foo(name='hello').save() + a = Foo(name="hello").save() a.bar = a - with open(TEST_IMAGE_PATH, 'rb') as test_image: + with open(TEST_IMAGE_PATH, "rb") as test_image: a.picture = test_image a.save() # Confirm can save and it resets the changed fields without hitting # max recursion error b = Foo.objects.with_id(a.id) - b.name = 'world' + b.name = "world" b.save() self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) @@ -1060,7 +1080,7 @@ class InstanceTest(MongoDBTestCase): def test_save_cascades(self): class Person(Document): name = StringField() - parent = ReferenceField('self') + parent = ReferenceField("self") Person.drop_collection() @@ -1082,7 +1102,7 @@ class InstanceTest(MongoDBTestCase): def test_save_cascade_kwargs(self): class Person(Document): name = StringField() - parent = ReferenceField('self') + parent = ReferenceField("self") Person.drop_collection() @@ -1102,9 +1122,9 @@ class InstanceTest(MongoDBTestCase): def test_save_cascade_meta_false(self): class Person(Document): name = StringField() - parent = ReferenceField('self') + parent = ReferenceField("self") - meta = {'cascade': False} + meta = {"cascade": False} Person.drop_collection() @@ -1130,9 +1150,9 @@ class InstanceTest(MongoDBTestCase): def test_save_cascade_meta_true(self): class Person(Document): name = StringField() - parent = ReferenceField('self') + parent = ReferenceField("self") - meta = {'cascade': False} + meta = {"cascade": False} Person.drop_collection() @@ -1194,7 +1214,7 @@ class InstanceTest(MongoDBTestCase): w1 = Widget(toggle=False, save_id=UUID(1)) # ignore save_condition on new record creation - w1.save(save_condition={'save_id': UUID(42)}) + w1.save(save_condition={"save_id": UUID(42)}) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.save_id, UUID(1)) @@ -1204,8 +1224,9 @@ class InstanceTest(MongoDBTestCase): flip(w1) self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) - self.assertRaises(SaveConditionError, - w1.save, save_condition={'save_id': UUID(42)}) + self.assertRaises( + SaveConditionError, w1.save, save_condition={"save_id": UUID(42)} + ) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.count, 0) @@ -1214,7 +1235,7 @@ class InstanceTest(MongoDBTestCase): flip(w1) self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) - w1.save(save_condition={'save_id': UUID(1)}) + w1.save(save_condition={"save_id": UUID(1)}) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 1) @@ -1227,27 +1248,29 @@ class InstanceTest(MongoDBTestCase): flip(w1) w1.save_id = UUID(2) - w1.save(save_condition={'save_id': old_id}) + w1.save(save_condition={"save_id": old_id}) w1.reload() self.assertFalse(w1.toggle) self.assertEqual(w1.count, 2) flip(w2) flip(w2) - self.assertRaises(SaveConditionError, - w2.save, save_condition={'save_id': old_id}) + self.assertRaises( + SaveConditionError, w2.save, save_condition={"save_id": old_id} + ) w2.reload() self.assertFalse(w2.toggle) self.assertEqual(w2.count, 2) # save_condition uses mongoengine-style operator syntax flip(w1) - w1.save(save_condition={'count__lt': w1.count}) + w1.save(save_condition={"count__lt": w1.count}) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 3) flip(w1) - self.assertRaises(SaveConditionError, - w1.save, save_condition={'count__gte': w1.count}) + self.assertRaises( + SaveConditionError, w1.save, save_condition={"count__gte": w1.count} + ) w1.reload() self.assertTrue(w1.toggle) self.assertEqual(w1.count, 3) @@ -1259,19 +1282,19 @@ class InstanceTest(MongoDBTestCase): WildBoy.drop_collection() - WildBoy(age=12, name='John').save() + WildBoy(age=12, name="John").save() boy1 = WildBoy.objects().first() boy2 = WildBoy.objects().first() boy1.age = 99 boy1.save() - boy2.name = 'Bob' + boy2.name = "Bob" boy2.save() fresh_boy = WildBoy.objects().first() self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, 'Bob') + self.assertEqual(fresh_boy.name, "Bob") def test_save_update_selectively_with_custom_pk(self): # Prevents regression of #2082 @@ -1282,30 +1305,30 @@ class InstanceTest(MongoDBTestCase): WildBoy.drop_collection() - WildBoy(pk_id='A', age=12, name='John').save() + WildBoy(pk_id="A", age=12, name="John").save() boy1 = WildBoy.objects().first() boy2 = WildBoy.objects().first() boy1.age = 99 boy1.save() - boy2.name = 'Bob' + boy2.name = "Bob" boy2.save() fresh_boy = WildBoy.objects().first() self.assertEqual(fresh_boy.age, 99) - self.assertEqual(fresh_boy.name, 'Bob') + self.assertEqual(fresh_boy.name, "Bob") def test_update(self): """Ensure that an existing document is updated instead of be overwritten. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30) + person = self.Person(name="Test User", age=30) person.save() # Create same person object, with same id, without age - same_person = self.Person(name='Test') + same_person = self.Person(name="Test") same_person.id = person.id same_person.save() @@ -1322,54 +1345,54 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(person.age, same_person.age) # Confirm the saved values - self.assertEqual(person.name, 'Test') + self.assertEqual(person.name, "Test") self.assertEqual(person.age, 30) # Test only / exclude only updates included fields - person = self.Person.objects.only('name').get() - person.name = 'User' + person = self.Person.objects.only("name").get() + person.name = "User" person.save() person.reload() - self.assertEqual(person.name, 'User') + self.assertEqual(person.name, "User") self.assertEqual(person.age, 30) # test exclude only updates set fields - person = self.Person.objects.exclude('name').get() + person = self.Person.objects.exclude("name").get() person.age = 21 person.save() person.reload() - self.assertEqual(person.name, 'User') + self.assertEqual(person.name, "User") self.assertEqual(person.age, 21) # Test only / exclude can set non excluded / included fields - person = self.Person.objects.only('name').get() - person.name = 'Test' + person = self.Person.objects.only("name").get() + person.name = "Test" person.age = 30 person.save() person.reload() - self.assertEqual(person.name, 'Test') + self.assertEqual(person.name, "Test") self.assertEqual(person.age, 30) # test exclude only updates set fields - person = self.Person.objects.exclude('name').get() - person.name = 'User' + person = self.Person.objects.exclude("name").get() + person.name = "User" person.age = 21 person.save() person.reload() - self.assertEqual(person.name, 'User') + self.assertEqual(person.name, "User") self.assertEqual(person.age, 21) # Confirm does remove unrequired fields - person = self.Person.objects.exclude('name').get() + person = self.Person.objects.exclude("name").get() person.age = None person.save() person.reload() - self.assertEqual(person.name, 'User') + self.assertEqual(person.name, "User") self.assertEqual(person.age, None) person = self.Person.objects.get() @@ -1384,19 +1407,18 @@ class InstanceTest(MongoDBTestCase): def test_update_rename_operator(self): """Test the $rename operator.""" coll = self.Person._get_collection() - doc = self.Person(name='John').save() - raw_doc = coll.find_one({'_id': doc.pk}) - self.assertEqual(set(raw_doc.keys()), set(['_id', '_cls', 'name'])) + doc = self.Person(name="John").save() + raw_doc = coll.find_one({"_id": doc.pk}) + self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "name"])) - doc.update(rename__name='first_name') - raw_doc = coll.find_one({'_id': doc.pk}) - self.assertEqual(set(raw_doc.keys()), - set(['_id', '_cls', 'first_name'])) - self.assertEqual(raw_doc['first_name'], 'John') + doc.update(rename__name="first_name") + raw_doc = coll.find_one({"_id": doc.pk}) + self.assertEqual(set(raw_doc.keys()), set(["_id", "_cls", "first_name"])) + self.assertEqual(raw_doc["first_name"], "John") def test_inserts_if_you_set_the_pk(self): - p1 = self.Person(name='p1', id=bson.ObjectId()).save() - p2 = self.Person(name='p2') + p1 = self.Person(name="p1", id=bson.ObjectId()).save() + p2 = self.Person(name="p2") p2.id = bson.ObjectId() p2.save() @@ -1410,33 +1432,34 @@ class InstanceTest(MongoDBTestCase): pass class Doc(Document): - string_field = StringField(default='1') + string_field = StringField(default="1") int_field = IntField(default=1) float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) embedded_document_field = EmbeddedDocumentField( - EmbeddedDoc, default=lambda: EmbeddedDoc()) + EmbeddedDoc, default=lambda: EmbeddedDoc() + ) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=bson.ObjectId) - reference_field = ReferenceField(Simple, default=lambda: - Simple().save()) + reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) generic_reference_field = GenericReferenceField( - default=lambda: Simple().save()) - sorted_list_field = SortedListField(IntField(), - default=lambda: [1, 2, 3]) + default=lambda: Simple().save() + ) + sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") geo_point_field = GeoPointField(default=lambda: [1, 2]) sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( - default=lambda: EmbeddedDoc()) + default=lambda: EmbeddedDoc() + ) Simple.drop_collection() Doc.drop_collection() @@ -1454,13 +1477,13 @@ class InstanceTest(MongoDBTestCase): # try updating a non-saved document with self.assertRaises(OperationError): - person = self.Person(name='dcrosta') - person.update(set__name='Dan Crosta') + person = self.Person(name="dcrosta") + person.update(set__name="Dan Crosta") - author = self.Person(name='dcrosta') + author = self.Person(name="dcrosta") author.save() - author.update(set__name='Dan Crosta') + author.update(set__name="Dan Crosta") author.reload() p1 = self.Person.objects.first() @@ -1490,9 +1513,9 @@ class InstanceTest(MongoDBTestCase): def test_embedded_update(self): """Test update on `EmbeddedDocumentField` fields.""" + class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - required=True) + log_message = StringField(verbose_name="Log message", required=True) class Site(Document): page = EmbeddedDocumentField(Page) @@ -1512,28 +1535,30 @@ class InstanceTest(MongoDBTestCase): def test_update_list_field(self): """Test update on `ListField` with $pull + $in. """ + class Doc(Document): foo = ListField(StringField()) Doc.drop_collection() - doc = Doc(foo=['a', 'b', 'c']) + doc = Doc(foo=["a", "b", "c"]) doc.save() # Update doc = Doc.objects.first() - doc.update(pull__foo__in=['a', 'c']) + doc.update(pull__foo__in=["a", "c"]) doc = Doc.objects.first() - self.assertEqual(doc.foo, ['b']) + self.assertEqual(doc.foo, ["b"]) def test_embedded_update_db_field(self): """Test update on `EmbeddedDocumentField` fields when db_field is other than default. """ + class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - db_field="page_log_message", - required=True) + log_message = StringField( + verbose_name="Log message", db_field="page_log_message", required=True + ) class Site(Document): page = EmbeddedDocumentField(Page) @@ -1553,13 +1578,14 @@ class InstanceTest(MongoDBTestCase): def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields.""" + class User(self.Person): active = BooleanField(default=True) User.drop_collection() # Create person object and save it to the database - user = User(name='Test User', age=30, active=True) + user = User(name="Test User", age=30, active=True) user.save() user.reload() @@ -1570,28 +1596,31 @@ class InstanceTest(MongoDBTestCase): user.age = 21 user.save() - same_person.name = 'User' + same_person.name = "User" same_person.save() person = self.Person.objects.get() - self.assertEqual(person.name, 'User') + self.assertEqual(person.name, "User") self.assertEqual(person.age, 21) self.assertEqual(person.active, False) - def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc(self): + def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( + self + ): # Refers to Issue #1685 class EmbeddedChildModel(EmbeddedDocument): id = DictField(primary_key=True) class ParentModel(Document): - child = EmbeddedDocumentField( - EmbeddedChildModel) + child = EmbeddedDocumentField(EmbeddedChildModel) - emb = EmbeddedChildModel(id={'1': [1]}) + emb = EmbeddedChildModel(id={"1": [1]}) changed_fields = ParentModel(child=emb)._get_changed_fields() self.assertEqual(changed_fields, []) - def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc(self): + def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( + self + ): # Refers to Issue #1685 class User(Document): id = IntField(primary_key=True) @@ -1604,12 +1633,12 @@ class InstanceTest(MongoDBTestCase): Message.drop_collection() # All objects share the same id, but each in a different collection - user = User(id=1, name='user-name').save() + user = User(id=1, name="user-name").save() message = Message(id=1, author=user).save() - message.author.name = 'tutu' + message.author.name = "tutu" self.assertEqual(message._get_changed_fields(), []) - self.assertEqual(user._get_changed_fields(), ['name']) + self.assertEqual(user._get_changed_fields(), ["name"]) def test__get_changed_fields_same_ids_embedded(self): # Refers to Issue #1768 @@ -1624,24 +1653,25 @@ class InstanceTest(MongoDBTestCase): Message.drop_collection() # All objects share the same id, but each in a different collection - user = User(id=1, name='user-name') # .save() + user = User(id=1, name="user-name") # .save() message = Message(id=1, author=user).save() - message.author.name = 'tutu' - self.assertEqual(message._get_changed_fields(), ['author.name']) + message.author.name = "tutu" + self.assertEqual(message._get_changed_fields(), ["author.name"]) message.save() message_fetched = Message.objects.with_id(message.id) - self.assertEqual(message_fetched.author.name, 'tutu') + self.assertEqual(message_fetched.author.name, "tutu") def test_query_count_when_saving(self): """Ensure references don't cause extra fetches when saving""" + class Organization(Document): name = StringField() class User(Document): name = StringField() - orgs = ListField(ReferenceField('Organization')) + orgs = ListField(ReferenceField("Organization")) class Feed(Document): name = StringField() @@ -1667,9 +1697,9 @@ class InstanceTest(MongoDBTestCase): user = User.objects.first() # Even if stored as ObjectId's internally mongoengine uses DBRefs # As ObjectId's aren't automatically derefenced - self.assertIsInstance(user._data['orgs'][0], DBRef) + self.assertIsInstance(user._data["orgs"][0], DBRef) self.assertIsInstance(user.orgs[0], Organization) - self.assertIsInstance(user._data['orgs'][0], Organization) + self.assertIsInstance(user._data["orgs"][0], Organization) # Changing a value with query_counter() as q: @@ -1731,6 +1761,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that $set and $unset actions are performed in the same operation. """ + class FooBar(Document): foo = StringField(default=None) bar = StringField(default=None) @@ -1738,11 +1769,11 @@ class InstanceTest(MongoDBTestCase): FooBar.drop_collection() # write an entity with a single prop - foo = FooBar(foo='foo').save() + foo = FooBar(foo="foo").save() - self.assertEqual(foo.foo, 'foo') + self.assertEqual(foo.foo, "foo") del foo.foo - foo.bar = 'bar' + foo.bar = "bar" with query_counter() as q: self.assertEqual(0, q) @@ -1751,6 +1782,7 @@ class InstanceTest(MongoDBTestCase): def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields.""" + class Comment(EmbeddedDocument): published = BooleanField(default=True) @@ -1762,7 +1794,7 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() # Create person object and save it to the database - person = User(name='Test User', age=30, active=True) + person = User(name="Test User", age=30, active=True) person.comments.append(Comment()) person.save() person.reload() @@ -1777,17 +1809,17 @@ class InstanceTest(MongoDBTestCase): self.assertFalse(person.comments[0].published) # Simple dict w - person.comments_dict['first_post'] = Comment() + person.comments_dict["first_post"] = Comment() person.save() person = self.Person.objects.get() - self.assertTrue(person.comments_dict['first_post'].published) + self.assertTrue(person.comments_dict["first_post"].published) - person.comments_dict['first_post'].published = False + person.comments_dict["first_post"].published = False person.save() person = self.Person.objects.get() - self.assertFalse(person.comments_dict['first_post'].published) + self.assertFalse(person.comments_dict["first_post"].published) def test_delete(self): """Ensure that document may be deleted using the delete method.""" @@ -1801,31 +1833,30 @@ class InstanceTest(MongoDBTestCase): """Ensure that a document may be saved with a custom _id.""" # Create person object and save it to the database - person = self.Person(name='Test User', age=30, - id='497ce96f395f2f052a494fd4') + person = self.Person(name="Test User", age=30, id="497ce96f395f2f052a494fd4") person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + person_obj = collection.find_one({"name": "Test User"}) + self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, - pk='497ce96f395f2f052a494fd4') + person = self.Person(name="Test User", age=30, pk="497ce96f395f2f052a494fd4") person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] - person_obj = collection.find_one({'name': 'Test User'}) - self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + person_obj = collection.find_one({"name": "Test User"}) + self.assertEqual(str(person_obj["_id"]), "497ce96f395f2f052a494fd4") def test_save_list(self): """Ensure that a list field may be properly saved.""" + class Comment(EmbeddedDocument): content = StringField() @@ -1836,37 +1867,36 @@ class InstanceTest(MongoDBTestCase): BlogPost.drop_collection() - post = BlogPost(content='Went for a walk today...') - post.tags = tags = ['fun', 'leisure'] - comments = [Comment(content='Good for you'), Comment(content='Yay.')] + post = BlogPost(content="Went for a walk today...") + post.tags = tags = ["fun", "leisure"] + comments = [Comment(content="Good for you"), Comment(content="Yay.")] post.comments = comments post.save() collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() - self.assertEqual(post_obj['tags'], tags) - for comment_obj, comment in zip(post_obj['comments'], comments): - self.assertEqual(comment_obj['content'], comment['content']) + self.assertEqual(post_obj["tags"], tags) + for comment_obj, comment in zip(post_obj["comments"], comments): + self.assertEqual(comment_obj["content"], comment["content"]) def test_list_search_by_embedded(self): class User(Document): username = StringField(required=True) - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class Comment(EmbeddedDocument): comment = StringField() - user = ReferenceField(User, - required=True) + user = ReferenceField(User, required=True) - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class Page(Document): comments = ListField(EmbeddedDocumentField(Comment)) - meta = {'allow_inheritance': False, - 'indexes': [ - {'fields': ['comments.user']} - ]} + meta = { + "allow_inheritance": False, + "indexes": [{"fields": ["comments.user"]}], + } User.drop_collection() Page.drop_collection() @@ -1880,14 +1910,22 @@ class InstanceTest(MongoDBTestCase): u3 = User(username="hmarr") u3.save() - p1 = Page(comments=[Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world"), - Comment(user=u3, comment="Ping Pong"), - Comment(user=u1, comment="I like a beer")]) + p1 = Page( + comments=[ + Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world"), + Comment(user=u3, comment="Ping Pong"), + Comment(user=u1, comment="I like a beer"), + ] + ) p1.save() - p2 = Page(comments=[Comment(user=u1, comment="Its very good"), - Comment(user=u2, comment="Hello world")]) + p2 = Page( + comments=[ + Comment(user=u1, comment="Its very good"), + Comment(user=u2, comment="Hello world"), + ] + ) p2.save() p3 = Page(comments=[Comment(user=u3, comment="Its very good")]) @@ -1896,20 +1934,15 @@ class InstanceTest(MongoDBTestCase): p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")]) p4.save() - self.assertEqual( - [p1, p2], - list(Page.objects.filter(comments__user=u1))) - self.assertEqual( - [p1, p2, p4], - list(Page.objects.filter(comments__user=u2))) - self.assertEqual( - [p1, p3], - list(Page.objects.filter(comments__user=u3))) + self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) + self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) + self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be saved in the database. """ + class EmployeeDetails(EmbeddedDocument): position = StringField() @@ -1918,26 +1951,26 @@ class InstanceTest(MongoDBTestCase): details = EmbeddedDocumentField(EmployeeDetails) # Create employee object and save it to the database - employee = Employee(name='Test Employee', age=50, salary=20000) - employee.details = EmployeeDetails(position='Developer') + employee = Employee(name="Test Employee", age=50, salary=20000) + employee.details = EmployeeDetails(position="Developer") employee.save() # Ensure that the object is in the database collection = self.db[self.Person._get_collection_name()] - employee_obj = collection.find_one({'name': 'Test Employee'}) - self.assertEqual(employee_obj['name'], 'Test Employee') - self.assertEqual(employee_obj['age'], 50) + employee_obj = collection.find_one({"name": "Test Employee"}) + self.assertEqual(employee_obj["name"], "Test Employee") + self.assertEqual(employee_obj["age"], 50) # Ensure that the 'details' embedded object saved correctly - self.assertEqual(employee_obj['details']['position'], 'Developer') + self.assertEqual(employee_obj["details"]["position"], "Developer") def test_embedded_update_after_save(self): """Test update of `EmbeddedDocumentField` attached to a newly saved document. """ + class Page(EmbeddedDocument): - log_message = StringField(verbose_name="Log message", - required=True) + log_message = StringField(verbose_name="Log message", required=True) class Site(Document): page = EmbeddedDocumentField(Page) @@ -1957,6 +1990,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that a document with an embedded document field may be saved in the database. """ + class EmployeeDetails(EmbeddedDocument): position = StringField() @@ -1965,22 +1999,21 @@ class InstanceTest(MongoDBTestCase): details = EmbeddedDocumentField(EmployeeDetails) # Create employee object and save it to the database - employee = Employee(name='Test Employee', age=50, salary=20000) - employee.details = EmployeeDetails(position='Developer') + employee = Employee(name="Test Employee", age=50, salary=20000) + employee.details = EmployeeDetails(position="Developer") employee.save() # Test updating an embedded document - promoted_employee = Employee.objects.get(name='Test Employee') - promoted_employee.details.position = 'Senior Developer' + promoted_employee = Employee.objects.get(name="Test Employee") + promoted_employee.details.position = "Senior Developer" promoted_employee.save() promoted_employee.reload() - self.assertEqual(promoted_employee.name, 'Test Employee') + self.assertEqual(promoted_employee.name, "Test Employee") self.assertEqual(promoted_employee.age, 50) # Ensure that the 'details' embedded object saved correctly - self.assertEqual( - promoted_employee.details.position, 'Senior Developer') + self.assertEqual(promoted_employee.details.position, "Senior Developer") # Test removal promoted_employee.details = None @@ -1996,12 +2029,12 @@ class InstanceTest(MongoDBTestCase): class Foo(EmbeddedDocument, NameMixin): quantity = IntField() - self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys())) + self.assertEqual(["name", "quantity"], sorted(Foo._fields.keys())) class Bar(Document, NameMixin): widgets = StringField() - self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) + self.assertEqual(["id", "name", "widgets"], sorted(Bar._fields.keys())) def test_mixin_inheritance(self): class BaseMixIn(object): @@ -2015,8 +2048,7 @@ class InstanceTest(MongoDBTestCase): age = IntField() TestDoc.drop_collection() - t = TestDoc(count=12, data="test", - comment="great!", age=19) + t = TestDoc(count=12, data="test", comment="great!", age=19) t.save() @@ -2031,17 +2063,18 @@ class InstanceTest(MongoDBTestCase): """Ensure that a document reference field may be saved in the database. """ + class BlogPost(Document): - meta = {'collection': 'blogpost_1'} + meta = {"collection": "blogpost_1"} content = StringField() author = ReferenceField(self.Person) BlogPost.drop_collection() - author = self.Person(name='Test User') + author = self.Person(name="Test User") author.save() - post = BlogPost(content='Watched some TV today... how exciting.') + post = BlogPost(content="Watched some TV today... how exciting.") # Should only reference author when saving post.author = author post.save() @@ -2049,15 +2082,15 @@ class InstanceTest(MongoDBTestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertIsInstance(post_obj._data['author'], bson.DBRef) + self.assertIsInstance(post_obj._data["author"], bson.DBRef) self.assertIsInstance(post_obj.author, self.Person) - self.assertEqual(post_obj.author.name, 'Test User') + self.assertEqual(post_obj.author.name, "Test User") # Ensure that the dereferenced object may be changed and saved post_obj.author.age = 25 post_obj.author.save() - author = list(self.Person.objects(name='Test User'))[-1] + author = list(self.Person.objects(name="Test User"))[-1] self.assertEqual(author.age, 25) def test_duplicate_db_fields_raise_invalid_document_error(self): @@ -2065,12 +2098,14 @@ class InstanceTest(MongoDBTestCase): declare the same db_field. """ with self.assertRaises(InvalidDocumentError): + class Foo(Document): name = StringField() - name2 = StringField(db_field='name') + name2 = StringField(db_field="name") def test_invalid_son(self): """Raise an error if loading invalid data.""" + class Occurrence(EmbeddedDocument): number = IntField() @@ -2081,21 +2116,24 @@ class InstanceTest(MongoDBTestCase): occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) with self.assertRaises(InvalidDocumentError): - Word._from_son({ - 'stem': [1, 2, 3], - 'forms': 1, - 'count': 'one', - 'occurs': {"hello": None} - }) + Word._from_son( + { + "stem": [1, 2, 3], + "forms": 1, + "count": "one", + "occurs": {"hello": None}, + } + ) # Tests for issue #1438: https://github.com/MongoEngine/mongoengine/issues/1438 with self.assertRaises(ValueError): - Word._from_son('this is not a valid SON dict') + Word._from_son("this is not a valid SON dict") def test_reverse_delete_rule_cascade_and_nullify(self): """Ensure that a referenced document is also deleted upon deletion. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) @@ -2104,13 +2142,13 @@ class InstanceTest(MongoDBTestCase): self.Person.drop_collection() BlogPost.drop_collection() - author = self.Person(name='Test User') + author = self.Person(name="Test User") author.save() - reviewer = self.Person(name='Re Viewer') + reviewer = self.Person(name="Re Viewer") reviewer.save() - post = BlogPost(content='Watched some TV') + post = BlogPost(content="Watched some TV") post.author = author post.reviewer = reviewer post.save() @@ -2128,24 +2166,26 @@ class InstanceTest(MongoDBTestCase): """Ensure that a referenced document is also deleted with pull. """ + class Record(Document): name = StringField() - children = ListField(ReferenceField('self', reverse_delete_rule=PULL)) + children = ListField(ReferenceField("self", reverse_delete_rule=PULL)) Record.drop_collection() - parent_record = Record(name='parent').save() - child_record = Record(name='child').save() + parent_record = Record(name="parent").save() + child_record = Record(name="child").save() parent_record.children.append(child_record) parent_record.save() child_record.delete() - self.assertEqual(Record.objects(name='parent').get().children, []) + self.assertEqual(Record.objects(name="parent").get().children, []) def test_reverse_delete_rule_with_custom_id_field(self): """Ensure that a referenced document with custom primary key is also deleted upon deletion. """ + class User(Document): name = StringField(primary_key=True) @@ -2156,8 +2196,8 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() Book.drop_collection() - user = User(name='Mike').save() - reviewer = User(name='John').save() + user = User(name="Mike").save() + reviewer = User(name="John").save() book = Book(author=user, reviewer=reviewer).save() reviewer.delete() @@ -2171,6 +2211,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that cascade delete rule doesn't mix id among collections. """ + class User(Document): id = IntField(primary_key=True) @@ -2203,6 +2244,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that a referenced document is also deleted upon deletion of a child document. """ + class Writer(self.Person): pass @@ -2214,13 +2256,13 @@ class InstanceTest(MongoDBTestCase): self.Person.drop_collection() BlogPost.drop_collection() - author = Writer(name='Test User') + author = Writer(name="Test User") author.save() - reviewer = Writer(name='Re Viewer') + reviewer = Writer(name="Re Viewer") reviewer.save() - post = BlogPost(content='Watched some TV') + post = BlogPost(content="Watched some TV") post.author = author post.reviewer = reviewer post.save() @@ -2237,23 +2279,26 @@ class InstanceTest(MongoDBTestCase): """Ensure that a referenced document is also deleted upon deletion for complex fields. """ + class BlogPost(Document): content = StringField() - authors = ListField(ReferenceField( - self.Person, reverse_delete_rule=CASCADE)) - reviewers = ListField(ReferenceField( - self.Person, reverse_delete_rule=NULLIFY)) + authors = ListField( + ReferenceField(self.Person, reverse_delete_rule=CASCADE) + ) + reviewers = ListField( + ReferenceField(self.Person, reverse_delete_rule=NULLIFY) + ) self.Person.drop_collection() BlogPost.drop_collection() - author = self.Person(name='Test User') + author = self.Person(name="Test User") author.save() - reviewer = self.Person(name='Re Viewer') + reviewer = self.Person(name="Re Viewer") reviewer.save() - post = BlogPost(content='Watched some TV') + post = BlogPost(content="Watched some TV") post.authors = [author] post.reviewers = [reviewer] post.save() @@ -2273,6 +2318,7 @@ class InstanceTest(MongoDBTestCase): delete the author which triggers deletion of blogpost via cascade blog post's pre_delete signal alters an editor attribute. """ + class Editor(self.Person): review_queue = IntField(default=0) @@ -2292,32 +2338,32 @@ class InstanceTest(MongoDBTestCase): BlogPost.drop_collection() Editor.drop_collection() - author = self.Person(name='Will S.').save() - editor = Editor(name='Max P.', review_queue=1).save() - BlogPost(content='wrote some books', author=author, - editor=editor).save() + author = self.Person(name="Will S.").save() + editor = Editor(name="Max P.", review_queue=1).save() + BlogPost(content="wrote some books", author=author, editor=editor).save() # delete the author, the post is also deleted due to the CASCADE rule author.delete() # the pre-delete signal should have decremented the editor's queue - editor = Editor.objects(name='Max P.').get() + editor = Editor.objects(name="Max P.").get() self.assertEqual(editor.review_queue, 0) def test_two_way_reverse_delete_rule(self): """Ensure that Bi-Directional relationships work with reverse_delete_rule """ + class Bar(Document): content = StringField() - foo = ReferenceField('Foo') + foo = ReferenceField("Foo") class Foo(Document): content = StringField() bar = ReferenceField(Bar) - Bar.register_delete_rule(Foo, 'bar', NULLIFY) - Foo.register_delete_rule(Bar, 'foo', NULLIFY) + Bar.register_delete_rule(Foo, "bar", NULLIFY) + Foo.register_delete_rule(Bar, "foo", NULLIFY) Bar.drop_collection() Foo.drop_collection() @@ -2338,24 +2384,27 @@ class InstanceTest(MongoDBTestCase): def test_invalid_reverse_delete_rule_raise_errors(self): with self.assertRaises(InvalidDocumentError): + class Blog(Document): content = StringField() - authors = MapField(ReferenceField( - self.Person, reverse_delete_rule=CASCADE)) + authors = MapField( + ReferenceField(self.Person, reverse_delete_rule=CASCADE) + ) reviewers = DictField( - field=ReferenceField( - self.Person, - reverse_delete_rule=NULLIFY)) + field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY) + ) with self.assertRaises(InvalidDocumentError): + class Parents(EmbeddedDocument): - father = ReferenceField('Person', reverse_delete_rule=DENY) - mother = ReferenceField('Person', reverse_delete_rule=DENY) + father = ReferenceField("Person", reverse_delete_rule=DENY) + mother = ReferenceField("Person", reverse_delete_rule=DENY) def test_reverse_delete_rule_cascade_recurs(self): """Ensure that a chain of documents is also deleted upon cascaded deletion. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) @@ -2368,14 +2417,14 @@ class InstanceTest(MongoDBTestCase): BlogPost.drop_collection() Comment.drop_collection() - author = self.Person(name='Test User') + author = self.Person(name="Test User") author.save() - post = BlogPost(content='Watched some TV') + post = BlogPost(content="Watched some TV") post.author = author post.save() - comment = Comment(text='Kudos.') + comment = Comment(text="Kudos.") comment.post = post comment.save() @@ -2388,6 +2437,7 @@ class InstanceTest(MongoDBTestCase): """Ensure that a document cannot be referenced if there are still documents referring to it. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=DENY) @@ -2395,20 +2445,22 @@ class InstanceTest(MongoDBTestCase): self.Person.drop_collection() BlogPost.drop_collection() - author = self.Person(name='Test User') + author = self.Person(name="Test User") author.save() - post = BlogPost(content='Watched some TV') + post = BlogPost(content="Watched some TV") post.author = author post.save() # Delete the Person should be denied self.assertRaises(OperationError, author.delete) # Should raise denied error - self.assertEqual(BlogPost.objects.count(), 1) # No objects may have been deleted + self.assertEqual( + BlogPost.objects.count(), 1 + ) # No objects may have been deleted self.assertEqual(self.Person.objects.count(), 1) # Other users, that don't have BlogPosts must be removable, like normal - author = self.Person(name='Another User') + author = self.Person(name="Another User") author.save() self.assertEqual(self.Person.objects.count(), 2) @@ -2434,6 +2486,7 @@ class InstanceTest(MongoDBTestCase): def test_document_hash(self): """Test document in list, dict, set.""" + class User(Document): pass @@ -2491,9 +2544,11 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(len(all_user_set), 3) def test_picklable(self): - pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) + pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) pickle_doc.embedded = PickleEmbedded() - pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved + pickled_doc = pickle.dumps( + pickle_doc + ) # make sure pickling works even before the doc is saved pickle_doc.save() pickled_doc = pickle.dumps(pickle_doc) @@ -2516,8 +2571,10 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) def test_regular_document_pickle(self): - pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) - pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved + pickle_doc = PickleTest(number=1, string="One", lists=["1", "2"]) + pickled_doc = pickle.dumps( + pickle_doc + ) # make sure pickling works even before the doc is saved pickle_doc.save() pickled_doc = pickle.dumps(pickle_doc) @@ -2527,21 +2584,23 @@ class InstanceTest(MongoDBTestCase): fixtures.PickleTest = fixtures.NewDocumentPickleTest resurrected = pickle.loads(pickled_doc) - self.assertEqual(resurrected.__class__, - fixtures.NewDocumentPickleTest) - self.assertEqual(resurrected._fields_ordered, - fixtures.NewDocumentPickleTest._fields_ordered) - self.assertNotEqual(resurrected._fields_ordered, - pickle_doc._fields_ordered) + self.assertEqual(resurrected.__class__, fixtures.NewDocumentPickleTest) + self.assertEqual( + resurrected._fields_ordered, fixtures.NewDocumentPickleTest._fields_ordered + ) + self.assertNotEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) # The local PickleTest is still a ref to the original fixtures.PickleTest = PickleTest def test_dynamic_document_pickle(self): pickle_doc = PickleDynamicTest( - name="test", number=1, string="One", lists=['1', '2']) + name="test", number=1, string="One", lists=["1", "2"] + ) pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar") - pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved + pickled_doc = pickle.dumps( + pickle_doc + ) # make sure pickling works even before the doc is saved pickle_doc.save() @@ -2549,20 +2608,22 @@ class InstanceTest(MongoDBTestCase): resurrected = pickle.loads(pickled_doc) self.assertEqual(resurrected, pickle_doc) - self.assertEqual(resurrected._fields_ordered, - pickle_doc._fields_ordered) - self.assertEqual(resurrected._dynamic_fields.keys(), - pickle_doc._dynamic_fields.keys()) + self.assertEqual(resurrected._fields_ordered, pickle_doc._fields_ordered) + self.assertEqual( + resurrected._dynamic_fields.keys(), pickle_doc._dynamic_fields.keys() + ) self.assertEqual(resurrected.embedded, pickle_doc.embedded) - self.assertEqual(resurrected.embedded._fields_ordered, - pickle_doc.embedded._fields_ordered) - self.assertEqual(resurrected.embedded._dynamic_fields.keys(), - pickle_doc.embedded._dynamic_fields.keys()) + self.assertEqual( + resurrected.embedded._fields_ordered, pickle_doc.embedded._fields_ordered + ) + self.assertEqual( + resurrected.embedded._dynamic_fields.keys(), + pickle_doc.embedded._dynamic_fields.keys(), + ) def test_picklable_on_signals(self): - pickle_doc = PickleSignalsTest( - number=1, string="One", lists=['1', '2']) + pickle_doc = PickleSignalsTest(number=1, string="One", lists=["1", "2"]) pickle_doc.embedded = PickleEmbedded() pickle_doc.save() pickle_doc.delete() @@ -2572,12 +2633,13 @@ class InstanceTest(MongoDBTestCase): the "validate" method. """ with self.assertRaises(InvalidDocumentError): + class Blog(Document): validate = DictField() def test_mutating_documents(self): class B(EmbeddedDocument): - field1 = StringField(default='field1') + field1 = StringField(default="field1") class A(Document): b = EmbeddedDocumentField(B, default=lambda: B()) @@ -2587,27 +2649,28 @@ class InstanceTest(MongoDBTestCase): a = A() a.save() a.reload() - self.assertEqual(a.b.field1, 'field1') + self.assertEqual(a.b.field1, "field1") class C(EmbeddedDocument): - c_field = StringField(default='cfield') + c_field = StringField(default="cfield") class B(EmbeddedDocument): - field1 = StringField(default='field1') + field1 = StringField(default="field1") field2 = EmbeddedDocumentField(C, default=lambda: C()) class A(Document): b = EmbeddedDocumentField(B, default=lambda: B()) a = A.objects()[0] - a.b.field2.c_field = 'new value' + a.b.field2.c_field = "new value" a.save() a.reload() - self.assertEqual(a.b.field2.c_field, 'new value') + self.assertEqual(a.b.field2.c_field, "new value") def test_can_save_false_values(self): """Ensures you can save False values on save.""" + class Doc(Document): foo = StringField() archived = BooleanField(default=False, required=True) @@ -2623,6 +2686,7 @@ class InstanceTest(MongoDBTestCase): def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs.""" + class Doc(DynamicDocument): foo = StringField() @@ -2637,6 +2701,7 @@ class InstanceTest(MongoDBTestCase): def test_do_not_save_unchanged_references(self): """Ensures cascading saves dont auto update""" + class Job(Document): name = StringField() @@ -2655,8 +2720,10 @@ class InstanceTest(MongoDBTestCase): person = Person(name="name", age=10, job=job) from pymongo.collection import Collection + orig_update = Collection.update try: + def fake_update(*args, **kwargs): self.fail("Unexpected update for %s" % args[0].name) return orig_update(*args, **kwargs) @@ -2670,9 +2737,9 @@ class InstanceTest(MongoDBTestCase): """DB Alias tests.""" # mongoenginetest - Is default connection alias from setUp() # Register Aliases - register_connection('testdb-1', 'mongoenginetest2') - register_connection('testdb-2', 'mongoenginetest3') - register_connection('testdb-3', 'mongoenginetest4') + register_connection("testdb-1", "mongoenginetest2") + register_connection("testdb-2", "mongoenginetest3") + register_connection("testdb-3", "mongoenginetest4") class User(Document): name = StringField() @@ -2719,42 +2786,43 @@ class InstanceTest(MongoDBTestCase): # Collections self.assertEqual( - User._get_collection(), - get_db("testdb-1")[User._get_collection_name()]) + User._get_collection(), get_db("testdb-1")[User._get_collection_name()] + ) self.assertEqual( - Book._get_collection(), - get_db("testdb-2")[Book._get_collection_name()]) + Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()] + ) self.assertEqual( AuthorBooks._get_collection(), - get_db("testdb-3")[AuthorBooks._get_collection_name()]) + get_db("testdb-3")[AuthorBooks._get_collection_name()], + ) def test_db_alias_overrides(self): """Test db_alias can be overriden.""" # Register a connection with db_alias testdb-2 - register_connection('testdb-2', 'mongoenginetest2') + register_connection("testdb-2", "mongoenginetest2") class A(Document): """Uses default db_alias """ + name = StringField() meta = {"allow_inheritance": True} class B(A): """Uses testdb-2 db_alias """ + meta = {"db_alias": "testdb-2"} A.objects.all() - self.assertEqual('testdb-2', B._meta.get('db_alias')) - self.assertEqual('mongoenginetest', - A._get_collection().database.name) - self.assertEqual('mongoenginetest2', - B._get_collection().database.name) + self.assertEqual("testdb-2", B._meta.get("db_alias")) + self.assertEqual("mongoenginetest", A._get_collection().database.name) + self.assertEqual("mongoenginetest2", B._get_collection().database.name) def test_db_alias_propagates(self): """db_alias propagates?""" - register_connection('testdb-1', 'mongoenginetest2') + register_connection("testdb-1", "mongoenginetest2") class A(Document): name = StringField() @@ -2763,10 +2831,11 @@ class InstanceTest(MongoDBTestCase): class B(A): pass - self.assertEqual('testdb-1', B._meta.get('db_alias')) + self.assertEqual("testdb-1", B._meta.get("db_alias")) def test_db_ref_usage(self): """DB Ref usage in dict_fields.""" + class User(Document): name = StringField() @@ -2774,9 +2843,7 @@ class InstanceTest(MongoDBTestCase): name = StringField() author = ReferenceField(User) extra = DictField() - meta = { - 'ordering': ['+name'] - } + meta = {"ordering": ["+name"]} def __unicode__(self): return self.name @@ -2798,12 +2865,19 @@ class InstanceTest(MongoDBTestCase): peter = User.objects.create(name="Peter") # Bob - Book.objects.create(name="1", author=bob, extra={ - "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) - Book.objects.create(name="2", author=bob, extra={ - "a": bob.to_dbref(), "b": karl.to_dbref()}) - Book.objects.create(name="3", author=bob, extra={ - "a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) + Book.objects.create( + name="1", + author=bob, + extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}, + ) + Book.objects.create( + name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()} + ) + Book.objects.create( + name="3", + author=bob, + extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}, + ) Book.objects.create(name="4", author=bob) # Jon @@ -2811,56 +2885,77 @@ class InstanceTest(MongoDBTestCase): Book.objects.create(name="6", author=peter) Book.objects.create(name="7", author=jon) Book.objects.create(name="8", author=jon) - Book.objects.create(name="9", author=jon, - extra={"a": peter.to_dbref()}) + Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) # Checks - self.assertEqual(",".join([str(b) for b in Book.objects.all()]), - "1,2,3,4,5,6,7,8,9") + self.assertEqual( + ",".join([str(b) for b in Book.objects.all()]), "1,2,3,4,5,6,7,8,9" + ) # bob related books - self.assertEqual(",".join([str(b) for b in Book.objects.filter( - Q(extra__a=bob) | - Q(author=bob) | - Q(extra__b=bob))]), - "1,2,3,4") + self.assertEqual( + ",".join( + [ + str(b) + for b in Book.objects.filter( + Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob) + ) + ] + ), + "1,2,3,4", + ) # Susan & Karl related books - self.assertEqual(",".join([str(b) for b in Book.objects.filter( - Q(extra__a__all=[karl, susan]) | - Q(author__all=[karl, susan]) | - Q(extra__b__all=[ - karl.to_dbref(), susan.to_dbref()])) - ]), "1") + self.assertEqual( + ",".join( + [ + str(b) + for b in Book.objects.filter( + Q(extra__a__all=[karl, susan]) + | Q(author__all=[karl, susan]) + | Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) + ) + ] + ), + "1", + ) # $Where - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - __raw__={ - "$where": """ + self.assertEqual( + u",".join( + [ + str(b) + for b in Book.objects.filter( + __raw__={ + "$where": """ function(){ return this.name == '1' || this.name == '2';}""" - })]), - "1,2") + } + ) + ] + ), + "1,2", + ) def test_switch_db_instance(self): - register_connection('testdb-1', 'mongoenginetest2') + register_connection("testdb-1", "mongoenginetest2") class Group(Document): name = StringField() Group.drop_collection() - with switch_db(Group, 'testdb-1') as Group: + with switch_db(Group, "testdb-1") as Group: Group.drop_collection() Group(name="hello - default").save() self.assertEqual(1, Group.objects.count()) group = Group.objects.first() - group.switch_db('testdb-1') + group.switch_db("testdb-1") group.name = "hello - testdb!" group.save() - with switch_db(Group, 'testdb-1') as Group: + with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() self.assertEqual("hello - testdb!", group.name) @@ -2869,10 +2964,10 @@ class InstanceTest(MongoDBTestCase): # Slightly contrived now - perform an update # Only works as they have the same object_id - group.switch_db('testdb-1') + group.switch_db("testdb-1") group.update(set__name="hello - update") - with switch_db(Group, 'testdb-1') as Group: + with switch_db(Group, "testdb-1") as Group: group = Group.objects.first() self.assertEqual("hello - update", group.name) Group.drop_collection() @@ -2883,10 +2978,10 @@ class InstanceTest(MongoDBTestCase): # Totally contrived now - perform a delete # Only works as they have the same object_id - group.switch_db('testdb-1') + group.switch_db("testdb-1") group.delete() - with switch_db(Group, 'testdb-1') as Group: + with switch_db(Group, "testdb-1") as Group: self.assertEqual(0, Group.objects.count()) group = Group.objects.first() @@ -2898,11 +2993,9 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().insert_one({ - 'name': 'John', - 'foo': 'Bar', - 'data': [1, 2, 3] - }) + User._get_collection().insert_one( + {"name": "John", "foo": "Bar", "data": [1, 2, 3]} + ) self.assertRaises(FieldDoesNotExist, User.objects.first) @@ -2910,22 +3003,20 @@ class InstanceTest(MongoDBTestCase): class User(Document): name = StringField() - meta = {'strict': False} + meta = {"strict": False} User.drop_collection() - User._get_collection().insert_one({ - 'name': 'John', - 'foo': 'Bar', - 'data': [1, 2, 3] - }) + User._get_collection().insert_one( + {"name": "John", "foo": "Bar", "data": [1, 2, 3]} + ) user = User.objects.first() - self.assertEqual(user.name, 'John') - self.assertFalse(hasattr(user, 'foo')) - self.assertEqual(user._data['foo'], 'Bar') - self.assertFalse(hasattr(user, 'data')) - self.assertEqual(user._data['data'], [1, 2, 3]) + self.assertEqual(user.name, "John") + self.assertFalse(hasattr(user, "foo")) + self.assertEqual(user._data["foo"], "Bar") + self.assertFalse(hasattr(user, "data")) + self.assertEqual(user._data["data"], [1, 2, 3]) def test_load_undefined_fields_on_embedded_document(self): class Thing(EmbeddedDocument): @@ -2937,14 +3028,12 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().insert_one({ - 'name': 'John', - 'thing': { - 'name': 'My thing', - 'foo': 'Bar', - 'data': [1, 2, 3] + User._get_collection().insert_one( + { + "name": "John", + "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]}, } - }) + ) self.assertRaises(FieldDoesNotExist, User.objects.first) @@ -2956,18 +3045,16 @@ class InstanceTest(MongoDBTestCase): name = StringField() thing = EmbeddedDocumentField(Thing) - meta = {'strict': False} + meta = {"strict": False} User.drop_collection() - User._get_collection().insert_one({ - 'name': 'John', - 'thing': { - 'name': 'My thing', - 'foo': 'Bar', - 'data': [1, 2, 3] + User._get_collection().insert_one( + { + "name": "John", + "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]}, } - }) + ) self.assertRaises(FieldDoesNotExist, User.objects.first) @@ -2975,7 +3062,7 @@ class InstanceTest(MongoDBTestCase): class Thing(EmbeddedDocument): name = StringField() - meta = {'strict': False} + meta = {"strict": False} class User(Document): name = StringField() @@ -2983,22 +3070,20 @@ class InstanceTest(MongoDBTestCase): User.drop_collection() - User._get_collection().insert_one({ - 'name': 'John', - 'thing': { - 'name': 'My thing', - 'foo': 'Bar', - 'data': [1, 2, 3] + User._get_collection().insert_one( + { + "name": "John", + "thing": {"name": "My thing", "foo": "Bar", "data": [1, 2, 3]}, } - }) + ) user = User.objects.first() - self.assertEqual(user.name, 'John') - self.assertEqual(user.thing.name, 'My thing') - self.assertFalse(hasattr(user.thing, 'foo')) - self.assertEqual(user.thing._data['foo'], 'Bar') - self.assertFalse(hasattr(user.thing, 'data')) - self.assertEqual(user.thing._data['data'], [1, 2, 3]) + self.assertEqual(user.name, "John") + self.assertEqual(user.thing.name, "My thing") + self.assertFalse(hasattr(user.thing, "foo")) + self.assertEqual(user.thing._data["foo"], "Bar") + self.assertFalse(hasattr(user.thing, "data")) + self.assertEqual(user.thing._data["data"], [1, 2, 3]) def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): @@ -3009,10 +3094,10 @@ class InstanceTest(MongoDBTestCase): Doc.drop_collection() doc = Doc() - setattr(doc, 'hello world', 1) + setattr(doc, "hello world", 1) doc.save() - one = Doc.objects.filter(**{'hello world': 1}).count() + one = Doc.objects.filter(**{"hello world": 1}).count() self.assertEqual(1, one) def test_shard_key(self): @@ -3020,9 +3105,7 @@ class InstanceTest(MongoDBTestCase): machine = StringField() log = StringField() - meta = { - 'shard_key': ('machine',) - } + meta = {"shard_key": ("machine",)} LogEntry.drop_collection() @@ -3044,24 +3127,22 @@ class InstanceTest(MongoDBTestCase): foo = StringField() class Bar(Document): - meta = { - 'shard_key': ('foo.foo',) - } + meta = {"shard_key": ("foo.foo",)} foo = EmbeddedDocumentField(Foo) bar = StringField() - foo_doc = Foo(foo='hello') - bar_doc = Bar(foo=foo_doc, bar='world') + foo_doc = Foo(foo="hello") + bar_doc = Bar(foo=foo_doc, bar="world") bar_doc.save() self.assertTrue(bar_doc.id is not None) - bar_doc.bar = 'baz' + bar_doc.bar = "baz" bar_doc.save() # try to change the shard key with self.assertRaises(OperationError): - bar_doc.foo.foo = 'something' + bar_doc.foo.foo = "something" bar_doc.save() def test_shard_key_primary(self): @@ -3069,9 +3150,7 @@ class InstanceTest(MongoDBTestCase): machine = StringField(primary_key=True) log = StringField() - meta = { - 'shard_key': ('machine',) - } + meta = {"shard_key": ("machine",)} LogEntry.drop_collection() @@ -3097,12 +3176,10 @@ class InstanceTest(MongoDBTestCase): doc = EmbeddedDocumentField(Embedded) def __eq__(self, other): - return (self.doc_name == other.doc_name and - self.doc == other.doc) + return self.doc_name == other.doc_name and self.doc == other.doc classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) - dict_doc = Doc(**{"doc_name": "my doc", - "doc": {"name": "embedded doc"}}) + dict_doc = Doc(**{"doc_name": "my doc", "doc": {"name": "embedded doc"}}) self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc._data, dict_doc._data) @@ -3116,15 +3193,18 @@ class InstanceTest(MongoDBTestCase): docs = ListField(EmbeddedDocumentField(Embedded)) def __eq__(self, other): - return (self.doc_name == other.doc_name and - self.docs == other.docs) + return self.doc_name == other.doc_name and self.docs == other.docs - classic_doc = Doc(doc_name="my doc", docs=[ - Embedded(name="embedded doc1"), - Embedded(name="embedded doc2")]) - dict_doc = Doc(**{"doc_name": "my doc", - "docs": [{"name": "embedded doc1"}, - {"name": "embedded doc2"}]}) + classic_doc = Doc( + doc_name="my doc", + docs=[Embedded(name="embedded doc1"), Embedded(name="embedded doc2")], + ) + dict_doc = Doc( + **{ + "doc_name": "my doc", + "docs": [{"name": "embedded doc1"}, {"name": "embedded doc2"}], + } + ) self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc._data, dict_doc._data) @@ -3134,8 +3214,8 @@ class InstanceTest(MongoDBTestCase): with self.assertRaises(TypeError) as e: person = self.Person("Test User", 42) expected_msg = ( - 'Instantiating a document with positional arguments is not ' - 'supported. Please use `field_name=value` keyword arguments.' + "Instantiating a document with positional arguments is not " + "supported. Please use `field_name=value` keyword arguments." ) self.assertEqual(str(e.exception), expected_msg) @@ -3144,8 +3224,8 @@ class InstanceTest(MongoDBTestCase): with self.assertRaises(TypeError) as e: person = self.Person("Test User", age=42) expected_msg = ( - 'Instantiating a document with positional arguments is not ' - 'supported. Please use `field_name=value` keyword arguments.' + "Instantiating a document with positional arguments is not " + "supported. Please use `field_name=value` keyword arguments." ) self.assertEqual(str(e.exception), expected_msg) @@ -3154,8 +3234,8 @@ class InstanceTest(MongoDBTestCase): with self.assertRaises(TypeError) as e: job = self.Job("Test Job", 4) expected_msg = ( - 'Instantiating a document with positional arguments is not ' - 'supported. Please use `field_name=value` keyword arguments.' + "Instantiating a document with positional arguments is not " + "supported. Please use `field_name=value` keyword arguments." ) self.assertEqual(str(e.exception), expected_msg) @@ -3164,13 +3244,14 @@ class InstanceTest(MongoDBTestCase): with self.assertRaises(TypeError) as e: job = self.Job("Test Job", years=4) expected_msg = ( - 'Instantiating a document with positional arguments is not ' - 'supported. Please use `field_name=value` keyword arguments.' + "Instantiating a document with positional arguments is not " + "supported. Please use `field_name=value` keyword arguments." ) self.assertEqual(str(e.exception), expected_msg) def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id'.""" + class Person(Document): name = StringField() @@ -3178,8 +3259,8 @@ class InstanceTest(MongoDBTestCase): Person(name="Harry Potter").save() person = Person.objects.first() - self.assertIn('id', person._data.keys()) - self.assertEqual(person._data.get('id'), person.id) + self.assertIn("id", person._data.keys()) + self.assertEqual(person._data.get("id"), person.id) def test_complex_nesting_document_and_embedded_document(self): class Macro(EmbeddedDocument): @@ -3220,8 +3301,8 @@ class InstanceTest(MongoDBTestCase): system = NodesSystem.objects.first() self.assertEqual( - "UNDEFINED", - system.nodes["node"].parameters["param"].macros["test"].value) + "UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value + ) def test_embedded_document_equality(self): class Test(Document): @@ -3231,7 +3312,7 @@ class InstanceTest(MongoDBTestCase): ref = ReferenceField(Test) Test.drop_collection() - test = Test(field='123').save() # has id + test = Test(field="123").save() # has id e = Embedded(ref=test) f1 = Embedded._from_son(e.to_mongo()) @@ -3250,25 +3331,25 @@ class InstanceTest(MongoDBTestCase): class Test(Document): name = StringField() - test2 = ReferenceField('Test2') - test3 = ReferenceField('Test3') + test2 = ReferenceField("Test2") + test3 = ReferenceField("Test3") Test.drop_collection() Test2.drop_collection() Test3.drop_collection() - t2 = Test2(name='a') + t2 = Test2(name="a") t2.save() - t3 = Test3(name='x') + t3 = Test3(name="x") t3.id = t2.id t3.save() - t = Test(name='b', test2=t2, test3=t3) + t = Test(name="b", test2=t2, test3=t3) f = Test._from_son(t.to_mongo()) - dbref2 = f._data['test2'] + dbref2 = f._data["test2"] obj2 = f.test2 self.assertIsInstance(dbref2, DBRef) self.assertIsInstance(obj2, Test2) @@ -3276,7 +3357,7 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(obj2, dbref2) self.assertEqual(dbref2, obj2) - dbref3 = f._data['test3'] + dbref3 = f._data["test3"] obj3 = f.test3 self.assertIsInstance(dbref3, DBRef) self.assertIsInstance(obj3, Test3) @@ -3306,14 +3387,14 @@ class InstanceTest(MongoDBTestCase): created_on = DateTimeField(default=lambda: datetime.utcnow()) name = StringField() - p = Person(name='alon') + p = Person(name="alon") p.save() - orig_created_on = Person.objects().only('created_on')[0].created_on + orig_created_on = Person.objects().only("created_on")[0].created_on - p2 = Person.objects().only('name')[0] - p2.name = 'alon2' + p2 = Person.objects().only("name")[0] + p2.name = "alon2" p2.save() - p3 = Person.objects().only('created_on')[0] + p3 = Person.objects().only("created_on")[0] self.assertEqual(orig_created_on, p3.created_on) class Person(Document): @@ -3331,8 +3412,8 @@ class InstanceTest(MongoDBTestCase): # alter DB for the new default coll = Person._get_collection() for person in Person.objects.as_pymongo(): - if 'height' not in person: - coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}}) + if "height" not in person: + coll.update_one({"_id": person["_id"]}, {"$set": {"height": 189}}) self.assertEqual(Person.objects(height=189).count(), 1) @@ -3340,12 +3421,17 @@ class InstanceTest(MongoDBTestCase): # 771 class MyPerson(self.Person): meta = dict(shard_key=["id"]) + p = MyPerson.from_json('{"name": "name", "age": 27}', created=True) self.assertEqual(p.id, None) - p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here + p.id = ( + "12345" + ) # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here p = MyPerson._from_son({"name": "name", "age": 27}, created=True) self.assertEqual(p.id, None) - p.id = "12345" # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here + p.id = ( + "12345" + ) # in case it is not working: "OperationError: Shard Keys are immutable..." will be raised here def test_from_son_created_False_without_id(self): class MyPerson(Document): @@ -3359,7 +3445,7 @@ class InstanceTest(MongoDBTestCase): p.save() self.assertIsNotNone(p.id) saved_p = MyPerson.objects.get(id=p.id) - self.assertEqual(saved_p.name, 'a_fancy_name') + self.assertEqual(saved_p.name, "a_fancy_name") def test_from_son_created_False_with_id(self): # 1854 @@ -3368,11 +3454,13 @@ class InstanceTest(MongoDBTestCase): MyPerson.objects.delete() - p = MyPerson.from_json('{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=False) + p = MyPerson.from_json( + '{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=False + ) self.assertFalse(p._created) self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, 'a_fancy_name') - self.assertEqual(p.id, ObjectId('5b85a8b04ec5dc2da388296e')) + self.assertEqual(p.name, "a_fancy_name") + self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) p.save() with self.assertRaises(DoesNotExist): @@ -3382,8 +3470,8 @@ class InstanceTest(MongoDBTestCase): MyPerson.objects.get(id=p.id) self.assertFalse(p._created) - p.name = 'a new fancy name' - self.assertEqual(p._changed_fields, ['name']) + p.name = "a new fancy name" + self.assertEqual(p._changed_fields, ["name"]) p.save() saved_p = MyPerson.objects.get(id=p.id) self.assertEqual(saved_p.name, p.name) @@ -3394,16 +3482,18 @@ class InstanceTest(MongoDBTestCase): MyPerson.objects.delete() - p = MyPerson.from_json('{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=True) + p = MyPerson.from_json( + '{"_id": "5b85a8b04ec5dc2da388296e", "name": "a_fancy_name"}', created=True + ) self.assertTrue(p._created) self.assertEqual(p._changed_fields, []) - self.assertEqual(p.name, 'a_fancy_name') - self.assertEqual(p.id, ObjectId('5b85a8b04ec5dc2da388296e')) + self.assertEqual(p.name, "a_fancy_name") + self.assertEqual(p.id, ObjectId("5b85a8b04ec5dc2da388296e")) p.save() saved_p = MyPerson.objects.get(id=p.id) self.assertEqual(saved_p, p) - self.assertEqual(p.name, 'a_fancy_name') + self.assertEqual(p.name, "a_fancy_name") def test_null_field(self): # 734 @@ -3417,9 +3507,9 @@ class InstanceTest(MongoDBTestCase): cdt_fld = ComplexDateTimeField(null=True) User.objects.delete() - u = User(name='user') + u = User(name="user") u.save() - u_from_db = User.objects.get(name='user') + u_from_db = User.objects.get(name="user") u_from_db.height = None u_from_db.save() self.assertEqual(u_from_db.height, None) @@ -3432,15 +3522,16 @@ class InstanceTest(MongoDBTestCase): # 735 User.objects.delete() - u = User(name='user') + u = User(name="user") u.save() - User.objects(name='user').update_one(set__height=None, upsert=True) - u_from_db = User.objects.get(name='user') + User.objects(name="user").update_one(set__height=None, upsert=True) + u_from_db = User.objects.get(name="user") self.assertEqual(u_from_db.height, None) def test_not_saved_eq(self): """Ensure we can compare documents not saved. """ + class Person(Document): pass @@ -3458,7 +3549,7 @@ class InstanceTest(MongoDBTestCase): l = ListField(EmbeddedDocumentField(B)) A.objects.delete() - A(l=[B(v='1'), B(v='2'), B(v='3')]).save() + A(l=[B(v="1"), B(v="2"), B(v="3")]).save() a = A.objects.get() self.assertEqual(a.l._instance, a) for idx, b in enumerate(a.l): @@ -3467,6 +3558,7 @@ class InstanceTest(MongoDBTestCase): def test_falsey_pk(self): """Ensure that we can create and update a document with Falsey PK.""" + class Person(Document): age = IntField(primary_key=True) height = FloatField() @@ -3480,6 +3572,7 @@ class InstanceTest(MongoDBTestCase): def test_push_with_position(self): """Ensure that push with position works properly for an instance.""" + class BlogPost(Document): slug = StringField() tags = ListField(StringField()) @@ -3491,10 +3584,11 @@ class InstanceTest(MongoDBTestCase): blog.update(push__tags__0=["mongodb", "code"]) blog.reload() - self.assertEqual(blog.tags, ['mongodb', 'code', 'python']) + self.assertEqual(blog.tags, ["mongodb", "code", "python"]) def test_push_nested_list(self): """Ensure that push update works in nested list""" + class BlogPost(Document): slug = StringField() tags = ListField() @@ -3505,10 +3599,11 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(blog.tags, [["value1", 123]]) def test_accessing_objects_with_indexes_error(self): - insert_result = self.db.company.insert_many([{'name': 'Foo'}, - {'name': 'Foo'}]) # Force 2 doc with same name + insert_result = self.db.company.insert_many( + [{"name": "Foo"}, {"name": "Foo"}] + ) # Force 2 doc with same name REF_OID = insert_result.inserted_ids[0] - self.db.user.insert_one({'company': REF_OID}) # Force 2 doc with same name + self.db.user.insert_one({"company": REF_OID}) # Force 2 doc with same name class Company(Document): name = StringField(unique=True) @@ -3521,5 +3616,5 @@ class InstanceTest(MongoDBTestCase): User.objects().select_related() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index 251b65a2..33d5a6d9 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -13,9 +13,8 @@ __all__ = ("TestJson",) class TestJson(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") def test_json_names(self): """ @@ -25,22 +24,24 @@ class TestJson(unittest.TestCase): a to_json with the original class names and not the abreviated mongodb document keys """ + class Embedded(EmbeddedDocument): - string = StringField(db_field='s') + string = StringField(db_field="s") class Doc(Document): - string = StringField(db_field='s') - embedded = EmbeddedDocumentField(Embedded, db_field='e') + string = StringField(db_field="s") + embedded = EmbeddedDocumentField(Embedded, db_field="e") doc = Doc(string="Hello", embedded=Embedded(string="Inner Hello")) - doc_json = doc.to_json(sort_keys=True, use_db_field=False, separators=(',', ':')) + doc_json = doc.to_json( + sort_keys=True, use_db_field=False, separators=(",", ":") + ) expected_json = """{"embedded":{"string":"Inner Hello"},"string":"Hello"}""" self.assertEqual(doc_json, expected_json) def test_json_simple(self): - class Embedded(EmbeddedDocument): string = StringField() @@ -49,12 +50,14 @@ class TestJson(unittest.TestCase): embedded_field = EmbeddedDocumentField(Embedded) def __eq__(self, other): - return (self.string == other.string and - self.embedded_field == other.embedded_field) + return ( + self.string == other.string + and self.embedded_field == other.embedded_field + ) doc = Doc(string="Hi", embedded_field=Embedded(string="Hi")) - doc_json = doc.to_json(sort_keys=True, separators=(',', ':')) + doc_json = doc.to_json(sort_keys=True, separators=(",", ":")) expected_json = """{"embedded_field":{"string":"Hi"},"string":"Hi"}""" self.assertEqual(doc_json, expected_json) @@ -68,41 +71,43 @@ class TestJson(unittest.TestCase): pass class Doc(Document): - string_field = StringField(default='1') + string_field = StringField(default="1") int_field = IntField(default=1) float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) - embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, - default=lambda: EmbeddedDoc()) + embedded_document_field = EmbeddedDocumentField( + EmbeddedDoc, default=lambda: EmbeddedDoc() + ) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=ObjectId) - reference_field = ReferenceField(Simple, default=lambda: - Simple().save()) + reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) generic_reference_field = GenericReferenceField( - default=lambda: Simple().save()) - sorted_list_field = SortedListField(IntField(), - default=lambda: [1, 2, 3]) + default=lambda: Simple().save() + ) + sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") geo_point_field = GeoPointField(default=lambda: [1, 2]) sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( - default=lambda: EmbeddedDoc()) + default=lambda: EmbeddedDoc() + ) def __eq__(self, other): import json + return json.loads(self.to_json()) == json.loads(other.to_json()) doc = Doc() self.assertEqual(doc, Doc.from_json(doc.to_json())) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/document/validation.py b/tests/document/validation.py index 30a285b2..78199231 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -8,49 +8,56 @@ __all__ = ("ValidatorErrorTest",) class ValidatorErrorTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") def test_to_dict(self): """Ensure a ValidationError handles error to_dict correctly. """ - error = ValidationError('root') + error = ValidationError("root") self.assertEqual(error.to_dict(), {}) # 1st level error schema - error.errors = {'1st': ValidationError('bad 1st'), } - self.assertIn('1st', error.to_dict()) - self.assertEqual(error.to_dict()['1st'], 'bad 1st') + error.errors = {"1st": ValidationError("bad 1st")} + self.assertIn("1st", error.to_dict()) + self.assertEqual(error.to_dict()["1st"], "bad 1st") # 2nd level error schema - error.errors = {'1st': ValidationError('bad 1st', errors={ - '2nd': ValidationError('bad 2nd'), - })} - self.assertIn('1st', error.to_dict()) - self.assertIsInstance(error.to_dict()['1st'], dict) - self.assertIn('2nd', error.to_dict()['1st']) - self.assertEqual(error.to_dict()['1st']['2nd'], 'bad 2nd') + error.errors = { + "1st": ValidationError( + "bad 1st", errors={"2nd": ValidationError("bad 2nd")} + ) + } + self.assertIn("1st", error.to_dict()) + self.assertIsInstance(error.to_dict()["1st"], dict) + self.assertIn("2nd", error.to_dict()["1st"]) + self.assertEqual(error.to_dict()["1st"]["2nd"], "bad 2nd") # moar levels - error.errors = {'1st': ValidationError('bad 1st', errors={ - '2nd': ValidationError('bad 2nd', errors={ - '3rd': ValidationError('bad 3rd', errors={ - '4th': ValidationError('Inception'), - }), - }), - })} - self.assertIn('1st', error.to_dict()) - self.assertIn('2nd', error.to_dict()['1st']) - self.assertIn('3rd', error.to_dict()['1st']['2nd']) - self.assertIn('4th', error.to_dict()['1st']['2nd']['3rd']) - self.assertEqual(error.to_dict()['1st']['2nd']['3rd']['4th'], - 'Inception') + error.errors = { + "1st": ValidationError( + "bad 1st", + errors={ + "2nd": ValidationError( + "bad 2nd", + errors={ + "3rd": ValidationError( + "bad 3rd", errors={"4th": ValidationError("Inception")} + ) + }, + ) + }, + ) + } + self.assertIn("1st", error.to_dict()) + self.assertIn("2nd", error.to_dict()["1st"]) + self.assertIn("3rd", error.to_dict()["1st"]["2nd"]) + self.assertIn("4th", error.to_dict()["1st"]["2nd"]["3rd"]) + self.assertEqual(error.to_dict()["1st"]["2nd"]["3rd"]["4th"], "Inception") self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") def test_model_validation(self): - class User(Document): username = StringField(primary_key=True) name = StringField(required=True) @@ -59,9 +66,10 @@ class ValidatorErrorTest(unittest.TestCase): User().validate() except ValidationError as e: self.assertIn("User:None", e.message) - self.assertEqual(e.to_dict(), { - 'username': 'Field is required', - 'name': 'Field is required'}) + self.assertEqual( + e.to_dict(), + {"username": "Field is required", "name": "Field is required"}, + ) user = User(username="RossC0", name="Ross").save() user.name = None @@ -69,14 +77,13 @@ class ValidatorErrorTest(unittest.TestCase): user.save() except ValidationError as e: self.assertIn("User:RossC0", e.message) - self.assertEqual(e.to_dict(), { - 'name': 'Field is required'}) + self.assertEqual(e.to_dict(), {"name": "Field is required"}) def test_fields_rewrite(self): class BasePerson(Document): name = StringField() age = IntField() - meta = {'abstract': True} + meta = {"abstract": True} class Person(BasePerson): name = StringField(required=True) @@ -87,6 +94,7 @@ class ValidatorErrorTest(unittest.TestCase): def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ + class Comment(EmbeddedDocument): date = DateTimeField() content = StringField(required=True) @@ -94,7 +102,7 @@ class ValidatorErrorTest(unittest.TestCase): comment = Comment() self.assertRaises(ValidationError, comment.validate) - comment.content = 'test' + comment.content = "test" comment.validate() comment.date = 4 @@ -105,20 +113,20 @@ class ValidatorErrorTest(unittest.TestCase): self.assertEqual(comment._instance, None) def test_embedded_db_field_validate(self): - class SubDoc(EmbeddedDocument): val = IntField(required=True) class Doc(Document): id = StringField(primary_key=True) - e = EmbeddedDocumentField(SubDoc, db_field='eb') + e = EmbeddedDocumentField(SubDoc, db_field="eb") try: Doc(id="bad").validate() except ValidationError as e: self.assertIn("SubDoc:None", e.message) - self.assertEqual(e.to_dict(), { - "e": {'val': 'OK could not be converted to int'}}) + self.assertEqual( + e.to_dict(), {"e": {"val": "OK could not be converted to int"}} + ) Doc.drop_collection() @@ -127,24 +135,24 @@ class ValidatorErrorTest(unittest.TestCase): doc = Doc.objects.first() keys = doc._data.keys() self.assertEqual(2, len(keys)) - self.assertIn('e', keys) - self.assertIn('id', keys) + self.assertIn("e", keys) + self.assertIn("id", keys) doc.e.val = "OK" try: doc.save() except ValidationError as e: self.assertIn("Doc:test", e.message) - self.assertEqual(e.to_dict(), { - "e": {'val': 'OK could not be converted to int'}}) + self.assertEqual( + e.to_dict(), {"e": {"val": "OK could not be converted to int"}} + ) def test_embedded_weakref(self): - class SubDoc(EmbeddedDocument): val = IntField(required=True) class Doc(Document): - e = EmbeddedDocumentField(SubDoc, db_field='eb') + e = EmbeddedDocumentField(SubDoc, db_field="eb") Doc.drop_collection() @@ -167,9 +175,10 @@ class ValidatorErrorTest(unittest.TestCase): Test to ensure a ReferenceField can store a reference to a parent class when inherited. Issue #954. """ + class Parent(Document): - meta = {'allow_inheritance': True} - reference = ReferenceField('self') + meta = {"allow_inheritance": True} + reference = ReferenceField("self") class Child(Parent): pass @@ -190,9 +199,10 @@ class ValidatorErrorTest(unittest.TestCase): Test to ensure a ReferenceField can store a reference to a parent class when inherited and when set via attribute. Issue #954. """ + class Parent(Document): - meta = {'allow_inheritance': True} - reference = ReferenceField('self') + meta = {"allow_inheritance": True} + reference = ReferenceField("self") class Child(Parent): pass @@ -210,5 +220,5 @@ class ValidatorErrorTest(unittest.TestCase): self.fail("ValidationError raised: %s" % e.message) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 68baab46..87acf27f 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -6,27 +6,52 @@ from nose.plugins.skip import SkipTest from bson import DBRef, ObjectId, SON -from mongoengine import Document, StringField, IntField, DateTimeField, DateField, ValidationError, \ - ComplexDateTimeField, FloatField, ListField, ReferenceField, DictField, EmbeddedDocument, EmbeddedDocumentField, \ - GenericReferenceField, DoesNotExist, NotRegistered, OperationError, DynamicField, \ - FieldDoesNotExist, EmbeddedDocumentListField, MultipleObjectsReturned, NotUniqueError, BooleanField,\ - ObjectIdField, SortedListField, GenericLazyReferenceField, LazyReferenceField, DynamicDocument -from mongoengine.base import (BaseField, EmbeddedDocumentList, _document_registry) +from mongoengine import ( + Document, + StringField, + IntField, + DateTimeField, + DateField, + ValidationError, + ComplexDateTimeField, + FloatField, + ListField, + ReferenceField, + DictField, + EmbeddedDocument, + EmbeddedDocumentField, + GenericReferenceField, + DoesNotExist, + NotRegistered, + OperationError, + DynamicField, + FieldDoesNotExist, + EmbeddedDocumentListField, + MultipleObjectsReturned, + NotUniqueError, + BooleanField, + ObjectIdField, + SortedListField, + GenericLazyReferenceField, + LazyReferenceField, + DynamicDocument, +) +from mongoengine.base import BaseField, EmbeddedDocumentList, _document_registry from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase class FieldTest(MongoDBTestCase): - def test_default_values_nothing_set(self): """Ensure that default field values are used when creating a document. """ + class Person(Document): name = StringField() age = IntField(default=30, required=False) - userid = StringField(default=lambda: 'test', required=True) + userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) day = DateField(default=datetime.date.today) @@ -34,9 +59,7 @@ class FieldTest(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, - ['age', 'created', 'day', 'name', 'userid'] - ) + self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) self.assertTrue(person.validate() is None) @@ -46,18 +69,19 @@ class FieldTest(MongoDBTestCase): self.assertEqual(person.created, person.created) self.assertEqual(person.day, person.day) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) - self.assertEqual(person._data['day'], person.day) + self.assertEqual(person._data["name"], person.name) + self.assertEqual(person._data["age"], person.age) + self.assertEqual(person._data["userid"], person.userid) + self.assertEqual(person._data["created"], person.created) + self.assertEqual(person._data["day"], person.day) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual( - data_to_be_saved, ['age', 'created', 'day', 'name', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "day", "name", "userid"]) - def test_custom_field_validation_raise_deprecated_error_when_validation_return_something(self): + def test_custom_field_validation_raise_deprecated_error_when_validation_return_something( + self + ): # Covers introduction of a breaking change in the validation parameter (0.18) def _not_empty(z): return bool(z) @@ -67,8 +91,10 @@ class FieldTest(MongoDBTestCase): Person.drop_collection() - error = ("validation argument for `name` must not return anything, " - "it should raise a ValidationError if validation fails") + error = ( + "validation argument for `name` must not return anything, " + "it should raise a ValidationError if validation fails" + ) with self.assertRaises(DeprecatedError) as ctx_err: Person(name="").validate() @@ -81,7 +107,7 @@ class FieldTest(MongoDBTestCase): def test_custom_field_validation_raise_validation_error(self): def _not_empty(z): if not z: - raise ValidationError('cantbeempty') + raise ValidationError("cantbeempty") class Person(Document): name = StringField(validation=_not_empty) @@ -90,11 +116,17 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(ValidationError) as ctx_err: Person(name="").validate() - self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + self.assertEqual( + "ValidationError (Person:None) (cantbeempty: ['name'])", + str(ctx_err.exception), + ) with self.assertRaises(ValidationError): Person(name="").save() - self.assertEqual("ValidationError (Person:None) (cantbeempty: ['name'])", str(ctx_err.exception)) + self.assertEqual( + "ValidationError (Person:None) (cantbeempty: ['name'])", + str(ctx_err.exception), + ) Person(name="garbage").validate() Person(name="garbage").save() @@ -103,10 +135,11 @@ class FieldTest(MongoDBTestCase): """Ensure that default field values are used even when we explcitly initialize the doc with None values. """ + class Person(Document): name = StringField() age = IntField(default=30, required=False) - userid = StringField(default=lambda: 'test', required=True) + userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) # Trying setting values to None @@ -114,7 +147,7 @@ class FieldTest(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) self.assertTrue(person.validate() is None) @@ -123,23 +156,24 @@ class FieldTest(MongoDBTestCase): self.assertEqual(person.userid, person.userid) self.assertEqual(person.created, person.created) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + self.assertEqual(person._data["name"], person.name) + self.assertEqual(person._data["age"], person.age) + self.assertEqual(person._data["userid"], person.userid) + self.assertEqual(person._data["created"], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) def test_default_values_when_setting_to_None(self): """Ensure that default field values are used when creating a document. """ + class Person(Document): name = StringField() age = IntField(default=30, required=False) - userid = StringField(default=lambda: 'test', required=True) + userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) person = Person() @@ -150,25 +184,27 @@ class FieldTest(MongoDBTestCase): # Confirm saving now would store values data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) self.assertTrue(person.validate() is None) self.assertEqual(person.name, None) self.assertEqual(person.age, 30) - self.assertEqual(person.userid, 'test') + self.assertEqual(person.userid, "test") self.assertIsInstance(person.created, datetime.datetime) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + self.assertEqual(person._data["name"], person.name) + self.assertEqual(person._data["age"], person.age) + self.assertEqual(person._data["userid"], person.userid) + self.assertEqual(person._data["created"], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) - def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc(self): + def test_default_value_is_not_used_when_changing_value_to_empty_list_for_strict_doc( + self + ): """List field with default can be set to the empty list (strict)""" # Issue #1733 class Doc(Document): @@ -180,7 +216,9 @@ class FieldTest(MongoDBTestCase): reloaded = Doc.objects.get(id=doc.id) self.assertEqual(reloaded.x, []) - def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc(self): + def test_default_value_is_not_used_when_changing_value_to_empty_list_for_dyn_doc( + self + ): """List field with default can be set to the empty list (dynamic)""" # Issue #1733 class Doc(DynamicDocument): @@ -188,7 +226,7 @@ class FieldTest(MongoDBTestCase): doc = Doc(x=[1]).save() doc.x = [] - doc.y = 2 # Was triggering the bug + doc.y = 2 # Was triggering the bug doc.save() reloaded = Doc.objects.get(id=doc.id) self.assertEqual(reloaded.x, []) @@ -197,41 +235,47 @@ class FieldTest(MongoDBTestCase): """Ensure that default field values are used after non-default values are explicitly deleted. """ + class Person(Document): name = StringField() age = IntField(default=30, required=False) - userid = StringField(default=lambda: 'test', required=True) + userid = StringField(default=lambda: "test", required=True) created = DateTimeField(default=datetime.datetime.utcnow) - person = Person(name="Ross", age=50, userid='different', - created=datetime.datetime(2014, 6, 12)) + person = Person( + name="Ross", + age=50, + userid="different", + created=datetime.datetime(2014, 6, 12), + ) del person.name del person.age del person.userid del person.created data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) self.assertTrue(person.validate() is None) self.assertEqual(person.name, None) self.assertEqual(person.age, 30) - self.assertEqual(person.userid, 'test') + self.assertEqual(person.userid, "test") self.assertIsInstance(person.created, datetime.datetime) self.assertNotEqual(person.created, datetime.datetime(2014, 6, 12)) - self.assertEqual(person._data['name'], person.name) - self.assertEqual(person._data['age'], person.age) - self.assertEqual(person._data['userid'], person.userid) - self.assertEqual(person._data['created'], person.created) + self.assertEqual(person._data["name"], person.name) + self.assertEqual(person._data["age"], person.age) + self.assertEqual(person._data["userid"], person.userid) + self.assertEqual(person._data["created"], person.created) # Confirm introspection changes nothing data_to_be_saved = sorted(person.to_mongo().keys()) - self.assertEqual(data_to_be_saved, ['age', 'created', 'userid']) + self.assertEqual(data_to_be_saved, ["age", "created", "userid"]) def test_required_values(self): """Ensure that required field constraints are enforced.""" + class Person(Document): name = StringField(required=True) age = IntField(required=True) @@ -246,6 +290,7 @@ class FieldTest(MongoDBTestCase): """Ensure that every fields should accept None if required is False. """ + class HandleNoneFields(Document): str_fld = StringField() int_fld = IntField() @@ -255,7 +300,7 @@ class FieldTest(MongoDBTestCase): HandleNoneFields.drop_collection() doc = HandleNoneFields() - doc.str_fld = u'spam ham egg' + doc.str_fld = u"spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 doc.com_dt_fld = datetime.datetime.utcnow() @@ -281,6 +326,7 @@ class FieldTest(MongoDBTestCase): """Ensure that every field can handle null values from the database. """ + class HandleNoneFields(Document): str_fld = StringField(required=True) int_fld = IntField(required=True) @@ -290,21 +336,17 @@ class FieldTest(MongoDBTestCase): HandleNoneFields.drop_collection() doc = HandleNoneFields() - doc.str_fld = u'spam ham egg' + doc.str_fld = u"spam ham egg" doc.int_fld = 42 doc.flt_fld = 4.2 doc.comp_dt_fld = datetime.datetime.utcnow() doc.save() # Unset all the fields - obj = HandleNoneFields._get_collection().update({"_id": doc.id}, { - "$unset": { - "str_fld": 1, - "int_fld": 1, - "flt_fld": 1, - "comp_dt_fld": 1 - } - }) + obj = HandleNoneFields._get_collection().update( + {"_id": doc.id}, + {"$unset": {"str_fld": 1, "int_fld": 1, "flt_fld": 1, "comp_dt_fld": 1}}, + ) # Retrive data from db and verify it. ret = HandleNoneFields.objects.first() @@ -321,16 +363,17 @@ class FieldTest(MongoDBTestCase): """Ensure that invalid values cannot be assigned to an ObjectIdField. """ + class Person(Document): name = StringField() - person = Person(name='Test User') + person = Person(name="Test User") self.assertEqual(person.id, None) person.id = 47 self.assertRaises(ValidationError, person.validate) - person.id = 'abc' + person.id = "abc" self.assertRaises(ValidationError, person.validate) person.id = str(ObjectId()) @@ -338,26 +381,27 @@ class FieldTest(MongoDBTestCase): def test_string_validation(self): """Ensure that invalid values cannot be assigned to string fields.""" + class Person(Document): name = StringField(max_length=20) - userid = StringField(r'[0-9a-z_]+$') + userid = StringField(r"[0-9a-z_]+$") person = Person(name=34) self.assertRaises(ValidationError, person.validate) # Test regex validation on userid - person = Person(userid='test.User') + person = Person(userid="test.User") self.assertRaises(ValidationError, person.validate) - person.userid = 'test_user' - self.assertEqual(person.userid, 'test_user') + person.userid = "test_user" + self.assertEqual(person.userid, "test_user") person.validate() # Test max length validation on name - person = Person(name='Name that is more than twenty characters') + person = Person(name="Name that is more than twenty characters") self.assertRaises(ValidationError, person.validate) - person.name = 'Shorter name' + person.name = "Shorter name" person.validate() def test_db_field_validation(self): @@ -365,25 +409,28 @@ class FieldTest(MongoDBTestCase): # dot in the name with self.assertRaises(ValueError): + class User(Document): - name = StringField(db_field='user.name') + name = StringField(db_field="user.name") # name starting with $ with self.assertRaises(ValueError): + class User(Document): - name = StringField(db_field='$name') + name = StringField(db_field="$name") # name containing a null character with self.assertRaises(ValueError): + class User(Document): - name = StringField(db_field='name\0') + name = StringField(db_field="name\0") def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements.""" access_level_choices = ( - ('a', u'Administration'), - ('b', u'Manager'), - ('c', u'Staff'), + ("a", u"Administration"), + ("b", u"Manager"), + ("c", u"Staff"), ) class User(Document): @@ -400,41 +447,41 @@ class FieldTest(MongoDBTestCase): authors_as_lazy = ListField(LazyReferenceField(User)) generic = ListField(GenericReferenceField()) generic_as_lazy = ListField(GenericLazyReferenceField()) - access_list = ListField(choices=access_level_choices, display_sep=', ') + access_list = ListField(choices=access_level_choices, display_sep=", ") User.drop_collection() BlogPost.drop_collection() - post = BlogPost(content='Went for a walk today...') + post = BlogPost(content="Went for a walk today...") post.validate() - post.tags = 'fun' + post.tags = "fun" self.assertRaises(ValidationError, post.validate) post.tags = [1, 2] self.assertRaises(ValidationError, post.validate) - post.tags = ['fun', 'leisure'] + post.tags = ["fun", "leisure"] post.validate() - post.tags = ('fun', 'leisure') + post.tags = ("fun", "leisure") post.validate() - post.access_list = 'a,b' + post.access_list = "a,b" self.assertRaises(ValidationError, post.validate) - post.access_list = ['c', 'd'] + post.access_list = ["c", "d"] self.assertRaises(ValidationError, post.validate) - post.access_list = ['a', 'b'] + post.access_list = ["a", "b"] post.validate() - self.assertEqual(post.get_access_list_display(), u'Administration, Manager') + self.assertEqual(post.get_access_list_display(), u"Administration, Manager") - post.comments = ['a'] + post.comments = ["a"] self.assertRaises(ValidationError, post.validate) - post.comments = 'yay' + post.comments = "yay" self.assertRaises(ValidationError, post.validate) - comments = [Comment(content='Good for you'), Comment(content='Yay.')] + comments = [Comment(content="Good for you"), Comment(content="Yay.")] post.comments = comments post.validate() @@ -485,28 +532,28 @@ class FieldTest(MongoDBTestCase): def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values. """ + class Comment(EmbeddedDocument): order = IntField() content = StringField() class BlogPost(Document): content = StringField() - comments = SortedListField(EmbeddedDocumentField(Comment), - ordering='order') + comments = SortedListField(EmbeddedDocumentField(Comment), ordering="order") tags = SortedListField(StringField()) BlogPost.drop_collection() - post = BlogPost(content='Went for a walk today...') + post = BlogPost(content="Went for a walk today...") post.save() - post.tags = ['leisure', 'fun'] + post.tags = ["leisure", "fun"] post.save() post.reload() - self.assertEqual(post.tags, ['fun', 'leisure']) + self.assertEqual(post.tags, ["fun", "leisure"]) - comment1 = Comment(content='Good for you', order=1) - comment2 = Comment(content='Yay.', order=0) + comment1 = Comment(content="Good for you", order=1) + comment2 = Comment(content="Yay.", order=0) comments = [comment1, comment2] post.comments = comments post.save() @@ -529,16 +576,17 @@ class FieldTest(MongoDBTestCase): name = StringField() class CategoryList(Document): - categories = SortedListField(EmbeddedDocumentField(Category), - ordering='count', reverse=True) + categories = SortedListField( + EmbeddedDocumentField(Category), ordering="count", reverse=True + ) name = StringField() CategoryList.drop_collection() catlist = CategoryList(name="Top categories") - cat1 = Category(name='posts', count=10) - cat2 = Category(name='food', count=100) - cat3 = Category(name='drink', count=40) + cat1 = Category(name="posts", count=10) + cat2 = Category(name="food", count=100) + cat3 = Category(name="drink", count=40) catlist.categories = [cat1, cat2, cat3] catlist.save() catlist.reload() @@ -549,57 +597,59 @@ class FieldTest(MongoDBTestCase): def test_list_field(self): """Ensure that list types work as expected.""" + class BlogPost(Document): info = ListField() BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' + post.info = "my post" self.assertRaises(ValidationError, post.validate) - post.info = {'title': 'test'} + post.info = {"title": "test"} self.assertRaises(ValidationError, post.validate) - post.info = ['test'] + post.info = ["test"] post.save() post = BlogPost() - post.info = [{'test': 'test'}] + post.info = [{"test": "test"}] post.save() post = BlogPost() - post.info = [{'test': 3}] + post.info = [{"test": 3}] post.save() self.assertEqual(BlogPost.objects.count(), 3) - self.assertEqual( - BlogPost.objects.filter(info__exact='test').count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__0__test='test').count(), 1) + self.assertEqual(BlogPost.objects.filter(info__exact="test").count(), 1) + self.assertEqual(BlogPost.objects.filter(info__0__test="test").count(), 1) # Confirm handles non strings or non existing keys + self.assertEqual(BlogPost.objects.filter(info__0__test__exact="5").count(), 0) self.assertEqual( - BlogPost.objects.filter(info__0__test__exact='5').count(), 0) - self.assertEqual( - BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + BlogPost.objects.filter(info__100__test__exact="test").count(), 0 + ) # test queries by list post = BlogPost() - post.info = ['1', '2'] + post.info = ["1", "2"] post.save() - post = BlogPost.objects(info=['1', '2']).get() - post.info += ['3', '4'] + post = BlogPost.objects(info=["1", "2"]).get() + post.info += ["3", "4"] post.save() - self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4']).count(), 1) - post = BlogPost.objects(info=['1', '2', '3', '4']).get() + self.assertEqual(BlogPost.objects(info=["1", "2", "3", "4"]).count(), 1) + post = BlogPost.objects(info=["1", "2", "3", "4"]).get() post.info *= 2 post.save() - self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) + self.assertEqual( + BlogPost.objects(info=["1", "2", "3", "4", "1", "2", "3", "4"]).count(), 1 + ) def test_list_field_manipulative_operators(self): """Ensure that ListField works with standard list operators that manipulate the list. """ + class BlogPost(Document): ref = StringField() info = ListField(StringField()) @@ -608,162 +658,178 @@ class FieldTest(MongoDBTestCase): post = BlogPost() post.ref = "1234" - post.info = ['0', '1', '2', '3', '4', '5'] + post.info = ["0", "1", "2", "3", "4", "5"] post.save() def reset_post(): - post.info = ['0', '1', '2', '3', '4', '5'] + post.info = ["0", "1", "2", "3", "4", "5"] post.save() # '__add__(listB)' # listA+listB # operator.add(listA, listB) reset_post() - temp = ['a', 'b'] + temp = ["a", "b"] post.info = post.info + temp - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) # '__delitem__(index)' # aka 'del list[index]' # aka 'operator.delitem(list, index)' reset_post() del post.info[2] # del from middle ('2') - self.assertEqual(post.info, ['0', '1', '3', '4', '5']) + self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '3', '4', '5']) + self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) # '__delitem__(slice(i, j))' # aka 'del list[i:j]' # aka 'operator.delitem(list, slice(i,j))' reset_post() del post.info[1:3] # removes '1', '2' - self.assertEqual(post.info, ['0', '3', '4', '5']) + self.assertEqual(post.info, ["0", "3", "4", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '3', '4', '5']) + self.assertEqual(post.info, ["0", "3", "4", "5"]) # '__iadd__' # aka 'list += list' reset_post() - temp = ['a', 'b'] + temp = ["a", "b"] post.info += temp - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'a', 'b']) + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "a", "b"]) # '__imul__' # aka 'list *= number' reset_post() post.info *= 2 - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) # '__mul__' # aka 'listA*listB' reset_post() post.info = post.info * 2 - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) # '__rmul__' # aka 'listB*listA' reset_post() post.info = 2 * post.info - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', '0', '1', '2', '3', '4', '5']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "0", "1", "2", "3", "4", "5"] + ) # '__setitem__(index, value)' # aka 'list[index]=value' # aka 'setitem(list, value)' reset_post() - post.info[4] = 'a' - self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) + post.info[4] = "a" + self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) + self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) # __setitem__(index, value) with a negative index reset_post() - post.info[-2] = 'a' - self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) + post.info[-2] = "a" + self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', 'a', '5']) + self.assertEqual(post.info, ["0", "1", "2", "3", "a", "5"]) # '__setitem__(slice(i, j), listB)' # aka 'listA[i:j] = listB' # aka 'setitem(listA, slice(i, j), listB)' reset_post() - post.info[1:3] = ['h', 'e', 'l', 'l', 'o'] - self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) + post.info[1:3] = ["h", "e", "l", "l", "o"] + self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) + self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) # '__setitem__(slice(i, j), listB)' with negative i and j reset_post() - post.info[-5:-3] = ['h', 'e', 'l', 'l', 'o'] - self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) + post.info[-5:-3] = ["h", "e", "l", "l", "o"] + self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', 'h', 'e', 'l', 'l', 'o', '3', '4', '5']) + self.assertEqual(post.info, ["0", "h", "e", "l", "l", "o", "3", "4", "5"]) # negative # 'append' reset_post() - post.info.append('h') - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h']) + post.info.append("h") + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h']) + self.assertEqual(post.info, ["0", "1", "2", "3", "4", "5", "h"]) # 'extend' reset_post() - post.info.extend(['h', 'e', 'l', 'l', 'o']) - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o']) + post.info.extend(["h", "e", "l", "l", "o"]) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] + ) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '2', '3', '4', '5', 'h', 'e', 'l', 'l', 'o']) + self.assertEqual( + post.info, ["0", "1", "2", "3", "4", "5", "h", "e", "l", "l", "o"] + ) # 'insert' # 'pop' reset_post() x = post.info.pop(2) y = post.info.pop() - self.assertEqual(post.info, ['0', '1', '3', '4']) - self.assertEqual(x, '2') - self.assertEqual(y, '5') + self.assertEqual(post.info, ["0", "1", "3", "4"]) + self.assertEqual(x, "2") + self.assertEqual(y, "5") post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '3', '4']) + self.assertEqual(post.info, ["0", "1", "3", "4"]) # 'remove' reset_post() - post.info.remove('2') - self.assertEqual(post.info, ['0', '1', '3', '4', '5']) + post.info.remove("2") + self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) post.save() post.reload() - self.assertEqual(post.info, ['0', '1', '3', '4', '5']) + self.assertEqual(post.info, ["0", "1", "3", "4", "5"]) # 'reverse' reset_post() post.info.reverse() - self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0']) + self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) post.save() post.reload() - self.assertEqual(post.info, ['5', '4', '3', '2', '1', '0']) + self.assertEqual(post.info, ["5", "4", "3", "2", "1", "0"]) # 'sort': though this operator method does manipulate the list, it is # tested in the 'test_list_field_lexicograpic_operators' function @@ -775,7 +841,7 @@ class FieldTest(MongoDBTestCase): post = BlogPost() post.ref = "1234" - post.info = ['0', '1', '2', '3', '4', '5'] + post.info = ["0", "1", "2", "3", "4", "5"] # '__hash__' # aka 'hash(list)' @@ -785,6 +851,7 @@ class FieldTest(MongoDBTestCase): """Ensure that ListField works with standard list operators that do lexigraphic ordering. """ + class BlogPost(Document): ref = StringField() text_info = ListField(StringField()) @@ -810,7 +877,7 @@ class FieldTest(MongoDBTestCase): blogLargeB.oid_info = [ "54495ad94c934721ede76f90", "54495ad94c934721ede76d23", - "54495ad94c934721ede76d00" + "54495ad94c934721ede76d00", ] blogLargeB.bool_info = [False, True] blogLargeB.save() @@ -852,7 +919,7 @@ class FieldTest(MongoDBTestCase): sorted_target_list = [ ObjectId("54495ad94c934721ede76d00"), ObjectId("54495ad94c934721ede76d23"), - ObjectId("54495ad94c934721ede76f90") + ObjectId("54495ad94c934721ede76f90"), ] self.assertEqual(blogLargeB.text_info, ["a", "j", "z"]) self.assertEqual(blogLargeB.oid_info, sorted_target_list) @@ -865,13 +932,14 @@ class FieldTest(MongoDBTestCase): def test_list_assignment(self): """Ensure that list field element assignment and slicing work.""" + class BlogPost(Document): info = ListField() BlogPost.drop_collection() post = BlogPost() - post.info = ['e1', 'e2', 3, '4', 5] + post.info = ["e1", "e2", 3, "4", 5] post.save() post.info[0] = 1 @@ -879,35 +947,35 @@ class FieldTest(MongoDBTestCase): post.reload() self.assertEqual(post.info[0], 1) - post.info[1:3] = ['n2', 'n3'] + post.info[1:3] = ["n2", "n3"] post.save() post.reload() - self.assertEqual(post.info, [1, 'n2', 'n3', '4', 5]) + self.assertEqual(post.info, [1, "n2", "n3", "4", 5]) - post.info[-1] = 'n5' + post.info[-1] = "n5" post.save() post.reload() - self.assertEqual(post.info, [1, 'n2', 'n3', '4', 'n5']) + self.assertEqual(post.info, [1, "n2", "n3", "4", "n5"]) post.info[-2] = 4 post.save() post.reload() - self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) + self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) post.info[1:-1] = [2] post.save() post.reload() - self.assertEqual(post.info, [1, 2, 'n5']) + self.assertEqual(post.info, [1, 2, "n5"]) - post.info[:-1] = [1, 'n2', 'n3', 4] + post.info[:-1] = [1, "n2", "n3", 4] post.save() post.reload() - self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5']) + self.assertEqual(post.info, [1, "n2", "n3", 4, "n5"]) post.info[-4:3] = [2, 3] post.save() post.reload() - self.assertEqual(post.info, [1, 2, 3, 4, 'n5']) + self.assertEqual(post.info, [1, 2, 3, 4, "n5"]) def test_list_field_passed_in_value(self): class Foo(Document): @@ -921,12 +989,13 @@ class FieldTest(MongoDBTestCase): foo = Foo(bars=[]) foo.bars.append(bar) - self.assertEqual(repr(foo.bars), '[]') + self.assertEqual(repr(foo.bars), "[]") def test_list_field_strict(self): """Ensure that list field handles validation if provided a strict field type. """ + class Simple(Document): mapping = ListField(field=IntField()) @@ -943,17 +1012,19 @@ class FieldTest(MongoDBTestCase): def test_list_field_rejects_strings(self): """Strings aren't valid list field data types.""" + class Simple(Document): mapping = ListField() Simple.drop_collection() e = Simple() - e.mapping = 'hello world' + e.mapping = "hello world" self.assertRaises(ValidationError, e.save) def test_complex_field_required(self): """Ensure required cant be None / Empty.""" + class Simple(Document): mapping = ListField(required=True) @@ -975,6 +1046,7 @@ class FieldTest(MongoDBTestCase): """If a complex field is set to the same value, it should not be marked as changed. """ + class Simple(Document): mapping = ListField() @@ -999,7 +1071,7 @@ class FieldTest(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[:3] = [] - self.assertEqual(['widgets'], simple._changed_fields) + self.assertEqual(["widgets"], simple._changed_fields) simple.save() simple = simple.reload() @@ -1011,7 +1083,7 @@ class FieldTest(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() del simple.widgets[:3] - self.assertEqual(['widgets'], simple._changed_fields) + self.assertEqual(["widgets"], simple._changed_fields) simple.save() simple = simple.reload() @@ -1023,7 +1095,7 @@ class FieldTest(MongoDBTestCase): simple = Simple(widgets=[1, 2, 3, 4]).save() simple.widgets[-1] = 5 - self.assertEqual(['widgets.3'], simple._changed_fields) + self.assertEqual(["widgets.3"], simple._changed_fields) simple.save() simple = simple.reload() @@ -1031,8 +1103,9 @@ class FieldTest(MongoDBTestCase): def test_list_field_complex(self): """Ensure that the list fields can handle the complex types.""" + class SettingBase(EmbeddedDocument): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class StringSetting(SettingBase): value = StringField() @@ -1046,12 +1119,17 @@ class FieldTest(MongoDBTestCase): Simple.drop_collection() e = Simple() - e.mapping.append(StringSetting(value='foo')) + e.mapping.append(StringSetting(value="foo")) e.mapping.append(IntegerSetting(value=42)) - e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, - 'complex': IntegerSetting(value=42), - 'list': [IntegerSetting(value=42), - StringSetting(value='foo')]}) + e.mapping.append( + { + "number": 1, + "string": "Hi!", + "float": 1.001, + "complex": IntegerSetting(value=42), + "list": [IntegerSetting(value=42), StringSetting(value="foo")], + } + ) e.save() e2 = Simple.objects.get(id=e.id) @@ -1059,35 +1137,36 @@ class FieldTest(MongoDBTestCase): self.assertIsInstance(e2.mapping[1], IntegerSetting) # Test querying + self.assertEqual(Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEqual(Simple.objects.filter(mapping__2__number=1).count(), 1) self.assertEqual( - Simple.objects.filter(mapping__1__value=42).count(), 1) + Simple.objects.filter(mapping__2__complex__value=42).count(), 1 + ) self.assertEqual( - Simple.objects.filter(mapping__2__number=1).count(), 1) + Simple.objects.filter(mapping__2__list__0__value=42).count(), 1 + ) self.assertEqual( - Simple.objects.filter(mapping__2__complex__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + Simple.objects.filter(mapping__2__list__1__value="foo").count(), 1 + ) # Confirm can update Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) - self.assertEqual( - Simple.objects.filter(mapping__1__value=10).count(), 1) + self.assertEqual(Simple.objects.filter(mapping__1__value=10).count(), 1) - Simple.objects().update( - set__mapping__2__list__1=StringSetting(value='Boo')) + Simple.objects().update(set__mapping__2__list__1=StringSetting(value="Boo")) self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + Simple.objects.filter(mapping__2__list__1__value="foo").count(), 0 + ) self.assertEqual( - Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + Simple.objects.filter(mapping__2__list__1__value="Boo").count(), 1 + ) def test_embedded_db_field(self): class Embedded(EmbeddedDocument): - number = IntField(default=0, db_field='i') + number = IntField(default=0, db_field="i") class Test(Document): - embedded = EmbeddedDocumentField(Embedded, db_field='x') + embedded = EmbeddedDocumentField(Embedded, db_field="x") Test.drop_collection() @@ -1100,58 +1179,54 @@ class FieldTest(MongoDBTestCase): test = Test.objects.get() self.assertEqual(test.embedded.number, 2) doc = self.db.test.find_one() - self.assertEqual(doc['x']['i'], 2) + self.assertEqual(doc["x"]["i"], 2) def test_double_embedded_db_field(self): """Make sure multiple layers of embedded docs resolve db fields properly and can be initialized using dicts. """ + class C(EmbeddedDocument): txt = StringField() class B(EmbeddedDocument): - c = EmbeddedDocumentField(C, db_field='fc') + c = EmbeddedDocumentField(C, db_field="fc") class A(Document): - b = EmbeddedDocumentField(B, db_field='fb') + b = EmbeddedDocumentField(B, db_field="fb") - a = A( - b=B( - c=C(txt='hi') - ) - ) + a = A(b=B(c=C(txt="hi"))) a.validate() - a = A(b={'c': {'txt': 'hi'}}) + a = A(b={"c": {"txt": "hi"}}) a.validate() def test_double_embedded_db_field_from_son(self): """Make sure multiple layers of embedded docs resolve db fields from SON properly. """ + class C(EmbeddedDocument): txt = StringField() class B(EmbeddedDocument): - c = EmbeddedDocumentField(C, db_field='fc') + c = EmbeddedDocumentField(C, db_field="fc") class A(Document): - b = EmbeddedDocumentField(B, db_field='fb') + b = EmbeddedDocumentField(B, db_field="fb") - a = A._from_son(SON([ - ('fb', SON([ - ('fc', SON([ - ('txt', 'hi') - ])) - ])) - ])) - self.assertEqual(a.b.c.txt, 'hi') + a = A._from_son(SON([("fb", SON([("fc", SON([("txt", "hi")]))]))])) + self.assertEqual(a.b.c.txt, "hi") - def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet(self): - raise SkipTest("Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet") + def test_embedded_document_field_cant_reference_using_a_str_if_it_does_not_exist_yet( + self + ): + raise SkipTest( + "Using a string reference in an EmbeddedDocumentField does not work if the class isnt registerd yet" + ) class MyDoc2(Document): - emb = EmbeddedDocumentField('MyDoc') + emb = EmbeddedDocumentField("MyDoc") class MyDoc(EmbeddedDocument): name = StringField() @@ -1160,6 +1235,7 @@ class FieldTest(MongoDBTestCase): """Ensure that invalid embedded documents cannot be assigned to embedded document fields. """ + class Comment(EmbeddedDocument): content = StringField() @@ -1173,30 +1249,31 @@ class FieldTest(MongoDBTestCase): Person.drop_collection() - person = Person(name='Test User') - person.preferences = 'My Preferences' + person = Person(name="Test User") + person.preferences = "My Preferences" self.assertRaises(ValidationError, person.validate) # Check that only the right embedded doc works - person.preferences = Comment(content='Nice blog post...') + person.preferences = Comment(content="Nice blog post...") self.assertRaises(ValidationError, person.validate) # Check that the embedded doc is valid person.preferences = PersonPreferences() self.assertRaises(ValidationError, person.validate) - person.preferences = PersonPreferences(food='Cheese', number=47) - self.assertEqual(person.preferences.food, 'Cheese') + person.preferences = PersonPreferences(food="Cheese", number=47) + self.assertEqual(person.preferences.food, "Cheese") person.validate() def test_embedded_document_inheritance(self): """Ensure that subclasses of embedded documents may be provided to EmbeddedDocumentFields of the superclass' type. """ + class User(EmbeddedDocument): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class PowerUser(User): power = IntField() @@ -1207,8 +1284,8 @@ class FieldTest(MongoDBTestCase): BlogPost.drop_collection() - post = BlogPost(content='What I did today...') - post.author = PowerUser(name='Test User', power=47) + post = BlogPost(content="What I did today...") + post.author = PowerUser(name="Test User", power=47) post.save() self.assertEqual(47, BlogPost.objects.first().author.power) @@ -1217,21 +1294,22 @@ class FieldTest(MongoDBTestCase): """Ensure that nested list of subclassed embedded documents is handled correctly. """ + class Group(EmbeddedDocument): name = StringField() content = ListField(StringField()) class Basedoc(Document): groups = ListField(EmbeddedDocumentField(Group)) - meta = {'abstract': True} + meta = {"abstract": True} class User(Basedoc): - doctype = StringField(require=True, default='userdata') + doctype = StringField(require=True, default="userdata") User.drop_collection() - content = ['la', 'le', 'lu'] - group = Group(name='foo', content=content) + content = ["la", "le", "lu"] + group = Group(name="foo", content=content) foobar = User(groups=[group]) foobar.save() @@ -1241,6 +1319,7 @@ class FieldTest(MongoDBTestCase): """Ensure an exception is raised when dereferencing an unknown document. """ + class Foo(Document): pass @@ -1257,20 +1336,21 @@ class FieldTest(MongoDBTestCase): # Reference is no longer valid foo.delete() bar = Bar.objects.get() - self.assertRaises(DoesNotExist, getattr, bar, 'ref') - self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref') + self.assertRaises(DoesNotExist, getattr, bar, "ref") + self.assertRaises(DoesNotExist, getattr, bar, "generic_ref") # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() expected = foo.to_dbref() - bar._fields['ref']._auto_dereference = False + bar._fields["ref"]._auto_dereference = False self.assertEqual(bar.ref, expected) - bar._fields['generic_ref']._auto_dereference = False - self.assertEqual(bar.generic_ref, {'_ref': expected, '_cls': 'Foo'}) + bar._fields["generic_ref"]._auto_dereference = False + self.assertEqual(bar.generic_ref, {"_ref": expected, "_cls": "Foo"}) def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -1280,9 +1360,9 @@ class FieldTest(MongoDBTestCase): User.drop_collection() Group.drop_collection() - user1 = User(name='user1') + user1 = User(name="user1") user1.save() - user2 = User(name='user2') + user2 = User(name="user2") user2.save() group = Group(members=[user1, user2]) @@ -1296,24 +1376,25 @@ class FieldTest(MongoDBTestCase): def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ + class Employee(Document): name = StringField() - boss = ReferenceField('self') - friends = ListField(ReferenceField('self')) + boss = ReferenceField("self") + friends = ListField(ReferenceField("self")) Employee.drop_collection() - bill = Employee(name='Bill Lumbergh') + bill = Employee(name="Bill Lumbergh") bill.save() - michael = Employee(name='Michael Bolton') + michael = Employee(name="Michael Bolton") michael.save() - samir = Employee(name='Samir Nagheenanajar') + samir = Employee(name="Samir Nagheenanajar") samir.save() friends = [michael, samir] - peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter = Employee(name="Peter Gibbons", boss=bill, friends=friends) peter.save() peter = Employee.objects.with_id(peter.id) @@ -1323,13 +1404,14 @@ class FieldTest(MongoDBTestCase): def test_recursive_embedding(self): """Ensure that EmbeddedDocumentFields can contain their own documents. """ + class TreeNode(EmbeddedDocument): name = StringField() - children = ListField(EmbeddedDocumentField('self')) + children = ListField(EmbeddedDocumentField("self")) class Tree(Document): name = StringField() - children = ListField(EmbeddedDocumentField('TreeNode')) + children = ListField(EmbeddedDocumentField("TreeNode")) Tree.drop_collection() @@ -1356,18 +1438,18 @@ class FieldTest(MongoDBTestCase): self.assertEqual(tree.children[0].children[1].name, third_child.name) # Test updating - tree.children[0].name = 'I am Child 1' - tree.children[0].children[0].name = 'I am Child 2' - tree.children[0].children[1].name = 'I am Child 3' + tree.children[0].name = "I am Child 1" + tree.children[0].children[0].name = "I am Child 2" + tree.children[0].children[1].name = "I am Child 3" tree.save() - self.assertEqual(tree.children[0].name, 'I am Child 1') - self.assertEqual(tree.children[0].children[0].name, 'I am Child 2') - self.assertEqual(tree.children[0].children[1].name, 'I am Child 3') + self.assertEqual(tree.children[0].name, "I am Child 1") + self.assertEqual(tree.children[0].children[0].name, "I am Child 2") + self.assertEqual(tree.children[0].children[1].name, "I am Child 3") # Test removal self.assertEqual(len(tree.children[0].children), 2) - del(tree.children[0].children[1]) + del tree.children[0].children[1] tree.save() self.assertEqual(len(tree.children[0].children), 1) @@ -1388,6 +1470,7 @@ class FieldTest(MongoDBTestCase): """Ensure that an abstract document cannot be dropped given it has no underlying collection. """ + class AbstractDoc(Document): name = StringField() meta = {"abstract": True} @@ -1397,6 +1480,7 @@ class FieldTest(MongoDBTestCase): def test_reference_class_with_abstract_parent(self): """Ensure that a class with an abstract parent can be referenced. """ + class Sibling(Document): name = StringField() meta = {"abstract": True} @@ -1421,6 +1505,7 @@ class FieldTest(MongoDBTestCase): """Ensure that an abstract class instance cannot be used in the reference of that abstract class. """ + class Sibling(Document): name = StringField() meta = {"abstract": True} @@ -1442,6 +1527,7 @@ class FieldTest(MongoDBTestCase): """Ensure that an an abstract reference fails validation when given a Document that does not inherit from the abstract type. """ + class Sibling(Document): name = StringField() meta = {"abstract": True} @@ -1463,9 +1549,10 @@ class FieldTest(MongoDBTestCase): def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """ + class Link(Document): title = StringField() - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class Post(Document): title = StringField() @@ -1502,6 +1589,7 @@ class FieldTest(MongoDBTestCase): def test_generic_reference_list(self): """Ensure that a ListField properly dereferences generic references. """ + class Link(Document): title = StringField() @@ -1533,6 +1621,7 @@ class FieldTest(MongoDBTestCase): """Ensure dereferencing out of the document registry throws a `NotRegistered` error. """ + class Link(Document): title = StringField() @@ -1550,7 +1639,7 @@ class FieldTest(MongoDBTestCase): # Mimic User and Link definitions being in a different file # and the Link model not being imported in the User file. - del(_document_registry["Link"]) + del _document_registry["Link"] user = User.objects.first() try: @@ -1560,7 +1649,6 @@ class FieldTest(MongoDBTestCase): pass def test_generic_reference_is_none(self): - class Person(Document): name = StringField() city = GenericReferenceField() @@ -1568,11 +1656,11 @@ class FieldTest(MongoDBTestCase): Person.drop_collection() Person(name="Wilson Jr").save() - self.assertEqual(repr(Person.objects(city=None)), - "[]") + self.assertEqual(repr(Person.objects(city=None)), "[]") def test_generic_reference_choices(self): """Ensure that a GenericReferenceField can handle choices.""" + class Link(Document): title = StringField() @@ -1604,6 +1692,7 @@ class FieldTest(MongoDBTestCase): def test_generic_reference_string_choices(self): """Ensure that a GenericReferenceField can handle choices as strings """ + class Link(Document): title = StringField() @@ -1611,7 +1700,7 @@ class FieldTest(MongoDBTestCase): title = StringField() class Bookmark(Document): - bookmark_object = GenericReferenceField(choices=('Post', Link)) + bookmark_object = GenericReferenceField(choices=("Post", Link)) Link.drop_collection() Post.drop_collection() @@ -1636,11 +1725,12 @@ class FieldTest(MongoDBTestCase): """Ensure that a GenericReferenceField can handle choices on non-derefenreced (i.e. DBRef) elements """ + class Post(Document): title = StringField() class Bookmark(Document): - bookmark_object = GenericReferenceField(choices=(Post, )) + bookmark_object = GenericReferenceField(choices=(Post,)) other_field = StringField() Post.drop_collection() @@ -1654,13 +1744,14 @@ class FieldTest(MongoDBTestCase): bm = Bookmark.objects.get(id=bm.id) # bookmark_object is now a DBRef - bm.other_field = 'dummy_change' + bm.other_field = "dummy_change" bm.save() def test_generic_reference_list_choices(self): """Ensure that a ListField properly dereferences generic references and respects choices. """ + class Link(Document): title = StringField() @@ -1692,6 +1783,7 @@ class FieldTest(MongoDBTestCase): def test_generic_reference_list_item_modification(self): """Ensure that modifications of related documents (through generic reference) don't influence on querying """ + class Post(Document): title = StringField() @@ -1721,6 +1813,7 @@ class FieldTest(MongoDBTestCase): """Ensure we can search for a specific generic reference by providing its ObjectId. """ + class Doc(Document): ref = GenericReferenceField() @@ -1729,13 +1822,14 @@ class FieldTest(MongoDBTestCase): doc1 = Doc.objects.create() doc2 = Doc.objects.create(ref=doc1) - doc = Doc.objects.get(ref=DBRef('doc', doc1.pk)) + doc = Doc.objects.get(ref=DBRef("doc", doc1.pk)) self.assertEqual(doc, doc2) def test_generic_reference_is_not_tracked_in_parent_doc(self): """Ensure that modifications of related documents (through generic reference) don't influence the owner changed fields (#1934) """ + class Doc1(Document): name = StringField() @@ -1746,14 +1840,14 @@ class FieldTest(MongoDBTestCase): Doc1.drop_collection() Doc2.drop_collection() - doc1 = Doc1(name='garbage1').save() - doc11 = Doc1(name='garbage11').save() + doc1 = Doc1(name="garbage1").save() + doc11 = Doc1(name="garbage11").save() doc2 = Doc2(ref=doc1, refs=[doc11]).save() - doc2.ref.name = 'garbage2' + doc2.ref.name = "garbage2" self.assertEqual(doc2._get_changed_fields(), []) - doc2.refs[0].name = 'garbage3' + doc2.refs[0].name = "garbage3" self.assertEqual(doc2._get_changed_fields(), []) self.assertEqual(doc2._delta(), ({}, {})) @@ -1761,6 +1855,7 @@ class FieldTest(MongoDBTestCase): """Ensure we can search for a specific generic reference by providing its DBRef. """ + class Doc(Document): ref = GenericReferenceField() @@ -1777,17 +1872,19 @@ class FieldTest(MongoDBTestCase): def test_choices_allow_using_sets_as_choices(self): """Ensure that sets can be used when setting choices """ - class Shirt(Document): - size = StringField(choices={'M', 'L'}) - Shirt(size='M').validate() + class Shirt(Document): + size = StringField(choices={"M", "L"}) + + Shirt(size="M").validate() def test_choices_validation_allow_no_value(self): """Ensure that .validate passes and no value was provided for a field setup with choices """ + class Shirt(Document): - size = StringField(choices=('S', 'M')) + size = StringField(choices=("S", "M")) shirt = Shirt() shirt.validate() @@ -1795,17 +1892,19 @@ class FieldTest(MongoDBTestCase): def test_choices_validation_accept_possible_value(self): """Ensure that value is in a container of allowed values. """ - class Shirt(Document): - size = StringField(choices=('S', 'M')) - shirt = Shirt(size='S') + class Shirt(Document): + size = StringField(choices=("S", "M")) + + shirt = Shirt(size="S") shirt.validate() def test_choices_validation_reject_unknown_value(self): """Ensure that unallowed value are rejected upon validation """ + class Shirt(Document): - size = StringField(choices=('S', 'M')) + size = StringField(choices=("S", "M")) shirt = Shirt(size="XS") with self.assertRaises(ValidationError): @@ -1815,12 +1914,23 @@ class FieldTest(MongoDBTestCase): """Test dynamic helper for returning the display value of a choices field. """ + class Shirt(Document): - size = StringField(max_length=3, choices=( - ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), - ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) - style = StringField(max_length=3, choices=( - ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') + size = StringField( + max_length=3, + choices=( + ("S", "Small"), + ("M", "Medium"), + ("L", "Large"), + ("XL", "Extra Large"), + ("XXL", "Extra Extra Large"), + ), + ) + style = StringField( + max_length=3, + choices=(("S", "Small"), ("B", "Baggy"), ("W", "Wide")), + default="W", + ) Shirt.drop_collection() @@ -1829,30 +1939,30 @@ class FieldTest(MongoDBTestCase): # Make sure get__display returns the default value (or None) self.assertEqual(shirt1.get_size_display(), None) - self.assertEqual(shirt1.get_style_display(), 'Wide') + self.assertEqual(shirt1.get_style_display(), "Wide") - shirt1.size = 'XXL' - shirt1.style = 'B' - shirt2.size = 'M' - shirt2.style = 'S' - self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') - self.assertEqual(shirt1.get_style_display(), 'Baggy') - self.assertEqual(shirt2.get_size_display(), 'Medium') - self.assertEqual(shirt2.get_style_display(), 'Small') + shirt1.size = "XXL" + shirt1.style = "B" + shirt2.size = "M" + shirt2.style = "S" + self.assertEqual(shirt1.get_size_display(), "Extra Extra Large") + self.assertEqual(shirt1.get_style_display(), "Baggy") + self.assertEqual(shirt2.get_size_display(), "Medium") + self.assertEqual(shirt2.get_style_display(), "Small") # Set as Z - an invalid choice - shirt1.size = 'Z' - shirt1.style = 'Z' - self.assertEqual(shirt1.get_size_display(), 'Z') - self.assertEqual(shirt1.get_style_display(), 'Z') + shirt1.size = "Z" + shirt1.style = "Z" + self.assertEqual(shirt1.get_size_display(), "Z") + self.assertEqual(shirt1.get_style_display(), "Z") self.assertRaises(ValidationError, shirt1.validate) def test_simple_choices_validation(self): """Ensure that value is in a container of allowed values. """ + class Shirt(Document): - size = StringField(max_length=3, - choices=('S', 'M', 'L', 'XL', 'XXL')) + size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL")) Shirt.drop_collection() @@ -1869,37 +1979,37 @@ class FieldTest(MongoDBTestCase): """Test dynamic helper for returning the display value of a choices field. """ + class Shirt(Document): - size = StringField(max_length=3, - choices=('S', 'M', 'L', 'XL', 'XXL')) - style = StringField(max_length=3, - choices=('Small', 'Baggy', 'wide'), - default='Small') + size = StringField(max_length=3, choices=("S", "M", "L", "XL", "XXL")) + style = StringField( + max_length=3, choices=("Small", "Baggy", "wide"), default="Small" + ) Shirt.drop_collection() shirt = Shirt() self.assertEqual(shirt.get_size_display(), None) - self.assertEqual(shirt.get_style_display(), 'Small') + self.assertEqual(shirt.get_style_display(), "Small") shirt.size = "XXL" shirt.style = "Baggy" - self.assertEqual(shirt.get_size_display(), 'XXL') - self.assertEqual(shirt.get_style_display(), 'Baggy') + self.assertEqual(shirt.get_size_display(), "XXL") + self.assertEqual(shirt.get_style_display(), "Baggy") # Set as Z - an invalid choice shirt.size = "Z" shirt.style = "Z" - self.assertEqual(shirt.get_size_display(), 'Z') - self.assertEqual(shirt.get_style_display(), 'Z') + self.assertEqual(shirt.get_size_display(), "Z") + self.assertEqual(shirt.get_style_display(), "Z") self.assertRaises(ValidationError, shirt.validate) def test_simple_choices_validation_invalid_value(self): """Ensure that error messages are correct. """ - SIZES = ('S', 'M', 'L', 'XL', 'XXL') - COLORS = (('R', 'Red'), ('B', 'Blue')) + SIZES = ("S", "M", "L", "XL", "XXL") + COLORS = (("R", "Red"), ("B", "Blue")) SIZE_MESSAGE = u"Value must be one of ('S', 'M', 'L', 'XL', 'XXL')" COLOR_MESSAGE = u"Value must be one of ['R', 'B']" @@ -1924,11 +2034,12 @@ class FieldTest(MongoDBTestCase): except ValidationError as error: # get the validation rules error_dict = error.to_dict() - self.assertEqual(error_dict['size'], SIZE_MESSAGE) - self.assertEqual(error_dict['color'], COLOR_MESSAGE) + self.assertEqual(error_dict["size"], SIZE_MESSAGE) + self.assertEqual(error_dict["color"], COLOR_MESSAGE) def test_recursive_validation(self): """Ensure that a validation result to_dict is available.""" + class Author(EmbeddedDocument): name = StringField(required=True) @@ -1940,9 +2051,9 @@ class FieldTest(MongoDBTestCase): title = StringField(required=True) comments = ListField(EmbeddedDocumentField(Comment)) - bob = Author(name='Bob') - post = Post(title='hello world') - post.comments.append(Comment(content='hello', author=bob)) + bob = Author(name="Bob") + post = Post(title="hello world") + post.comments.append(Comment(content="hello", author=bob)) post.comments.append(Comment(author=bob)) self.assertRaises(ValidationError, post.validate) @@ -1950,30 +2061,31 @@ class FieldTest(MongoDBTestCase): post.validate() except ValidationError as error: # ValidationError.errors property - self.assertTrue(hasattr(error, 'errors')) + self.assertTrue(hasattr(error, "errors")) self.assertIsInstance(error.errors, dict) - self.assertIn('comments', error.errors) - self.assertIn(1, error.errors['comments']) - self.assertIsInstance(error.errors['comments'][1]['content'], ValidationError) + self.assertIn("comments", error.errors) + self.assertIn(1, error.errors["comments"]) + self.assertIsInstance( + error.errors["comments"][1]["content"], ValidationError + ) # ValidationError.schema property error_dict = error.to_dict() self.assertIsInstance(error_dict, dict) - self.assertIn('comments', error_dict) - self.assertIn(1, error_dict['comments']) - self.assertIn('content', error_dict['comments'][1]) - self.assertEqual(error_dict['comments'][1]['content'], - u'Field is required') + self.assertIn("comments", error_dict) + self.assertIn(1, error_dict["comments"]) + self.assertIn("content", error_dict["comments"][1]) + self.assertEqual(error_dict["comments"][1]["content"], u"Field is required") - post.comments[1].content = 'here we go' + post.comments[1].content = "here we go" post.validate() def test_tuples_as_tuples(self): """Ensure that tuples remain tuples when they are inside a ComplexBaseField. """ - class EnumField(BaseField): + class EnumField(BaseField): def __init__(self, **kwargs): super(EnumField, self).__init__(**kwargs) @@ -1988,7 +2100,7 @@ class FieldTest(MongoDBTestCase): TestDoc.drop_collection() - tuples = [(100, 'Testing')] + tuples = [(100, "Testing")] doc = TestDoc() doc.items = tuples doc.save() @@ -2000,12 +2112,12 @@ class FieldTest(MongoDBTestCase): def test_dynamic_fields_class(self): class Doc2(Document): - field_1 = StringField(db_field='f') + field_1 = StringField(db_field="f") class Doc(Document): my_id = IntField(primary_key=True) - embed_me = DynamicField(db_field='e') - field_x = StringField(db_field='x') + embed_me = DynamicField(db_field="e") + field_x = StringField(db_field="x") Doc.drop_collection() Doc2.drop_collection() @@ -2022,12 +2134,12 @@ class FieldTest(MongoDBTestCase): def test_dynamic_fields_embedded_class(self): class Embed(EmbeddedDocument): - field_1 = StringField(db_field='f') + field_1 = StringField(db_field="f") class Doc(Document): my_id = IntField(primary_key=True) - embed_me = DynamicField(db_field='e') - field_x = StringField(db_field='x') + embed_me = DynamicField(db_field="e") + field_x = StringField(db_field="x") Doc.drop_collection() @@ -2038,6 +2150,7 @@ class FieldTest(MongoDBTestCase): def test_dynamicfield_dump_document(self): """Ensure a DynamicField can handle another document's dump.""" + class Doc(Document): field = DynamicField() @@ -2049,7 +2162,7 @@ class FieldTest(MongoDBTestCase): id = IntField(primary_key=True, default=1) recursive = DynamicField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class ToEmbedChild(ToEmbedParent): pass @@ -2070,7 +2183,7 @@ class FieldTest(MongoDBTestCase): def test_cls_field(self): class Animal(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Fish(Animal): pass @@ -2088,7 +2201,9 @@ class FieldTest(MongoDBTestCase): Dog().save() Fish().save() Human().save() - self.assertEqual(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) + self.assertEqual( + Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2 + ) self.assertEqual(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) def test_sparse_field(self): @@ -2104,32 +2219,34 @@ class FieldTest(MongoDBTestCase): trying to instantiate a document with a field that's not defined. """ + class Doc(Document): foo = StringField() with self.assertRaises(FieldDoesNotExist): - Doc(bar='test') + Doc(bar="test") def test_undefined_field_exception_with_strict(self): """Tests if a `FieldDoesNotExist` exception is raised when trying to instantiate a document with a field that's not defined, even when strict is set to False. """ + class Doc(Document): foo = StringField() - meta = {'strict': False} + meta = {"strict": False} with self.assertRaises(FieldDoesNotExist): - Doc(bar='test') + Doc(bar="test") class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): - def setUp(self): """ Create two BlogPost entries in the database, each with several EmbeddedDocuments. """ + class Comments(EmbeddedDocument): author = StringField() message = StringField() @@ -2142,20 +2259,24 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.Comments = Comments self.BlogPost = BlogPost - self.post1 = self.BlogPost(comments=[ - self.Comments(author='user1', message='message1'), - self.Comments(author='user2', message='message1') - ]).save() + self.post1 = self.BlogPost( + comments=[ + self.Comments(author="user1", message="message1"), + self.Comments(author="user2", message="message1"), + ] + ).save() - self.post2 = self.BlogPost(comments=[ - self.Comments(author='user2', message='message2'), - self.Comments(author='user2', message='message3'), - self.Comments(author='user3', message='message1') - ]).save() + self.post2 = self.BlogPost( + comments=[ + self.Comments(author="user2", message="message2"), + self.Comments(author="user2", message="message3"), + self.Comments(author="user3", message="message1"), + ] + ).save() def test_fails_upon_validate_if_provide_a_doc_instead_of_a_list_of_doc(self): # Relates to Issue #1464 - comment = self.Comments(author='John') + comment = self.Comments(author="John") class Title(Document): content = StringField() @@ -2166,14 +2287,18 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): with self.assertRaises(ValidationError) as ctx_err: post.validate() self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn('Only lists and tuples may be used in a list field', str(ctx_err.exception)) + self.assertIn( + "Only lists and tuples may be used in a list field", str(ctx_err.exception) + ) # Test with a Document - post = self.BlogPost(comments=Title(content='garbage')) + post = self.BlogPost(comments=Title(content="garbage")) with self.assertRaises(ValidationError) as e: post.validate() self.assertIn("'comments'", str(ctx_err.exception)) - self.assertIn('Only lists and tuples may be used in a list field', str(ctx_err.exception)) + self.assertIn( + "Only lists and tuples may be used in a list field", str(ctx_err.exception) + ) def test_no_keyword_filter(self): """ @@ -2190,44 +2315,40 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests the filter method of a List of Embedded Documents with a single keyword. """ - filtered = self.post1.comments.filter(author='user1') + filtered = self.post1.comments.filter(author="user1") # Ensure only 1 entry was returned. self.assertEqual(len(filtered), 1) # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, 'user1') + self.assertEqual(filtered[0].author, "user1") def test_multi_keyword_filter(self): """ Tests the filter method of a List of Embedded Documents with multiple keywords. """ - filtered = self.post2.comments.filter( - author='user2', message='message2' - ) + filtered = self.post2.comments.filter(author="user2", message="message2") # Ensure only 1 entry was returned. self.assertEqual(len(filtered), 1) # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, 'user2') - self.assertEqual(filtered[0].message, 'message2') + self.assertEqual(filtered[0].author, "user2") + self.assertEqual(filtered[0].message, "message2") def test_chained_filter(self): """ Tests chained filter methods of a List of Embedded Documents """ - filtered = self.post2.comments.filter(author='user2').filter( - message='message2' - ) + filtered = self.post2.comments.filter(author="user2").filter(message="message2") # Ensure only 1 entry was returned. self.assertEqual(len(filtered), 1) # Ensure the entry returned is the correct entry. - self.assertEqual(filtered[0].author, 'user2') - self.assertEqual(filtered[0].message, 'message2') + self.assertEqual(filtered[0].author, "user2") + self.assertEqual(filtered[0].message, "message2") def test_unknown_keyword_filter(self): """ @@ -2252,36 +2373,34 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests the exclude method of a List of Embedded Documents with a single keyword. """ - excluded = self.post1.comments.exclude(author='user1') + excluded = self.post1.comments.exclude(author="user1") # Ensure only 1 entry was returned. self.assertEqual(len(excluded), 1) # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, 'user2') + self.assertEqual(excluded[0].author, "user2") def test_multi_keyword_exclude(self): """ Tests the exclude method of a List of Embedded Documents with multiple keywords. """ - excluded = self.post2.comments.exclude( - author='user3', message='message1' - ) + excluded = self.post2.comments.exclude(author="user3", message="message1") # Ensure only 2 entries were returned. self.assertEqual(len(excluded), 2) # Ensure the entries returned are the correct entries. - self.assertEqual(excluded[0].author, 'user2') - self.assertEqual(excluded[1].author, 'user2') + self.assertEqual(excluded[0].author, "user2") + self.assertEqual(excluded[1].author, "user2") def test_non_matching_exclude(self): """ Tests the exclude method of a List of Embedded Documents when the keyword does not match any entries. """ - excluded = self.post2.comments.exclude(author='user4') + excluded = self.post2.comments.exclude(author="user4") # Ensure the 3 entries still exist. self.assertEqual(len(excluded), 3) @@ -2299,16 +2418,16 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests the exclude method after a filter method of a List of Embedded Documents. """ - excluded = self.post2.comments.filter(author='user2').exclude( - message='message2' + excluded = self.post2.comments.filter(author="user2").exclude( + message="message2" ) # Ensure only 1 entry was returned. self.assertEqual(len(excluded), 1) # Ensure the entry returned is the correct entry. - self.assertEqual(excluded[0].author, 'user2') - self.assertEqual(excluded[0].message, 'message3') + self.assertEqual(excluded[0].author, "user2") + self.assertEqual(excluded[0].message, "message3") def test_count(self): """ @@ -2321,7 +2440,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): """ Tests the filter + count method of a List of Embedded Documents. """ - count = self.post1.comments.filter(author='user1').count() + count = self.post1.comments.filter(author="user1").count() self.assertEqual(count, 1) def test_single_keyword_get(self): @@ -2329,19 +2448,19 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests the get method of a List of Embedded Documents using a single keyword. """ - comment = self.post1.comments.get(author='user1') + comment = self.post1.comments.get(author="user1") self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, 'user1') + self.assertEqual(comment.author, "user1") def test_multi_keyword_get(self): """ Tests the get method of a List of Embedded Documents using multiple keywords. """ - comment = self.post2.comments.get(author='user2', message='message2') + comment = self.post2.comments.get(author="user2", message="message2") self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, 'user2') - self.assertEqual(comment.message, 'message2') + self.assertEqual(comment.author, "user2") + self.assertEqual(comment.message, "message2") def test_no_keyword_multiple_return_get(self): """ @@ -2357,7 +2476,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): to return multiple documents. """ with self.assertRaises(MultipleObjectsReturned): - self.post2.comments.get(author='user2') + self.post2.comments.get(author="user2") def test_unknown_keyword_get(self): """ @@ -2373,7 +2492,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): returns no results. """ with self.assertRaises(DoesNotExist): - self.post1.comments.get(author='user3') + self.post1.comments.get(author="user3") def test_first(self): """ @@ -2390,20 +2509,17 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): """ Test the create method of a List of Embedded Documents. """ - comment = self.post1.comments.create( - author='user4', message='message1' - ) + comment = self.post1.comments.create(author="user4", message="message1") self.post1.save() # Ensure the returned value is the comment object. self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, 'user4') - self.assertEqual(comment.message, 'message1') + self.assertEqual(comment.author, "user4") + self.assertEqual(comment.message, "message1") # Ensure the new comment was actually saved to the database. self.assertIn( - comment, - self.BlogPost.objects(comments__author='user4')[0].comments + comment, self.BlogPost.objects(comments__author="user4")[0].comments ) def test_filtered_create(self): @@ -2412,20 +2528,19 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): to a call to the filter method. Filtering should have no effect on creation. """ - comment = self.post1.comments.filter(author='user1').create( - author='user4', message='message1' + comment = self.post1.comments.filter(author="user1").create( + author="user4", message="message1" ) self.post1.save() # Ensure the returned value is the comment object. self.assertIsInstance(comment, self.Comments) - self.assertEqual(comment.author, 'user4') - self.assertEqual(comment.message, 'message1') + self.assertEqual(comment.author, "user4") + self.assertEqual(comment.message, "message1") # Ensure the new comment was actually saved to the database. self.assertIn( - comment, - self.BlogPost.objects(comments__author='user4')[0].comments + comment, self.BlogPost.objects(comments__author="user4")[0].comments ) def test_no_keyword_update(self): @@ -2438,15 +2553,9 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): self.post1.save() # Ensure that nothing was altered. - self.assertIn( - original[0], - self.BlogPost.objects(id=self.post1.id)[0].comments - ) + self.assertIn(original[0], self.BlogPost.objects(id=self.post1.id)[0].comments) - self.assertIn( - original[1], - self.BlogPost.objects(id=self.post1.id)[0].comments - ) + self.assertIn(original[1], self.BlogPost.objects(id=self.post1.id)[0].comments) # Ensure the method returned 0 as the number of entries # modified @@ -2457,14 +2566,14 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests the update method of a List of Embedded Documents with a single keyword. """ - number = self.post1.comments.update(author='user4') + number = self.post1.comments.update(author="user4") self.post1.save() comments = self.BlogPost.objects(id=self.post1.id)[0].comments # Ensure that the database was updated properly. - self.assertEqual(comments[0].author, 'user4') - self.assertEqual(comments[1].author, 'user4') + self.assertEqual(comments[0].author, "user4") + self.assertEqual(comments[1].author, "user4") # Ensure the method returned 2 as the number of entries # modified @@ -2474,27 +2583,25 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): """ Tests that unicode strings handled correctly """ - post = self.BlogPost(comments=[ - self.Comments(author='user1', message=u'сообщение'), - self.Comments(author='user2', message=u'хабарлама') - ]).save() - self.assertEqual(post.comments.get(message=u'сообщение').author, - 'user1') + post = self.BlogPost( + comments=[ + self.Comments(author="user1", message=u"сообщение"), + self.Comments(author="user2", message=u"хабарлама"), + ] + ).save() + self.assertEqual(post.comments.get(message=u"сообщение").author, "user1") def test_save(self): """ Tests the save method of a List of Embedded Documents. """ comments = self.post1.comments - new_comment = self.Comments(author='user4') + new_comment = self.Comments(author="user4") comments.append(new_comment) comments.save() # Ensure that the new comment has been added to the database. - self.assertIn( - new_comment, - self.BlogPost.objects(id=self.post1.id)[0].comments - ) + self.assertIn(new_comment, self.BlogPost.objects(id=self.post1.id)[0].comments) def test_delete(self): """ @@ -2505,9 +2612,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): # Ensure that all the comments under post1 were deleted in the # database. - self.assertListEqual( - self.BlogPost.objects(id=self.post1.id)[0].comments, [] - ) + self.assertListEqual(self.BlogPost.objects(id=self.post1.id)[0].comments, []) # Ensure that post1 comments were deleted from the list. self.assertListEqual(self.post1.comments, []) @@ -2525,6 +2630,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): that have a unique field can be saved, but if the unique field is also sparse than multiple documents with an empty list can be saved. """ + class EmbeddedWithUnique(EmbeddedDocument): number = IntField(unique=True) @@ -2553,16 +2659,12 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): after the filter method has been called. """ comment = self.post1.comments[1] - number = self.post1.comments.filter(author='user2').delete() + number = self.post1.comments.filter(author="user2").delete() self.post1.save() # Ensure that only the user2 comment was deleted. - self.assertNotIn( - comment, self.BlogPost.objects(id=self.post1.id)[0].comments - ) - self.assertEqual( - len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1 - ) + self.assertNotIn(comment, self.BlogPost.objects(id=self.post1.id)[0].comments) + self.assertEqual(len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1) # Ensure that the user2 comment no longer exists in the list. self.assertNotIn(comment, self.post1.comments) @@ -2577,7 +2679,7 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): Tests that custom data is saved in the field object and doesn't interfere with the rest of field functionalities. """ - custom_data = {'a': 'a_value', 'b': [1, 2]} + custom_data = {"a": "a_value", "b": [1, 2]} class CustomData(Document): a_field = IntField() @@ -2587,10 +2689,10 @@ class EmbeddedDocumentListFieldTestCase(MongoDBTestCase): a1 = CustomData(a_field=1, c_field=2).save() self.assertEqual(2, a1.c_field) - self.assertFalse(hasattr(a1.c_field, 'custom_data')) - self.assertTrue(hasattr(CustomData.c_field, 'custom_data')) - self.assertEqual(custom_data['a'], CustomData.c_field.custom_data['a']) + self.assertFalse(hasattr(a1.c_field, "custom_data")) + self.assertTrue(hasattr(CustomData.c_field, "custom_data")) + self.assertEqual(custom_data["a"], CustomData.c_field.custom_data["a"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index a7722458..dd2fe609 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -14,36 +14,37 @@ from mongoengine.python_support import StringIO try: from PIL import Image + HAS_PIL = True except ImportError: HAS_PIL = False from tests.utils import MongoDBTestCase -TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') -TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "mongoengine.png") +TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), "mongodb_leaf.png") def get_file(path): """Use a BytesIO instead of a file to allow to have a one-liner and avoid that the file remains opened""" bytes_io = StringIO() - with open(path, 'rb') as f: + with open(path, "rb") as f: bytes_io.write(f.read()) bytes_io.seek(0) return bytes_io class FileTest(MongoDBTestCase): - def tearDown(self): - self.db.drop_collection('fs.files') - self.db.drop_collection('fs.chunks') + self.db.drop_collection("fs.files") + self.db.drop_collection("fs.chunks") def test_file_field_optional(self): # Make sure FileField is optional and not required class DemoFile(Document): the_file = FileField() + DemoFile.objects.create() def test_file_fields(self): @@ -55,8 +56,8 @@ class FileTest(MongoDBTestCase): PutFile.drop_collection() - text = six.b('Hello, World!') - content_type = 'text/plain' + text = six.b("Hello, World!") + content_type = "text/plain" putfile = PutFile() putfile.the_file.put(text, content_type=content_type, filename="hello") @@ -64,7 +65,10 @@ class FileTest(MongoDBTestCase): result = PutFile.objects.first() self.assertEqual(putfile, result) - self.assertEqual("%s" % result.the_file, "" % result.the_file.grid_id) + self.assertEqual( + "%s" % result.the_file, + "" % result.the_file.grid_id, + ) self.assertEqual(result.the_file.read(), text) self.assertEqual(result.the_file.content_type, content_type) result.the_file.delete() # Remove file from GridFS @@ -89,14 +93,15 @@ class FileTest(MongoDBTestCase): def test_file_fields_stream(self): """Ensure that file fields can be written to and their data retrieved """ + class StreamFile(Document): the_file = FileField() StreamFile.drop_collection() - text = six.b('Hello, World!') - more_text = six.b('Foo Bar') - content_type = 'text/plain' + text = six.b("Hello, World!") + more_text = six.b("Foo Bar") + content_type = "text/plain" streamfile = StreamFile() streamfile.the_file.new_file(content_type=content_type) @@ -124,14 +129,15 @@ class FileTest(MongoDBTestCase): """Ensure that a file field can be written to after it has been saved as None """ + class StreamFile(Document): the_file = FileField() StreamFile.drop_collection() - text = six.b('Hello, World!') - more_text = six.b('Foo Bar') - content_type = 'text/plain' + text = six.b("Hello, World!") + more_text = six.b("Foo Bar") + content_type = "text/plain" streamfile = StreamFile() streamfile.save() @@ -157,12 +163,11 @@ class FileTest(MongoDBTestCase): self.assertTrue(result.the_file.read() is None) def test_file_fields_set(self): - class SetFile(Document): the_file = FileField() - text = six.b('Hello, World!') - more_text = six.b('Foo Bar') + text = six.b("Hello, World!") + more_text = six.b("Foo Bar") SetFile.drop_collection() @@ -184,7 +189,6 @@ class FileTest(MongoDBTestCase): result.the_file.delete() def test_file_field_no_default(self): - class GridDocument(Document): the_file = FileField() @@ -199,7 +203,7 @@ class FileTest(MongoDBTestCase): doc_a.save() doc_b = GridDocument.objects.with_id(doc_a.id) - doc_b.the_file.replace(f, filename='doc_b') + doc_b.the_file.replace(f, filename="doc_b") doc_b.save() self.assertNotEqual(doc_b.the_file.grid_id, None) @@ -208,13 +212,13 @@ class FileTest(MongoDBTestCase): self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) # Test with default - doc_d = GridDocument(the_file=six.b('')) + doc_d = GridDocument(the_file=six.b("")) doc_d.save() doc_e = GridDocument.objects.with_id(doc_d.id) self.assertEqual(doc_d.the_file.grid_id, doc_e.the_file.grid_id) - doc_e.the_file.replace(f, filename='doc_e') + doc_e.the_file.replace(f, filename="doc_e") doc_e.save() doc_f = GridDocument.objects.with_id(doc_e.id) @@ -222,11 +226,12 @@ class FileTest(MongoDBTestCase): db = GridDocument._get_db() grid_fs = gridfs.GridFS(db) - self.assertEqual(['doc_b', 'doc_e'], grid_fs.list()) + self.assertEqual(["doc_b", "doc_e"], grid_fs.list()) def test_file_uniqueness(self): """Ensure that each instance of a FileField is unique """ + class TestFile(Document): name = StringField() the_file = FileField() @@ -234,7 +239,7 @@ class FileTest(MongoDBTestCase): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(six.b('Hello, World!')) + test_file.the_file.put(six.b("Hello, World!")) test_file.save() # Second instance @@ -255,20 +260,21 @@ class FileTest(MongoDBTestCase): photo = FileField() Animal.drop_collection() - marmot = Animal(genus='Marmota', family='Sciuridae') + marmot = Animal(genus="Marmota", family="Sciuridae") marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk - marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar') + marmot.photo.put(marmot_photo_content, content_type="image/jpeg", foo="bar") marmot.photo.close() marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photo.content_type, 'image/jpeg') - self.assertEqual(marmot.photo.foo, 'bar') + self.assertEqual(marmot.photo.content_type, "image/jpeg") + self.assertEqual(marmot.photo.foo, "bar") def test_file_reassigning(self): class TestFile(Document): the_file = FileField() + TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() @@ -282,13 +288,15 @@ class FileTest(MongoDBTestCase): def test_file_boolean(self): """Ensure that a boolean test of a FileField indicates its presence """ + class TestFile(Document): the_file = FileField() + TestFile.drop_collection() test_file = TestFile() self.assertFalse(bool(test_file.the_file)) - test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain') + test_file.the_file.put(six.b("Hello, World!"), content_type="text/plain") test_file.save() self.assertTrue(bool(test_file.the_file)) @@ -297,6 +305,7 @@ class FileTest(MongoDBTestCase): def test_file_cmp(self): """Test comparing against other types""" + class TestFile(Document): the_file = FileField() @@ -305,11 +314,12 @@ class FileTest(MongoDBTestCase): def test_file_disk_space(self): """ Test disk space usage when we delete/replace a file """ + class TestFile(Document): the_file = FileField() - text = six.b('Hello, World!') - content_type = 'text/plain' + text = six.b("Hello, World!") + content_type = "text/plain" testfile = TestFile() testfile.the_file.put(text, content_type=content_type, filename="hello") @@ -352,7 +362,7 @@ class FileTest(MongoDBTestCase): testfile.the_file.put(text, content_type=content_type, filename="hello") testfile.save() - text = six.b('Bonjour, World!') + text = six.b("Bonjour, World!") testfile.the_file.replace(text, content_type=content_type, filename="hello") testfile.save() @@ -370,7 +380,7 @@ class FileTest(MongoDBTestCase): def test_image_field(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestImage(Document): image = ImageField() @@ -386,7 +396,9 @@ class FileTest(MongoDBTestCase): t.image.put(f) self.fail("Should have raised an invalidation error") except ValidationError as e: - self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) + self.assertEqual( + "%s" % e, "Invalid image: cannot identify image file %s" % f + ) t = TestImage() t.image.put(get_file(TEST_IMAGE_PATH)) @@ -394,7 +406,7 @@ class FileTest(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, 'PNG') + self.assertEqual(t.image.format, "PNG") w, h = t.image.size self.assertEqual(w, 371) @@ -404,10 +416,11 @@ class FileTest(MongoDBTestCase): def test_image_field_reassigning(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestFile(Document): the_file = ImageField() + TestFile.drop_collection() test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save() @@ -420,7 +433,7 @@ class FileTest(MongoDBTestCase): def test_image_field_resize(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestImage(Document): image = ImageField(size=(185, 37)) @@ -433,7 +446,7 @@ class FileTest(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, 'PNG') + self.assertEqual(t.image.format, "PNG") w, h = t.image.size self.assertEqual(w, 185) @@ -443,7 +456,7 @@ class FileTest(MongoDBTestCase): def test_image_field_resize_force(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestImage(Document): image = ImageField(size=(185, 37, True)) @@ -456,7 +469,7 @@ class FileTest(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.format, 'PNG') + self.assertEqual(t.image.format, "PNG") w, h = t.image.size self.assertEqual(w, 185) @@ -466,7 +479,7 @@ class FileTest(MongoDBTestCase): def test_image_field_thumbnail(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestImage(Document): image = ImageField(thumbnail_size=(92, 18)) @@ -479,19 +492,18 @@ class FileTest(MongoDBTestCase): t = TestImage.objects.first() - self.assertEqual(t.image.thumbnail.format, 'PNG') + self.assertEqual(t.image.thumbnail.format, "PNG") self.assertEqual(t.image.thumbnail.width, 92) self.assertEqual(t.image.thumbnail.height, 18) t.image.delete() def test_file_multidb(self): - register_connection('test_files', 'test_files') + register_connection("test_files", "test_files") class TestFile(Document): name = StringField() - the_file = FileField(db_alias="test_files", - collection_name="macumba") + the_file = FileField(db_alias="test_files", collection_name="macumba") TestFile.drop_collection() @@ -502,23 +514,21 @@ class FileTest(MongoDBTestCase): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(six.b('Hello, World!'), - name="hello.txt") + test_file.the_file.put(six.b("Hello, World!"), name="hello.txt") test_file.save() data = get_db("test_files").macumba.files.find_one() - self.assertEqual(data.get('name'), 'hello.txt') + self.assertEqual(data.get("name"), "hello.txt") test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), six.b('Hello, World!')) + self.assertEqual(test_file.the_file.read(), six.b("Hello, World!")) test_file = TestFile.objects.first() - test_file.the_file = six.b('HELLO, WORLD!') + test_file.the_file = six.b("HELLO, WORLD!") test_file.save() test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), - six.b('HELLO, WORLD!')) + self.assertEqual(test_file.the_file.read(), six.b("HELLO, WORLD!")) def test_copyable(self): class PutFile(Document): @@ -526,8 +536,8 @@ class FileTest(MongoDBTestCase): PutFile.drop_collection() - text = six.b('Hello, World!') - content_type = 'text/plain' + text = six.b("Hello, World!") + content_type = "text/plain" putfile = PutFile() putfile.the_file.put(text, content_type=content_type) @@ -542,7 +552,7 @@ class FileTest(MongoDBTestCase): def test_get_image_by_grid_id(self): if not HAS_PIL: - raise SkipTest('PIL not installed') + raise SkipTest("PIL not installed") class TestImage(Document): @@ -559,8 +569,9 @@ class FileTest(MongoDBTestCase): test = TestImage.objects.first() grid_id = test.image1.grid_id - self.assertEqual(1, TestImage.objects(Q(image1=grid_id) - or Q(image2=grid_id)).count()) + self.assertEqual( + 1, TestImage.objects(Q(image1=grid_id) or Q(image2=grid_id)).count() + ) def test_complex_field_filefield(self): """Ensure you can add meta data to file""" @@ -571,21 +582,21 @@ class FileTest(MongoDBTestCase): photos = ListField(FileField()) Animal.drop_collection() - marmot = Animal(genus='Marmota', family='Sciuridae') + marmot = Animal(genus="Marmota", family="Sciuridae") - with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk - photos_field = marmot._fields['photos'].field - new_proxy = photos_field.get_proxy_obj('photos', marmot) - new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar') + with open(TEST_IMAGE_PATH, "rb") as marmot_photo: # Retrieve a photo from disk + photos_field = marmot._fields["photos"].field + new_proxy = photos_field.get_proxy_obj("photos", marmot) + new_proxy.put(marmot_photo, content_type="image/jpeg", foo="bar") marmot.photos.append(new_proxy) marmot.save() marmot = Animal.objects.get() - self.assertEqual(marmot.photos[0].content_type, 'image/jpeg') - self.assertEqual(marmot.photos[0].foo, 'bar') + self.assertEqual(marmot.photos[0].content_type, "image/jpeg") + self.assertEqual(marmot.photos[0].foo, "bar") self.assertEqual(marmot.photos[0].get().length, 8313) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 37ed97f5..446d7171 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -4,28 +4,27 @@ import unittest from mongoengine import * from mongoengine.connection import get_db -__all__ = ("GeoFieldTest", ) +__all__ = ("GeoFieldTest",) class GeoFieldTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") self.db = get_db() def _test_for_expected_error(self, Cls, loc, expected): try: Cls(loc=loc).validate() - self.fail('Should not validate the location {0}'.format(loc)) + self.fail("Should not validate the location {0}".format(loc)) except ValidationError as e: - self.assertEqual(expected, e.to_dict()['loc']) + self.assertEqual(expected, e.to_dict()["loc"]) def test_geopoint_validation(self): class Location(Document): loc = GeoPointField() invalid_coords = [{"x": 1, "y": 2}, 5, "a"] - expected = 'GeoPointField can only accept tuples or lists of (x, y)' + expected = "GeoPointField can only accept tuples or lists of (x, y)" for coord in invalid_coords: self._test_for_expected_error(Location, coord, expected) @@ -40,7 +39,7 @@ class GeoFieldTest(unittest.TestCase): expected = "Both values (%s) in point must be float or int" % repr(coord) self._test_for_expected_error(Location, coord, expected) - invalid_coords = [21, 4, 'a'] + invalid_coords = [21, 4, "a"] for coord in invalid_coords: expected = "GeoPointField can only accept tuples or lists of (x, y)" self._test_for_expected_error(Location, coord, expected) @@ -50,7 +49,9 @@ class GeoFieldTest(unittest.TestCase): loc = PointField() invalid_coords = {"x": 1, "y": 2} - expected = 'PointField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = ( + "PointField can only accept a valid GeoJson dictionary or lists of (x, y)" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": []} @@ -77,19 +78,16 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, coord, expected) Location(loc=[1, 2]).validate() - Location(loc={ - "type": "Point", - "coordinates": [ - 81.4471435546875, - 23.61432859499169 - ]}).validate() + Location( + loc={"type": "Point", "coordinates": [81.4471435546875, 23.61432859499169]} + ).validate() def test_linestring_validation(self): class Location(Document): loc = LineStringField() invalid_coords = {"x": 1, "y": 2} - expected = 'LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = "LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)" self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": [[]]} @@ -97,7 +95,9 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "LineString", "coordinates": [[1, 2, 3]]} - expected = "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point" + expected = ( + "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [5, "a"] @@ -105,16 +105,25 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[1]] - expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0]) + expected = ( + "Invalid LineString:\nValue (%s) must be a two-dimensional point" + % repr(invalid_coords[0]) + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[1, 2, 3]] - expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0]) + expected = ( + "Invalid LineString:\nValue (%s) must be a two-dimensional point" + % repr(invalid_coords[0]) + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[{}, {}]], [("a", "b")]] for coord in invalid_coords: - expected = "Invalid LineString:\nBoth values (%s) in point must be float or int" % repr(coord[0]) + expected = ( + "Invalid LineString:\nBoth values (%s) in point must be float or int" + % repr(coord[0]) + ) self._test_for_expected_error(Location, coord, expected) Location(loc=[[1, 2], [3, 4], [5, 6], [1, 2]]).validate() @@ -124,7 +133,9 @@ class GeoFieldTest(unittest.TestCase): loc = PolygonField() invalid_coords = {"x": 1, "y": 2} - expected = 'PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = ( + "PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": [[]]} @@ -136,7 +147,9 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[5, "a"]]] - expected = "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int" + expected = ( + "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[]]] @@ -162,7 +175,7 @@ class GeoFieldTest(unittest.TestCase): loc = MultiPointField() invalid_coords = {"x": 1, "y": 2} - expected = 'MultiPointField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = "MultiPointField can only accept a valid GeoJson dictionary or lists of (x, y)" self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": [[]]} @@ -188,19 +201,19 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, coord, expected) Location(loc=[[1, 2]]).validate() - Location(loc={ - "type": "MultiPoint", - "coordinates": [ - [1, 2], - [81.4471435546875, 23.61432859499169] - ]}).validate() + Location( + loc={ + "type": "MultiPoint", + "coordinates": [[1, 2], [81.4471435546875, 23.61432859499169]], + } + ).validate() def test_multilinestring_validation(self): class Location(Document): loc = MultiLineStringField() invalid_coords = {"x": 1, "y": 2} - expected = 'MultiLineStringField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = "MultiLineStringField can only accept a valid GeoJson dictionary or lists of (x, y)" self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": [[]]} @@ -216,16 +229,25 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[1]]] - expected = "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0][0]) + expected = ( + "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" + % repr(invalid_coords[0][0]) + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[1, 2, 3]]] - expected = "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0][0]) + expected = ( + "Invalid MultiLineString:\nValue (%s) must be a two-dimensional point" + % repr(invalid_coords[0][0]) + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[[{}, {}]]], [[("a", "b")]]] for coord in invalid_coords: - expected = "Invalid MultiLineString:\nBoth values (%s) in point must be float or int" % repr(coord[0][0]) + expected = ( + "Invalid MultiLineString:\nBoth values (%s) in point must be float or int" + % repr(coord[0][0]) + ) self._test_for_expected_error(Location, coord, expected) Location(loc=[[[1, 2], [3, 4], [5, 6], [1, 2]]]).validate() @@ -235,7 +257,7 @@ class GeoFieldTest(unittest.TestCase): loc = MultiPolygonField() invalid_coords = {"x": 1, "y": 2} - expected = 'MultiPolygonField can only accept a valid GeoJson dictionary or lists of (x, y)' + expected = "MultiPolygonField can only accept a valid GeoJson dictionary or lists of (x, y)" self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MadeUp", "coordinates": [[]]} @@ -243,7 +265,9 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = {"type": "MultiPolygon", "coordinates": [[[[1, 2, 3]]]]} - expected = "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + expected = ( + "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[[5, "a"]]]] @@ -255,7 +279,9 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[[1, 2, 3]]]] - expected = "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + expected = ( + "Invalid MultiPolygon:\nValue ([1, 2, 3]) must be a two-dimensional point" + ) self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[[{}, {}]]], [[("a", "b")]]] @@ -263,7 +289,9 @@ class GeoFieldTest(unittest.TestCase): self._test_for_expected_error(Location, invalid_coords, expected) invalid_coords = [[[[1, 2], [3, 4]]]] - expected = "Invalid MultiPolygon:\nLineStrings must start and end at the same point" + expected = ( + "Invalid MultiPolygon:\nLineStrings must start and end at the same point" + ) self._test_for_expected_error(Location, invalid_coords, expected) Location(loc=[[[[1, 2], [3, 4], [5, 6], [1, 2]]]]).validate() @@ -271,17 +299,19 @@ class GeoFieldTest(unittest.TestCase): def test_indexes_geopoint(self): """Ensure that indexes are created automatically for GeoPointFields. """ + class Event(Document): title = StringField() location = GeoPointField() geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{'fields': [('location', '2d')]}]) + self.assertEqual(geo_indicies, [{"fields": [("location", "2d")]}]) def test_geopoint_embedded_indexes(self): """Ensure that indexes are created automatically for GeoPointFields on embedded documents. """ + class Venue(EmbeddedDocument): location = GeoPointField() name = StringField() @@ -291,11 +321,12 @@ class GeoFieldTest(unittest.TestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertEqual(geo_indicies, [{'fields': [('venue.location', '2d')]}]) + self.assertEqual(geo_indicies, [{"fields": [("venue.location", "2d")]}]) def test_indexes_2dsphere(self): """Ensure that indexes are created automatically for GeoPointFields. """ + class Event(Document): title = StringField() point = PointField() @@ -303,13 +334,14 @@ class GeoFieldTest(unittest.TestCase): polygon = PolygonField() geo_indicies = Event._geo_indices() - self.assertIn({'fields': [('line', '2dsphere')]}, geo_indicies) - self.assertIn({'fields': [('polygon', '2dsphere')]}, geo_indicies) - self.assertIn({'fields': [('point', '2dsphere')]}, geo_indicies) + self.assertIn({"fields": [("line", "2dsphere")]}, geo_indicies) + self.assertIn({"fields": [("polygon", "2dsphere")]}, geo_indicies) + self.assertIn({"fields": [("point", "2dsphere")]}, geo_indicies) def test_indexes_2dsphere_embedded(self): """Ensure that indexes are created automatically for GeoPointFields. """ + class Venue(EmbeddedDocument): name = StringField() point = PointField() @@ -321,12 +353,11 @@ class GeoFieldTest(unittest.TestCase): venue = EmbeddedDocumentField(Venue) geo_indicies = Event._geo_indices() - self.assertIn({'fields': [('venue.line', '2dsphere')]}, geo_indicies) - self.assertIn({'fields': [('venue.polygon', '2dsphere')]}, geo_indicies) - self.assertIn({'fields': [('venue.point', '2dsphere')]}, geo_indicies) + self.assertIn({"fields": [("venue.line", "2dsphere")]}, geo_indicies) + self.assertIn({"fields": [("venue.polygon", "2dsphere")]}, geo_indicies) + self.assertIn({"fields": [("venue.point", "2dsphere")]}, geo_indicies) def test_geo_indexes_recursion(self): - class Location(Document): name = StringField() location = GeoPointField() @@ -338,11 +369,11 @@ class GeoFieldTest(unittest.TestCase): Location.drop_collection() Parent.drop_collection() - Parent(name='Berlin').save() + Parent(name="Berlin").save() info = Parent._get_collection().index_information() - self.assertNotIn('location_2d', info) + self.assertNotIn("location_2d", info) info = Location._get_collection().index_information() - self.assertIn('location_2d', info) + self.assertIn("location_2d", info) self.assertEqual(len(Parent._geo_indices()), 0) self.assertEqual(len(Location._geo_indices()), 1) @@ -354,9 +385,7 @@ class GeoFieldTest(unittest.TestCase): location = PointField(auto_index=False) datetime = DateTimeField() - meta = { - 'indexes': [[("location", "2dsphere"), ("datetime", 1)]] - } + meta = {"indexes": [[("location", "2dsphere"), ("datetime", 1)]]} self.assertEqual([], Log._geo_indices()) @@ -364,8 +393,10 @@ class GeoFieldTest(unittest.TestCase): Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual(info["location_2dsphere_datetime_1"]["key"], - [('location', '2dsphere'), ('datetime', 1)]) + self.assertEqual( + info["location_2dsphere_datetime_1"]["key"], + [("location", "2dsphere"), ("datetime", 1)], + ) # Test listing explicitly class Log(Document): @@ -373,9 +404,7 @@ class GeoFieldTest(unittest.TestCase): datetime = DateTimeField() meta = { - 'indexes': [ - {'fields': [("location", "2dsphere"), ("datetime", 1)]} - ] + "indexes": [{"fields": [("location", "2dsphere"), ("datetime", 1)]}] } self.assertEqual([], Log._geo_indices()) @@ -384,9 +413,11 @@ class GeoFieldTest(unittest.TestCase): Log.ensure_indexes() info = Log._get_collection().index_information() - self.assertEqual(info["location_2dsphere_datetime_1"]["key"], - [('location', '2dsphere'), ('datetime', 1)]) + self.assertEqual( + info["location_2dsphere_datetime_1"]["key"], + [("location", "2dsphere"), ("datetime", 1)], + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/fields/test_binary_field.py b/tests/fields/test_binary_field.py index 8af75d4e..df4bf2de 100644 --- a/tests/fields/test_binary_field.py +++ b/tests/fields/test_binary_field.py @@ -9,19 +9,22 @@ from bson import Binary from mongoengine import * from tests.utils import MongoDBTestCase -BIN_VALUE = six.b('\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5') +BIN_VALUE = six.b( + "\xa9\xf3\x8d(\xd7\x03\x84\xb4k[\x0f\xe3\xa2\x19\x85p[J\xa3\xd2>\xde\xe6\x87\xb1\x7f\xc6\xe6\xd9r\x18\xf5" +) class TestBinaryField(MongoDBTestCase): def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """ + class Attachment(Document): content_type = StringField() blob = BinaryField() - BLOB = six.b('\xe6\x00\xc4\xff\x07') - MIME_TYPE = 'application/octet-stream' + BLOB = six.b("\xe6\x00\xc4\xff\x07") + MIME_TYPE = "application/octet-stream" Attachment.drop_collection() @@ -35,6 +38,7 @@ class TestBinaryField(MongoDBTestCase): def test_validation_succeeds(self): """Ensure that valid values can be assigned to binary fields. """ + class AttachmentRequired(Document): blob = BinaryField(required=True) @@ -43,11 +47,11 @@ class TestBinaryField(MongoDBTestCase): attachment_required = AttachmentRequired() self.assertRaises(ValidationError, attachment_required.validate) - attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07')) + attachment_required.blob = Binary(six.b("\xe6\x00\xc4\xff\x07")) attachment_required.validate() - _5_BYTES = six.b('\xe6\x00\xc4\xff\x07') - _4_BYTES = six.b('\xe6\x00\xc4\xff') + _5_BYTES = six.b("\xe6\x00\xc4\xff\x07") + _4_BYTES = six.b("\xe6\x00\xc4\xff") self.assertRaises(ValidationError, AttachmentSizeLimit(blob=_5_BYTES).validate) AttachmentSizeLimit(blob=_4_BYTES).validate() @@ -57,7 +61,7 @@ class TestBinaryField(MongoDBTestCase): class Attachment(Document): blob = BinaryField() - for invalid_data in (2, u'Im_a_unicode', ['some_str']): + for invalid_data in (2, u"Im_a_unicode", ["some_str"]): self.assertRaises(ValidationError, Attachment(blob=invalid_data).validate) def test__primary(self): @@ -108,17 +112,17 @@ class TestBinaryField(MongoDBTestCase): def test_modify_operation__set(self): """Ensures no regression of bug #1127""" + class MyDocument(Document): some_field = StringField() bin_field = BinaryField() MyDocument.drop_collection() - doc = MyDocument.objects(some_field='test').modify( - upsert=True, new=True, - set__bin_field=BIN_VALUE + doc = MyDocument.objects(some_field="test").modify( + upsert=True, new=True, set__bin_field=BIN_VALUE ) - self.assertEqual(doc.some_field, 'test') + self.assertEqual(doc.some_field, "test") if six.PY3: self.assertEqual(doc.bin_field, BIN_VALUE) else: @@ -126,15 +130,18 @@ class TestBinaryField(MongoDBTestCase): def test_update_one(self): """Ensures no regression of bug #1127""" + class MyDocument(Document): bin_field = BinaryField() MyDocument.drop_collection() - bin_data = six.b('\xe6\x00\xc4\xff\x07') + bin_data = six.b("\xe6\x00\xc4\xff\x07") doc = MyDocument(bin_field=bin_data).save() - n_updated = MyDocument.objects(bin_field=bin_data).update_one(bin_field=BIN_VALUE) + n_updated = MyDocument.objects(bin_field=bin_data).update_one( + bin_field=BIN_VALUE + ) self.assertEqual(n_updated, 1) fetched = MyDocument.objects.with_id(doc.id) if six.PY3: diff --git a/tests/fields/test_boolean_field.py b/tests/fields/test_boolean_field.py index 7a2a3db6..22ebb6f7 100644 --- a/tests/fields/test_boolean_field.py +++ b/tests/fields/test_boolean_field.py @@ -11,15 +11,13 @@ class TestBooleanField(MongoDBTestCase): person = Person(admin=True) person.save() - self.assertEqual( - get_as_pymongo(person), - {'_id': person.id, - 'admin': True}) + self.assertEqual(get_as_pymongo(person), {"_id": person.id, "admin": True}) def test_validation(self): """Ensure that invalid values cannot be assigned to boolean fields. """ + class Person(Document): admin = BooleanField() @@ -29,9 +27,9 @@ class TestBooleanField(MongoDBTestCase): person.admin = 2 self.assertRaises(ValidationError, person.validate) - person.admin = 'Yes' + person.admin = "Yes" self.assertRaises(ValidationError, person.validate) - person.admin = 'False' + person.admin = "False" self.assertRaises(ValidationError, person.validate) def test_weirdness_constructor(self): @@ -39,11 +37,12 @@ class TestBooleanField(MongoDBTestCase): which causes some weird behavior. We dont necessarily want to maintain this behavior but its a known issue """ + class Person(Document): admin = BooleanField() - new_person = Person(admin='False') + new_person = Person(admin="False") self.assertTrue(new_person.admin) - new_person = Person(admin='0') + new_person = Person(admin="0") self.assertTrue(new_person.admin) diff --git a/tests/fields/test_cached_reference_field.py b/tests/fields/test_cached_reference_field.py index 470ecc5d..4e467587 100644 --- a/tests/fields/test_cached_reference_field.py +++ b/tests/fields/test_cached_reference_field.py @@ -7,12 +7,12 @@ from tests.utils import MongoDBTestCase class TestCachedReferenceField(MongoDBTestCase): - def test_get_and_save(self): """ Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo. """ + class Animal(Document): name = StringField() tag = StringField() @@ -24,10 +24,11 @@ class TestCachedReferenceField(MongoDBTestCase): Animal.drop_collection() Ocorrence.drop_collection() - Ocorrence(person="testte", - animal=Animal(name="Leopard", tag="heavy").save()).save() + Ocorrence( + person="testte", animal=Animal(name="Leopard", tag="heavy").save() + ).save() p = Ocorrence.objects.get() - p.person = 'new_testte' + p.person = "new_testte" p.save() def test_general_things(self): @@ -37,8 +38,7 @@ class TestCachedReferenceField(MongoDBTestCase): class Ocorrence(Document): person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag']) + animal = CachedReferenceField(Animal, fields=["tag"]) Animal.drop_collection() Ocorrence.drop_collection() @@ -55,19 +55,18 @@ class TestCachedReferenceField(MongoDBTestCase): self.assertEqual(Ocorrence.objects(animal=None).count(), 1) - self.assertEqual( - a.to_mongo(fields=['tag']), {'tag': 'heavy', "_id": a.pk}) + self.assertEqual(a.to_mongo(fields=["tag"]), {"tag": "heavy", "_id": a.pk}) - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') + self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() - count = Ocorrence.objects(animal__tag='heavy').count() + count = Ocorrence.objects(animal__tag="heavy").count() self.assertEqual(count, 1) - ocorrence = Ocorrence.objects(animal__tag='heavy').first() + ocorrence = Ocorrence.objects(animal__tag="heavy").first() self.assertEqual(ocorrence.person, "teste") self.assertIsInstance(ocorrence.animal, Animal) @@ -78,28 +77,21 @@ class TestCachedReferenceField(MongoDBTestCase): class SocialTest(Document): group = StringField() - person = CachedReferenceField( - PersonAuto, - fields=('salary',)) + person = CachedReferenceField(PersonAuto, fields=("salary",)) PersonAuto.drop_collection() SocialTest.drop_collection() - p = PersonAuto(name="Alberto", salary=Decimal('7000.00')) + p = PersonAuto(name="Alberto", salary=Decimal("7000.00")) p.save() s = SocialTest(group="dev", person=p) s.save() self.assertEqual( - SocialTest.objects._collection.find_one({'person.salary': 7000.00}), { - '_id': s.pk, - 'group': s.group, - 'person': { - '_id': p.pk, - 'salary': 7000.00 - } - }) + SocialTest.objects._collection.find_one({"person.salary": 7000.00}), + {"_id": s.pk, "group": s.group, "person": {"_id": p.pk, "salary": 7000.00}}, + ) def test_cached_reference_field_reference(self): class Group(Document): @@ -111,17 +103,14 @@ class TestCachedReferenceField(MongoDBTestCase): class SocialData(Document): obs = StringField() - tags = ListField( - StringField()) - person = CachedReferenceField( - Person, - fields=('group',)) + tags = ListField(StringField()) + person = CachedReferenceField(Person, fields=("group",)) Group.drop_collection() Person.drop_collection() SocialData.drop_collection() - g1 = Group(name='dev') + g1 = Group(name="dev") g1.save() g2 = Group(name="designers") @@ -136,22 +125,21 @@ class TestCachedReferenceField(MongoDBTestCase): p3 = Person(name="Afro design", group=g2) p3.save() - s1 = SocialData(obs="testing 123", person=p1, tags=['tag1', 'tag2']) + s1 = SocialData(obs="testing 123", person=p1, tags=["tag1", "tag2"]) s1.save() - s2 = SocialData(obs="testing 321", person=p3, tags=['tag3', 'tag4']) + s2 = SocialData(obs="testing 321", person=p3, tags=["tag3", "tag4"]) s2.save() - self.assertEqual(SocialData.objects._collection.find_one( - {'tags': 'tag2'}), { - '_id': s1.pk, - 'obs': 'testing 123', - 'tags': ['tag1', 'tag2'], - 'person': { - '_id': p1.pk, - 'group': g1.pk - } - }) + self.assertEqual( + SocialData.objects._collection.find_one({"tags": "tag2"}), + { + "_id": s1.pk, + "obs": "testing 123", + "tags": ["tag1", "tag2"], + "person": {"_id": p1.pk, "group": g1.pk}, + }, + ) self.assertEqual(SocialData.objects(person__group=g2).count(), 1) self.assertEqual(SocialData.objects(person__group=g2).first(), s2) @@ -163,23 +151,18 @@ class TestCachedReferenceField(MongoDBTestCase): Product.drop_collection() class Basket(Document): - products = ListField(CachedReferenceField(Product, fields=['name'])) + products = ListField(CachedReferenceField(Product, fields=["name"])) Basket.drop_collection() - product1 = Product(name='abc').save() - product2 = Product(name='def').save() + product1 = Product(name="abc").save() + product2 = Product(name="def").save() basket = Basket(products=[product1]).save() self.assertEqual( Basket.objects._collection.find_one(), { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - } - ] - } + "_id": basket.pk, + "products": [{"_id": product1.pk, "name": product1.name}], + }, ) # push to list basket.update(push__products=product2) @@ -187,161 +170,135 @@ class TestCachedReferenceField(MongoDBTestCase): self.assertEqual( Basket.objects._collection.find_one(), { - '_id': basket.pk, - 'products': [ - { - '_id': product1.pk, - 'name': product1.name - }, - { - '_id': product2.pk, - 'name': product2.name - } - ] - } + "_id": basket.pk, + "products": [ + {"_id": product1.pk, "name": product1.name}, + {"_id": product2.pk, "name": product2.name}, + ], + }, ) def test_cached_reference_field_update_all(self): class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) + TYPES = (("pf", "PF"), ("pj", "PJ")) name = StringField() tp = StringField(choices=TYPES) - father = CachedReferenceField('self', fields=('tp',)) + father = CachedReferenceField("self", fields=("tp",)) Person.drop_collection() a1 = Person(name="Wilson Father", tp="pj") a1.save() - a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2 = Person(name="Wilson Junior", tp="pf", father=a1) a2.save() a2 = Person.objects.with_id(a2.id) self.assertEqual(a2.father.tp, a1.tp) - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pj" - } - }) + self.assertEqual( + dict(a2.to_mongo()), + { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pj"}, + }, + ) - self.assertEqual(Person.objects(father=a1)._query, { - 'father._id': a1.pk - }) + self.assertEqual(Person.objects(father=a1)._query, {"father._id": a1.pk}) self.assertEqual(Person.objects(father=a1).count(), 1) Person.objects.update(set__tp="pf") Person.father.sync_all() a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - "_id": a2.pk, - "name": u"Wilson Junior", - "tp": u"pf", - "father": { - "_id": a1.pk, - "tp": u"pf" - } - }) + self.assertEqual( + dict(a2.to_mongo()), + { + "_id": a2.pk, + "name": u"Wilson Junior", + "tp": u"pf", + "father": {"_id": a1.pk, "tp": u"pf"}, + }, + ) def test_cached_reference_fields_on_embedded_documents(self): with self.assertRaises(InvalidDocumentError): + class Test(Document): name = StringField() - type('WrongEmbeddedDocument', ( - EmbeddedDocument,), { - 'test': CachedReferenceField(Test) - }) + type( + "WrongEmbeddedDocument", + (EmbeddedDocument,), + {"test": CachedReferenceField(Test)}, + ) def test_cached_reference_auto_sync(self): class Person(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) + TYPES = (("pf", "PF"), ("pj", "PJ")) name = StringField() - tp = StringField( - choices=TYPES - ) + tp = StringField(choices=TYPES) - father = CachedReferenceField('self', fields=('tp',)) + father = CachedReferenceField("self", fields=("tp",)) Person.drop_collection() a1 = Person(name="Wilson Father", tp="pj") a1.save() - a2 = Person(name='Wilson Junior', tp='pf', father=a1) + a2 = Person(name="Wilson Junior", tp="pf", father=a1) a2.save() - a1.tp = 'pf' + a1.tp = "pf" a1.save() a2.reload() - self.assertEqual(dict(a2.to_mongo()), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pf' - } - }) + self.assertEqual( + dict(a2.to_mongo()), + { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pf"}, + }, + ) def test_cached_reference_auto_sync_disabled(self): class Persone(Document): - TYPES = ( - ('pf', "PF"), - ('pj', "PJ") - ) + TYPES = (("pf", "PF"), ("pj", "PJ")) name = StringField() - tp = StringField( - choices=TYPES - ) + tp = StringField(choices=TYPES) - father = CachedReferenceField( - 'self', fields=('tp',), auto_sync=False) + father = CachedReferenceField("self", fields=("tp",), auto_sync=False) Persone.drop_collection() a1 = Persone(name="Wilson Father", tp="pj") a1.save() - a2 = Persone(name='Wilson Junior', tp='pf', father=a1) + a2 = Persone(name="Wilson Junior", tp="pf", father=a1) a2.save() - a1.tp = 'pf' + a1.tp = "pf" a1.save() - self.assertEqual(Persone.objects._collection.find_one({'_id': a2.pk}), { - '_id': a2.pk, - 'name': 'Wilson Junior', - 'tp': 'pf', - 'father': { - '_id': a1.pk, - 'tp': 'pj' - } - }) + self.assertEqual( + Persone.objects._collection.find_one({"_id": a2.pk}), + { + "_id": a2.pk, + "name": "Wilson Junior", + "tp": "pf", + "father": {"_id": a1.pk, "tp": "pj"}, + }, + ) def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): - TPS = ( - ('n', "Normal"), - ('u', "Urgent") - ) + TPS = (("n", "Normal"), ("u", "Urgent")) name = StringField() - tp = StringField( - verbose_name="Type", - db_field="t", - choices=TPS) + tp = StringField(verbose_name="Type", db_field="t", choices=TPS) class Animal(Document): name = StringField() @@ -351,43 +308,38 @@ class TestCachedReferenceField(MongoDBTestCase): class Ocorrence(Document): person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tp']) + animal = CachedReferenceField(Animal, fields=["tag", "owner.tp"]) Animal.drop_collection() Ocorrence.drop_collection() - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tp='u', name="Wilson Júnior") - ) + a = Animal( + name="Leopard", tag="heavy", owner=Owner(tp="u", name="Wilson Júnior") + ) a.save() o = Ocorrence(person="teste", animal=a) o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tp'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 't': 'u' - } - }) - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['t'], 'u') + self.assertEqual( + dict(a.to_mongo(fields=["tag", "owner.tp"])), + {"_id": a.pk, "tag": "heavy", "owner": {"t": "u"}}, + ) + self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") + self.assertEqual(o.to_mongo()["animal"]["owner"]["t"], "u") # Check to_mongo with fields - self.assertNotIn('animal', o.to_mongo(fields=['person'])) + self.assertNotIn("animal", o.to_mongo(fields=["person"])) # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() - count = Ocorrence.objects( - animal__tag='heavy', animal__owner__tp='u').count() + count = Ocorrence.objects(animal__tag="heavy", animal__owner__tp="u").count() self.assertEqual(count, 1) ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tp='u').first() + animal__tag="heavy", animal__owner__tp="u" + ).first() self.assertEqual(ocorrence.person, "teste") self.assertIsInstance(ocorrence.animal, Animal) @@ -404,43 +356,39 @@ class TestCachedReferenceField(MongoDBTestCase): class Ocorrence(Document): person = StringField() - animal = CachedReferenceField( - Animal, fields=['tag', 'owner.tags']) + animal = CachedReferenceField(Animal, fields=["tag", "owner.tags"]) Animal.drop_collection() Ocorrence.drop_collection() - a = Animal(name="Leopard", tag="heavy", - owner=Owner(tags=['cool', 'funny'], - name="Wilson Júnior") - ) + a = Animal( + name="Leopard", + tag="heavy", + owner=Owner(tags=["cool", "funny"], name="Wilson Júnior"), + ) a.save() o = Ocorrence(person="teste 2", animal=a) o.save() - self.assertEqual(dict(a.to_mongo(fields=['tag', 'owner.tags'])), { - '_id': a.pk, - 'tag': 'heavy', - 'owner': { - 'tags': ['cool', 'funny'] - } - }) + self.assertEqual( + dict(a.to_mongo(fields=["tag", "owner.tags"])), + {"_id": a.pk, "tag": "heavy", "owner": {"tags": ["cool", "funny"]}}, + ) - self.assertEqual(o.to_mongo()['animal']['tag'], 'heavy') - self.assertEqual(o.to_mongo()['animal']['owner']['tags'], - ['cool', 'funny']) + self.assertEqual(o.to_mongo()["animal"]["tag"], "heavy") + self.assertEqual(o.to_mongo()["animal"]["owner"]["tags"], ["cool", "funny"]) # counts Ocorrence(person="teste 2").save() Ocorrence(person="teste 3").save() query = Ocorrence.objects( - animal__tag='heavy', animal__owner__tags='cool')._query - self.assertEqual( - query, {'animal.owner.tags': 'cool', 'animal.tag': 'heavy'}) + animal__tag="heavy", animal__owner__tags="cool" + )._query + self.assertEqual(query, {"animal.owner.tags": "cool", "animal.tag": "heavy"}) ocorrence = Ocorrence.objects( - animal__tag='heavy', - animal__owner__tags='cool').first() + animal__tag="heavy", animal__owner__tags="cool" + ).first() self.assertEqual(ocorrence.person, "teste 2") self.assertIsInstance(ocorrence.animal, Animal) diff --git a/tests/fields/test_complex_datetime_field.py b/tests/fields/test_complex_datetime_field.py index 58dc4b43..4eea5bdc 100644 --- a/tests/fields/test_complex_datetime_field.py +++ b/tests/fields/test_complex_datetime_field.py @@ -14,9 +14,10 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): """Tests for complex datetime fields - which can handle microseconds without rounding. """ + class LogEntry(Document): date = ComplexDateTimeField() - date_with_dots = ComplexDateTimeField(separator='.') + date_with_dots = ComplexDateTimeField(separator=".") LogEntry.drop_collection() @@ -62,17 +63,25 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): mm = dd = hh = ii = ss = [1, 10] for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): - stored = LogEntry(date=datetime.datetime(*values)).to_mongo()['date'] - self.assertTrue(re.match('^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$', stored) is not None) + stored = LogEntry(date=datetime.datetime(*values)).to_mongo()["date"] + self.assertTrue( + re.match("^\d{4},\d{2},\d{2},\d{2},\d{2},\d{2},\d{6}$", stored) + is not None + ) # Test separator - stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()['date_with_dots'] - self.assertTrue(re.match('^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$', stored) is not None) + stored = LogEntry(date_with_dots=datetime.datetime(2014, 1, 1)).to_mongo()[ + "date_with_dots" + ] + self.assertTrue( + re.match("^\d{4}.\d{2}.\d{2}.\d{2}.\d{2}.\d{2}.\d{6}$", stored) is not None + ) def test_complexdatetime_usage(self): """Tests for complex datetime fields - which can handle microseconds without rounding. """ + class LogEntry(Document): date = ComplexDateTimeField() @@ -123,22 +132,21 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): # Test microsecond-level ordering/filtering for microsecond in (99, 999, 9999, 10000): - LogEntry( - date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond) - ).save() + LogEntry(date=datetime.datetime(2015, 1, 1, 0, 0, 0, microsecond)).save() - logs = list(LogEntry.objects.order_by('date')) + logs = list(LogEntry.objects.order_by("date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] self.assertTrue(log.date < next_log.date) - logs = list(LogEntry.objects.order_by('-date')) + logs = list(LogEntry.objects.order_by("-date")) for next_idx, log in enumerate(logs[:-1], start=1): next_log = logs[next_idx] self.assertTrue(log.date > next_log.date) logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000)) + date__lte=datetime.datetime(2015, 1, 1, 0, 0, 0, 10000) + ) self.assertEqual(logs.count(), 4) def test_no_default_value(self): @@ -156,6 +164,7 @@ class ComplexDateTimeFieldTest(MongoDBTestCase): def test_default_static_value(self): NOW = datetime.datetime.utcnow() + class Log(Document): timestamp = ComplexDateTimeField(default=NOW) diff --git a/tests/fields/test_date_field.py b/tests/fields/test_date_field.py index 82adb514..da572134 100644 --- a/tests/fields/test_date_field.py +++ b/tests/fields/test_date_field.py @@ -18,10 +18,11 @@ class TestDateField(MongoDBTestCase): Ensure an exception is raised when trying to cast an empty string to datetime. """ + class MyDoc(Document): dt = DateField() - md = MyDoc(dt='') + md = MyDoc(dt="") self.assertRaises(ValidationError, md.save) def test_date_from_whitespace_string(self): @@ -29,16 +30,18 @@ class TestDateField(MongoDBTestCase): Ensure an exception is raised when trying to cast a whitespace-only string to datetime. """ + class MyDoc(Document): dt = DateField() - md = MyDoc(dt=' ') + md = MyDoc(dt=" ") self.assertRaises(ValidationError, md.save) def test_default_values_today(self): """Ensure that default field values are used when creating a document. """ + class Person(Document): day = DateField(default=datetime.date.today) @@ -46,13 +49,14 @@ class TestDateField(MongoDBTestCase): person.validate() self.assertEqual(person.day, person.day) self.assertEqual(person.day, datetime.date.today()) - self.assertEqual(person._data['day'], person.day) + self.assertEqual(person._data["day"], person.day) def test_date(self): """Tests showing pymongo date fields See: http://api.mongodb.org/python/current/api/bson/son.html#dt """ + class LogEntry(Document): date = DateField() @@ -95,6 +99,7 @@ class TestDateField(MongoDBTestCase): def test_regular_usage(self): """Tests for regular datetime fields""" + class LogEntry(Document): date = DateField() @@ -106,12 +111,12 @@ class TestDateField(MongoDBTestCase): log.validate() log.save() - for query in (d1, d1.isoformat(' ')): + for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) self.assertEqual(log, log1) if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) + log1 = LogEntry.objects.get(date=d1.isoformat("T")) self.assertEqual(log, log1) # create additional 19 log entries for a total of 20 @@ -142,6 +147,7 @@ class TestDateField(MongoDBTestCase): """Ensure that invalid values cannot be assigned to datetime fields. """ + class LogEntry(Document): time = DateField() @@ -152,14 +158,14 @@ class TestDateField(MongoDBTestCase): log.time = datetime.date.today() log.validate() - log.time = datetime.datetime.now().isoformat(' ') + log.time = datetime.datetime.now().isoformat(" ") log.validate() if dateutil: - log.time = datetime.datetime.now().isoformat('T') + log.time = datetime.datetime.now().isoformat("T") log.validate() log.time = -1 self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' + log.time = "ABC" self.assertRaises(ValidationError, log.validate) diff --git a/tests/fields/test_datetime_field.py b/tests/fields/test_datetime_field.py index 92f0668a..c911390a 100644 --- a/tests/fields/test_datetime_field.py +++ b/tests/fields/test_datetime_field.py @@ -19,10 +19,11 @@ class TestDateTimeField(MongoDBTestCase): Ensure an exception is raised when trying to cast an empty string to datetime. """ + class MyDoc(Document): dt = DateTimeField() - md = MyDoc(dt='') + md = MyDoc(dt="") self.assertRaises(ValidationError, md.save) def test_datetime_from_whitespace_string(self): @@ -30,16 +31,18 @@ class TestDateTimeField(MongoDBTestCase): Ensure an exception is raised when trying to cast a whitespace-only string to datetime. """ + class MyDoc(Document): dt = DateTimeField() - md = MyDoc(dt=' ') + md = MyDoc(dt=" ") self.assertRaises(ValidationError, md.save) def test_default_value_utcnow(self): """Ensure that default field values are used when creating a document. """ + class Person(Document): created = DateTimeField(default=dt.datetime.utcnow) @@ -48,8 +51,10 @@ class TestDateTimeField(MongoDBTestCase): person.validate() person_created_t0 = person.created self.assertLess(person.created - utcnow, dt.timedelta(seconds=1)) - self.assertEqual(person_created_t0, person.created) # make sure it does not change - self.assertEqual(person._data['created'], person.created) + self.assertEqual( + person_created_t0, person.created + ) # make sure it does not change + self.assertEqual(person._data["created"], person.created) def test_handling_microseconds(self): """Tests showing pymongo datetime fields handling of microseconds. @@ -58,6 +63,7 @@ class TestDateTimeField(MongoDBTestCase): See: http://api.mongodb.org/python/current/api/bson/son.html#dt """ + class LogEntry(Document): date = DateTimeField() @@ -103,6 +109,7 @@ class TestDateTimeField(MongoDBTestCase): def test_regular_usage(self): """Tests for regular datetime fields""" + class LogEntry(Document): date = DateTimeField() @@ -114,12 +121,12 @@ class TestDateTimeField(MongoDBTestCase): log.validate() log.save() - for query in (d1, d1.isoformat(' ')): + for query in (d1, d1.isoformat(" ")): log1 = LogEntry.objects.get(date=query) self.assertEqual(log, log1) if dateutil: - log1 = LogEntry.objects.get(date=d1.isoformat('T')) + log1 = LogEntry.objects.get(date=d1.isoformat("T")) self.assertEqual(log, log1) # create additional 19 log entries for a total of 20 @@ -150,8 +157,7 @@ class TestDateTimeField(MongoDBTestCase): self.assertEqual(logs.count(), 10) logs = LogEntry.objects.filter( - date__lte=dt.datetime(1980, 1, 1), - date__gte=dt.datetime(1975, 1, 1), + date__lte=dt.datetime(1980, 1, 1), date__gte=dt.datetime(1975, 1, 1) ) self.assertEqual(logs.count(), 5) @@ -159,6 +165,7 @@ class TestDateTimeField(MongoDBTestCase): """Ensure that invalid values cannot be assigned to datetime fields. """ + class LogEntry(Document): time = DateTimeField() @@ -169,32 +176,32 @@ class TestDateTimeField(MongoDBTestCase): log.time = dt.date.today() log.validate() - log.time = dt.datetime.now().isoformat(' ') + log.time = dt.datetime.now().isoformat(" ") log.validate() - log.time = '2019-05-16 21:42:57.897847' + log.time = "2019-05-16 21:42:57.897847" log.validate() if dateutil: - log.time = dt.datetime.now().isoformat('T') + log.time = dt.datetime.now().isoformat("T") log.validate() log.time = -1 self.assertRaises(ValidationError, log.validate) - log.time = 'ABC' + log.time = "ABC" self.assertRaises(ValidationError, log.validate) - log.time = '2019-05-16 21:GARBAGE:12' + log.time = "2019-05-16 21:GARBAGE:12" self.assertRaises(ValidationError, log.validate) - log.time = '2019-05-16 21:42:57.GARBAGE' + log.time = "2019-05-16 21:42:57.GARBAGE" self.assertRaises(ValidationError, log.validate) - log.time = '2019-05-16 21:42:57.123.456' + log.time = "2019-05-16 21:42:57.123.456" self.assertRaises(ValidationError, log.validate) def test_parse_datetime_as_str(self): class DTDoc(Document): date = DateTimeField() - date_str = '2019-03-02 22:26:01' + date_str = "2019-03-02 22:26:01" # make sure that passing a parsable datetime works dtd = DTDoc() @@ -206,7 +213,7 @@ class TestDateTimeField(MongoDBTestCase): self.assertIsInstance(dtd.date, dt.datetime) self.assertEqual(str(dtd.date), date_str) - dtd.date = 'January 1st, 9999999999' + dtd.date = "January 1st, 9999999999" self.assertRaises(ValidationError, dtd.validate) @@ -217,7 +224,7 @@ class TestDateTimeTzAware(MongoDBTestCase): connection._connections = {} connection._dbs = {} - connect(db='mongoenginetest', tz_aware=True) + connect(db="mongoenginetest", tz_aware=True) class LogEntry(Document): time = DateTimeField() @@ -228,4 +235,4 @@ class TestDateTimeTzAware(MongoDBTestCase): log = LogEntry.objects.first() log.time = dt.datetime(2013, 1, 1, 0, 0, 0) - self.assertEqual(['time'], log._changed_fields) + self.assertEqual(["time"], log._changed_fields) diff --git a/tests/fields/test_decimal_field.py b/tests/fields/test_decimal_field.py index 0213b880..30b7e5ea 100644 --- a/tests/fields/test_decimal_field.py +++ b/tests/fields/test_decimal_field.py @@ -7,32 +7,31 @@ from tests.utils import MongoDBTestCase class TestDecimalField(MongoDBTestCase): - def test_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ + class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), - max_value=Decimal('3.5')) + height = DecimalField(min_value=Decimal("0.1"), max_value=Decimal("3.5")) Person.drop_collection() - Person(height=Decimal('1.89')).save() + Person(height=Decimal("1.89")).save() person = Person.objects.first() - self.assertEqual(person.height, Decimal('1.89')) + self.assertEqual(person.height, Decimal("1.89")) - person.height = '2.0' + person.height = "2.0" person.save() person.height = 0.01 self.assertRaises(ValidationError, person.validate) - person.height = Decimal('0.01') + person.height = Decimal("0.01") self.assertRaises(ValidationError, person.validate) - person.height = Decimal('4.0') + person.height = Decimal("4.0") self.assertRaises(ValidationError, person.validate) - person.height = 'something invalid' + person.height = "something invalid" self.assertRaises(ValidationError, person.validate) - person_2 = Person(height='something invalid') + person_2 = Person(height="something invalid") self.assertRaises(ValidationError, person_2.validate) def test_comparison(self): @@ -58,7 +57,14 @@ class TestDecimalField(MongoDBTestCase): string_value = DecimalField(precision=4, force_string=True) Person.drop_collection() - values_to_store = [10, 10.1, 10.11, "10.111", Decimal("10.1111"), Decimal("10.11111")] + values_to_store = [ + 10, + 10.1, + 10.11, + "10.111", + Decimal("10.1111"), + Decimal("10.11111"), + ] for store_at_creation in [True, False]: for value in values_to_store: # to_python is called explicitly if values were sent in the kwargs of __init__ @@ -72,20 +78,27 @@ class TestDecimalField(MongoDBTestCase): # How its stored expected = [ - {'float_value': 10.0, 'string_value': '10.0000'}, - {'float_value': 10.1, 'string_value': '10.1000'}, - {'float_value': 10.11, 'string_value': '10.1100'}, - {'float_value': 10.111, 'string_value': '10.1110'}, - {'float_value': 10.1111, 'string_value': '10.1111'}, - {'float_value': 10.1111, 'string_value': '10.1111'}] + {"float_value": 10.0, "string_value": "10.0000"}, + {"float_value": 10.1, "string_value": "10.1000"}, + {"float_value": 10.11, "string_value": "10.1100"}, + {"float_value": 10.111, "string_value": "10.1110"}, + {"float_value": 10.1111, "string_value": "10.1111"}, + {"float_value": 10.1111, "string_value": "10.1111"}, + ] expected.extend(expected) - actual = list(Person.objects.exclude('id').as_pymongo()) + actual = list(Person.objects.exclude("id").as_pymongo()) self.assertEqual(expected, actual) # How it comes out locally - expected = [Decimal('10.0000'), Decimal('10.1000'), Decimal('10.1100'), - Decimal('10.1110'), Decimal('10.1111'), Decimal('10.1111')] + expected = [ + Decimal("10.0000"), + Decimal("10.1000"), + Decimal("10.1100"), + Decimal("10.1110"), + Decimal("10.1111"), + Decimal("10.1111"), + ] expected.extend(expected) - for field_name in ['float_value', 'string_value']: + for field_name in ["float_value", "string_value"]: actual = list(Person.objects().scalar(field_name)) self.assertEqual(expected, actual) diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index ade02ccf..07bab85b 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -6,95 +6,92 @@ from tests.utils import MongoDBTestCase, get_as_pymongo class TestDictField(MongoDBTestCase): - def test_storage(self): class BlogPost(Document): info = DictField() BlogPost.drop_collection() - info = {'testkey': 'testvalue'} + info = {"testkey": "testvalue"} post = BlogPost(info=info).save() - self.assertEqual( - get_as_pymongo(post), - { - '_id': post.id, - 'info': info - } - ) + self.assertEqual(get_as_pymongo(post), {"_id": post.id, "info": info}) def test_general_things(self): """Ensure that dict types work as expected.""" + class BlogPost(Document): info = DictField() BlogPost.drop_collection() post = BlogPost() - post.info = 'my post' + post.info = "my post" self.assertRaises(ValidationError, post.validate) - post.info = ['test', 'test'] + post.info = ["test", "test"] self.assertRaises(ValidationError, post.validate) - post.info = {'$title': 'test'} + post.info = {"$title": "test"} self.assertRaises(ValidationError, post.validate) - post.info = {'nested': {'$title': 'test'}} + post.info = {"nested": {"$title": "test"}} self.assertRaises(ValidationError, post.validate) - post.info = {'the.title': 'test'} + post.info = {"the.title": "test"} self.assertRaises(ValidationError, post.validate) - post.info = {'nested': {'the.title': 'test'}} + post.info = {"nested": {"the.title": "test"}} self.assertRaises(ValidationError, post.validate) - post.info = {1: 'test'} + post.info = {1: "test"} self.assertRaises(ValidationError, post.validate) - post.info = {'title': 'test'} + post.info = {"title": "test"} post.save() post = BlogPost() - post.info = {'title': 'dollar_sign', 'details': {'te$t': 'test'}} + post.info = {"title": "dollar_sign", "details": {"te$t": "test"}} post.save() post = BlogPost() - post.info = {'details': {'test': 'test'}} + post.info = {"details": {"test": "test"}} post.save() post = BlogPost() - post.info = {'details': {'test': 3}} + post.info = {"details": {"test": 3}} post.save() self.assertEqual(BlogPost.objects.count(), 4) + self.assertEqual(BlogPost.objects.filter(info__title__exact="test").count(), 1) self.assertEqual( - BlogPost.objects.filter(info__title__exact='test').count(), 1) - self.assertEqual( - BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + BlogPost.objects.filter(info__details__test__exact="test").count(), 1 + ) - post = BlogPost.objects.filter(info__title__exact='dollar_sign').first() - self.assertIn('te$t', post['info']['details']) + post = BlogPost.objects.filter(info__title__exact="dollar_sign").first() + self.assertIn("te$t", post["info"]["details"]) # Confirm handles non strings or non existing keys self.assertEqual( - BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + BlogPost.objects.filter(info__details__test__exact=5).count(), 0 + ) self.assertEqual( - BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + BlogPost.objects.filter(info__made_up__test__exact="test").count(), 0 + ) - post = BlogPost.objects.create(info={'title': 'original'}) - post.info.update({'title': 'updated'}) + post = BlogPost.objects.create(info={"title": "original"}) + post.info.update({"title": "updated"}) post.save() post.reload() - self.assertEqual('updated', post.info['title']) + self.assertEqual("updated", post.info["title"]) - post.info.setdefault('authors', []) + post.info.setdefault("authors", []) post.save() post.reload() - self.assertEqual([], post.info['authors']) + self.assertEqual([], post.info["authors"]) def test_dictfield_dump_document(self): """Ensure a DictField can handle another document's dump.""" + class Doc(Document): field = DictField() @@ -106,51 +103,62 @@ class TestDictField(MongoDBTestCase): id = IntField(primary_key=True, default=1) recursive = DictField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class ToEmbedChild(ToEmbedParent): pass to_embed_recursive = ToEmbed(id=1).save() to_embed = ToEmbed( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + id=2, recursive=to_embed_recursive.to_mongo().to_dict() + ).save() doc = Doc(field=to_embed.to_mongo().to_dict()) doc.save() self.assertIsInstance(doc.field, dict) - self.assertEqual(doc.field, {'_id': 2, 'recursive': {'_id': 1, 'recursive': {}}}) + self.assertEqual( + doc.field, {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} + ) # Same thing with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild( - id=2, recursive=to_embed_recursive.to_mongo().to_dict()).save() + id=2, recursive=to_embed_recursive.to_mongo().to_dict() + ).save() doc = Doc(field=to_embed_child.to_mongo().to_dict()) doc.save() self.assertIsInstance(doc.field, dict) expected = { - '_id': 2, '_cls': 'ToEmbedParent.ToEmbedChild', - 'recursive': {'_id': 1, '_cls': 'ToEmbedParent.ToEmbedChild', 'recursive': {}} + "_id": 2, + "_cls": "ToEmbedParent.ToEmbedChild", + "recursive": { + "_id": 1, + "_cls": "ToEmbedParent.ToEmbedChild", + "recursive": {}, + }, } self.assertEqual(doc.field, expected) def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" + class Simple(Document): mapping = DictField(field=IntField()) Simple.drop_collection() e = Simple() - e.mapping['someint'] = 1 + e.mapping["someint"] = 1 e.save() # try creating an invalid mapping with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" + e.mapping["somestring"] = "abc" e.save() def test_dictfield_complex(self): """Ensure that the dict field can handle the complex types.""" + class SettingBase(EmbeddedDocument): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class StringSetting(SettingBase): value = StringField() @@ -164,70 +172,72 @@ class TestDictField(MongoDBTestCase): Simple.drop_collection() e = Simple() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', - 'float': 1.001, - 'complex': IntegerSetting(value=42), - 'list': [IntegerSetting(value=42), - StringSetting(value='foo')]} + e.mapping["somestring"] = StringSetting(value="foo") + e.mapping["someint"] = IntegerSetting(value=42) + e.mapping["nested_dict"] = { + "number": 1, + "string": "Hi!", + "float": 1.001, + "complex": IntegerSetting(value=42), + "list": [IntegerSetting(value=42), StringSetting(value="foo")], + } e.save() e2 = Simple.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + self.assertIsInstance(e2.mapping["somestring"], StringSetting) + self.assertIsInstance(e2.mapping["someint"], IntegerSetting) # Test querying + self.assertEqual(Simple.objects.filter(mapping__someint__value=42).count(), 1) self.assertEqual( - Simple.objects.filter(mapping__someint__value=42).count(), 1) + Simple.objects.filter(mapping__nested_dict__number=1).count(), 1 + ) self.assertEqual( - Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1 + ) self.assertEqual( - Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1 + ) self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) - self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 1 + ) # Confirm can update + Simple.objects().update(set__mapping={"someint": IntegerSetting(value=10)}) Simple.objects().update( - set__mapping={"someint": IntegerSetting(value=10)}) - Simple.objects().update( - set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + set__mapping__nested_dict__list__1=StringSetting(value="Boo") + ) self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + Simple.objects.filter(mapping__nested_dict__list__1__value="foo").count(), 0 + ) self.assertEqual( - Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + Simple.objects.filter(mapping__nested_dict__list__1__value="Boo").count(), 1 + ) def test_push_dict(self): class MyModel(Document): events = ListField(DictField()) - doc = MyModel(events=[{'a': 1}]).save() + doc = MyModel(events=[{"a": 1}]).save() raw_doc = get_as_pymongo(doc) - expected_raw_doc = { - '_id': doc.id, - 'events': [{'a': 1}] - } + expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}]} self.assertEqual(raw_doc, expected_raw_doc) MyModel.objects(id=doc.id).update(push__events={}) raw_doc = get_as_pymongo(doc) - expected_raw_doc = { - '_id': doc.id, - 'events': [{'a': 1}, {}] - } + expected_raw_doc = {"_id": doc.id, "events": [{"a": 1}, {}]} self.assertEqual(raw_doc, expected_raw_doc) def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" + class D(Document): data = DictField() data2 = DictField(default=lambda: {}) d1 = D() - d1.data['foo'] = 'bar' - d1.data2['foo'] = 'bar' + d1.data["foo"] = "bar" + d1.data2["foo"] = "bar" d2 = D() self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) @@ -255,22 +265,25 @@ class TestDictField(MongoDBTestCase): class Embedded(EmbeddedDocument): name = StringField() - embed = Embedded(name='garbage') + embed = Embedded(name="garbage") doc = DictFieldTest(dictionary=embed) with self.assertRaises(ValidationError) as ctx_err: doc.validate() self.assertIn("'dictionary'", str(ctx_err.exception)) - self.assertIn('Only dictionaries may be used in a DictField', str(ctx_err.exception)) + self.assertIn( + "Only dictionaries may be used in a DictField", str(ctx_err.exception) + ) def test_atomic_update_dict_field(self): """Ensure that the entire DictField can be atomically updated.""" + class Simple(Document): mapping = DictField(field=ListField(IntField(required=True))) Simple.drop_collection() e = Simple() - e.mapping['someints'] = [1, 2] + e.mapping["someints"] = [1, 2] e.save() e.update(set__mapping={"ints": [3, 4]}) e.reload() @@ -279,7 +292,7 @@ class TestDictField(MongoDBTestCase): # try creating an invalid mapping with self.assertRaises(ValueError): - e.update(set__mapping={"somestrings": ["foo", "bar", ]}) + e.update(set__mapping={"somestrings": ["foo", "bar"]}) def test_dictfield_with_referencefield_complex_nesting_cases(self): """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" @@ -296,29 +309,33 @@ class TestDictField(MongoDBTestCase): mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) - mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) - mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) + mapping8 = DictField( + ListField(DictField(ListField(ReferenceField(Doc, dbref=True)))) + ) + mapping9 = DictField( + ListField(DictField(ListField(ReferenceField(Doc, dbref=False)))) + ) Doc.drop_collection() Simple.drop_collection() - d = Doc(s='aa').save() + d = Doc(s="aa").save() e = Simple() - e.mapping0['someint'] = e.mapping1['someint'] = d - e.mapping2['someint'] = e.mapping3['someint'] = [d] - e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} - e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] - e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] + e.mapping0["someint"] = e.mapping1["someint"] = d + e.mapping2["someint"] = e.mapping3["someint"] = [d] + e.mapping4["someint"] = e.mapping5["someint"] = {"d": d} + e.mapping6["someint"] = e.mapping7["someint"] = [{"d": d}] + e.mapping8["someint"] = e.mapping9["someint"] = [{"d": [d]}] e.save() s = Simple.objects.first() - self.assertIsInstance(s.mapping0['someint'], Doc) - self.assertIsInstance(s.mapping1['someint'], Doc) - self.assertIsInstance(s.mapping2['someint'][0], Doc) - self.assertIsInstance(s.mapping3['someint'][0], Doc) - self.assertIsInstance(s.mapping4['someint']['d'], Doc) - self.assertIsInstance(s.mapping5['someint']['d'], Doc) - self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) - self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) - self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) + self.assertIsInstance(s.mapping0["someint"], Doc) + self.assertIsInstance(s.mapping1["someint"], Doc) + self.assertIsInstance(s.mapping2["someint"][0], Doc) + self.assertIsInstance(s.mapping3["someint"][0], Doc) + self.assertIsInstance(s.mapping4["someint"]["d"], Doc) + self.assertIsInstance(s.mapping5["someint"]["d"], Doc) + self.assertIsInstance(s.mapping6["someint"][0]["d"], Doc) + self.assertIsInstance(s.mapping7["someint"][0]["d"], Doc) + self.assertIsInstance(s.mapping8["someint"][0]["d"][0], Doc) + self.assertIsInstance(s.mapping9["someint"][0]["d"][0], Doc) diff --git a/tests/fields/test_email_field.py b/tests/fields/test_email_field.py index 3ce49d62..06ec5151 100644 --- a/tests/fields/test_email_field.py +++ b/tests/fields/test_email_field.py @@ -12,28 +12,29 @@ class TestEmailField(MongoDBTestCase): class User(Document): email = EmailField() - user = User(email='ross@example.com') + user = User(email="ross@example.com") user.validate() - user = User(email='ross@example.co.uk') + user = User(email="ross@example.co.uk") user.validate() - user = User(email=('Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S' - 'aJIazqqWkm7.net')) + user = User( + email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5SaJIazqqWkm7.net") + ) user.validate() - user = User(email='new-tld@example.technology') + user = User(email="new-tld@example.technology") user.validate() - user = User(email='ross@example.com.') + user = User(email="ross@example.com.") self.assertRaises(ValidationError, user.validate) # unicode domain - user = User(email=u'user@пример.рф') + user = User(email=u"user@пример.рф") user.validate() # invalid unicode domain - user = User(email=u'user@пример') + user = User(email=u"user@пример") self.assertRaises(ValidationError, user.validate) # invalid data type @@ -44,20 +45,20 @@ class TestEmailField(MongoDBTestCase): # Don't run this test on pypy3, which doesn't support unicode regex: # https://bitbucket.org/pypy/pypy/issues/1821/regular-expression-doesnt-find-unicode if sys.version_info[:2] == (3, 2): - raise SkipTest('unicode email addresses are not supported on PyPy 3') + raise SkipTest("unicode email addresses are not supported on PyPy 3") class User(Document): email = EmailField() # unicode user shouldn't validate by default... - user = User(email=u'Dörte@Sörensen.example.com') + user = User(email=u"Dörte@Sörensen.example.com") self.assertRaises(ValidationError, user.validate) # ...but it should be fine with allow_utf8_user set to True class User(Document): email = EmailField(allow_utf8_user=True) - user = User(email=u'Dörte@Sörensen.example.com') + user = User(email=u"Dörte@Sörensen.example.com") user.validate() def test_email_field_domain_whitelist(self): @@ -65,22 +66,22 @@ class TestEmailField(MongoDBTestCase): email = EmailField() # localhost domain shouldn't validate by default... - user = User(email='me@localhost') + user = User(email="me@localhost") self.assertRaises(ValidationError, user.validate) # ...but it should be fine if it's whitelisted class User(Document): - email = EmailField(domain_whitelist=['localhost']) + email = EmailField(domain_whitelist=["localhost"]) - user = User(email='me@localhost') + user = User(email="me@localhost") user.validate() def test_email_domain_validation_fails_if_invalid_idn(self): class User(Document): email = EmailField() - invalid_idn = '.google.com' - user = User(email='me@%s' % invalid_idn) + invalid_idn = ".google.com" + user = User(email="me@%s" % invalid_idn) with self.assertRaises(ValidationError) as ctx_err: user.validate() self.assertIn("domain failed IDN encoding", str(ctx_err.exception)) @@ -89,9 +90,9 @@ class TestEmailField(MongoDBTestCase): class User(Document): email = EmailField() - valid_ipv4 = 'email@[127.0.0.1]' - valid_ipv6 = 'email@[2001:dB8::1]' - invalid_ip = 'email@[324.0.0.1]' + valid_ipv4 = "email@[127.0.0.1]" + valid_ipv6 = "email@[2001:dB8::1]" + invalid_ip = "email@[324.0.0.1]" # IP address as a domain shouldn't validate by default... user = User(email=valid_ipv4) @@ -119,12 +120,12 @@ class TestEmailField(MongoDBTestCase): def test_email_field_honors_regex(self): class User(Document): - email = EmailField(regex=r'\w+@example.com') + email = EmailField(regex=r"\w+@example.com") # Fails regex validation - user = User(email='me@foo.com') + user = User(email="me@foo.com") self.assertRaises(ValidationError, user.validate) # Passes regex validation - user = User(email='me@example.com') + user = User(email="me@example.com") self.assertIsNone(user.validate()) diff --git a/tests/fields/test_embedded_document_field.py b/tests/fields/test_embedded_document_field.py index a262d054..6b420781 100644 --- a/tests/fields/test_embedded_document_field.py +++ b/tests/fields/test_embedded_document_field.py @@ -1,7 +1,18 @@ # -*- coding: utf-8 -*- -from mongoengine import Document, StringField, ValidationError, EmbeddedDocument, EmbeddedDocumentField, \ - InvalidQueryError, LookUpError, IntField, GenericEmbeddedDocumentField, ListField, EmbeddedDocumentListField, \ - ReferenceField +from mongoengine import ( + Document, + StringField, + ValidationError, + EmbeddedDocument, + EmbeddedDocumentField, + InvalidQueryError, + LookUpError, + IntField, + GenericEmbeddedDocumentField, + ListField, + EmbeddedDocumentListField, + ReferenceField, +) from tests.utils import MongoDBTestCase @@ -14,22 +25,24 @@ class TestEmbeddedDocumentField(MongoDBTestCase): field = EmbeddedDocumentField(MyDoc) self.assertEqual(field.document_type_obj, MyDoc) - field2 = EmbeddedDocumentField('MyDoc') - self.assertEqual(field2.document_type_obj, 'MyDoc') + field2 = EmbeddedDocumentField("MyDoc") + self.assertEqual(field2.document_type_obj, "MyDoc") def test___init___throw_error_if_document_type_is_not_EmbeddedDocument(self): with self.assertRaises(ValidationError): EmbeddedDocumentField(dict) def test_document_type_throw_error_if_not_EmbeddedDocument_subclass(self): - class MyDoc(Document): name = StringField() - emb = EmbeddedDocumentField('MyDoc') + emb = EmbeddedDocumentField("MyDoc") with self.assertRaises(ValidationError) as ctx: emb.document_type - self.assertIn('Invalid embedded document class provided to an EmbeddedDocumentField', str(ctx.exception)) + self.assertIn( + "Invalid embedded document class provided to an EmbeddedDocumentField", + str(ctx.exception), + ) def test_embedded_document_field_only_allow_subclasses_of_embedded_document(self): # Relates to #1661 @@ -37,12 +50,14 @@ class TestEmbeddedDocumentField(MongoDBTestCase): name = StringField() with self.assertRaises(ValidationError): + class MyFailingDoc(Document): emb = EmbeddedDocumentField(MyDoc) with self.assertRaises(ValidationError): + class MyFailingdoc2(Document): - emb = EmbeddedDocumentField('MyDoc') + emb = EmbeddedDocumentField("MyDoc") def test_query_embedded_document_attribute(self): class AdminSettings(EmbeddedDocument): @@ -55,34 +70,31 @@ class TestEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - p = Person( - settings=AdminSettings(foo1='bar1', foo2='bar2'), - name='John', - ).save() + p = Person(settings=AdminSettings(foo1="bar1", foo2="bar2"), name="John").save() # Test non exiting attribute with self.assertRaises(InvalidQueryError) as ctx_err: - Person.objects(settings__notexist='bar').first() + Person.objects(settings__notexist="bar").first() self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') with self.assertRaises(LookUpError): - Person.objects.only('settings.notexist') + Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p.id) - only_p = Person.objects.only('settings.foo1').first() + self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p.id) + only_p = Person.objects.only("settings.foo1").first() self.assertEqual(only_p.settings.foo1, p.settings.foo1) self.assertIsNone(only_p.settings.foo2) self.assertIsNone(only_p.name) - exclude_p = Person.objects.exclude('settings.foo1').first() + exclude_p = Person.objects.exclude("settings.foo1").first() self.assertIsNone(exclude_p.settings.foo1) self.assertEqual(exclude_p.settings.foo2, p.settings.foo2) self.assertEqual(exclude_p.name, p.name) def test_query_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} base_foo = StringField() class AdminSettings(BaseSettings): @@ -93,26 +105,26 @@ class TestEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p = Person(settings=AdminSettings(base_foo="basefoo", sub_foo="subfoo")) p.save() # Test non exiting attribute with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) + self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) - only_p = Person.objects.only('settings.base_foo', 'settings._cls').first() - self.assertEqual(only_p.settings.base_foo, 'basefoo') + only_p = Person.objects.only("settings.base_foo", "settings._cls").first() + self.assertEqual(only_p.settings.base_foo, "basefoo") self.assertIsNone(only_p.settings.sub_foo) def test_query_list_embedded_document_with_inheritance(self): class Post(EmbeddedDocument): title = StringField(max_length=120, required=True) - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class TextPost(Post): content = StringField() @@ -123,8 +135,8 @@ class TestEmbeddedDocumentField(MongoDBTestCase): class Record(Document): posts = ListField(EmbeddedDocumentField(Post)) - record_movie = Record(posts=[MoviePost(author='John', title='foo')]).save() - record_text = Record(posts=[TextPost(content='a', title='foo')]).save() + record_movie = Record(posts=[MoviePost(author="John", title="foo")]).save() + record_text = Record(posts=[TextPost(content="a", title="foo")]).save() records = list(Record.objects(posts__author=record_movie.posts[0].author)) self.assertEqual(len(records), 1) @@ -134,11 +146,10 @@ class TestEmbeddedDocumentField(MongoDBTestCase): self.assertEqual(len(records), 1) self.assertEqual(records[0].id, record_text.id) - self.assertEqual(Record.objects(posts__title='foo').count(), 2) + self.assertEqual(Record.objects(posts__title="foo").count(), 2) class TestGenericEmbeddedDocumentField(MongoDBTestCase): - def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -153,8 +164,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - person = Person(name='Test User') - person.like = Car(name='Fiat') + person = Person(name="Test User") + person.like = Car(name="Fiat") person.save() person = Person.objects.first() @@ -168,6 +179,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): def test_generic_embedded_document_choices(self): """Ensure you can limit GenericEmbeddedDocument choices.""" + class Car(EmbeddedDocument): name = StringField() @@ -181,8 +193,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - person = Person(name='Test User') - person.like = Car(name='Fiat') + person = Person(name="Test User") + person.like = Car(name="Fiat") self.assertRaises(ValidationError, person.validate) person.like = Dish(food="arroz", number=15) @@ -195,6 +207,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): """Ensure you can limit GenericEmbeddedDocument choices inside a list field. """ + class Car(EmbeddedDocument): name = StringField() @@ -208,8 +221,8 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - person = Person(name='Test User') - person.likes = [Car(name='Fiat')] + person = Person(name="Test User") + person.likes = [Car(name="Fiat")] self.assertRaises(ValidationError, person.validate) person.likes = [Dish(food="arroz", number=15)] @@ -222,25 +235,23 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): """ Ensure fields with document choices validate given a valid choice. """ + class UserComments(EmbeddedDocument): author = StringField() message = StringField() class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) + comments = ListField(GenericEmbeddedDocumentField(choices=(UserComments,))) # Ensure Validation Passes - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() + BlogPost(comments=[UserComments(author="user2", message="message2")]).save() def test_choices_validation_documents_invalid(self): """ Ensure fields with document choices validate given an invalid choice. This should throw a ValidationError exception. """ + class UserComments(EmbeddedDocument): author = StringField() message = StringField() @@ -250,31 +261,28 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): message = StringField() class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(UserComments,)) - ) + comments = ListField(GenericEmbeddedDocumentField(choices=(UserComments,))) # Single Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - ]) + post = BlogPost(comments=[ModeratorComments(author="mod1", message="message1")]) self.assertRaises(ValidationError, post.save) # Mixed Entry Failure - post = BlogPost(comments=[ - ModeratorComments(author='mod1', message='message1'), - UserComments(author='user2', message='message2'), - ]) + post = BlogPost( + comments=[ + ModeratorComments(author="mod1", message="message1"), + UserComments(author="user2", message="message2"), + ] + ) self.assertRaises(ValidationError, post.save) def test_choices_validation_documents_inheritance(self): """ Ensure fields with document choices validate given subclass of choice. """ + class Comments(EmbeddedDocument): - meta = { - 'abstract': True - } + meta = {"abstract": True} author = StringField() message = StringField() @@ -282,14 +290,10 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): pass class BlogPost(Document): - comments = ListField( - GenericEmbeddedDocumentField(choices=(Comments,)) - ) + comments = ListField(GenericEmbeddedDocumentField(choices=(Comments,))) # Save Valid EmbeddedDocument Type - BlogPost(comments=[ - UserComments(author='user2', message='message2'), - ]).save() + BlogPost(comments=[UserComments(author="user2", message="message2")]).save() def test_query_generic_embedded_document_attribute(self): class AdminSettings(EmbeddedDocument): @@ -299,28 +303,30 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): foo2 = StringField() class Person(Document): - settings = GenericEmbeddedDocumentField(choices=(AdminSettings, NonAdminSettings)) + settings = GenericEmbeddedDocumentField( + choices=(AdminSettings, NonAdminSettings) + ) Person.drop_collection() - p1 = Person(settings=AdminSettings(foo1='bar1')).save() - p2 = Person(settings=NonAdminSettings(foo2='bar2')).save() + p1 = Person(settings=AdminSettings(foo1="bar1")).save() + p2 = Person(settings=NonAdminSettings(foo2="bar2")).save() # Test non exiting attribute with self.assertRaises(InvalidQueryError) as ctx_err: - Person.objects(settings__notexist='bar').first() + Person.objects(settings__notexist="bar").first() self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') with self.assertRaises(LookUpError): - Person.objects.only('settings.notexist') + Person.objects.only("settings.notexist") # Test existing attribute - self.assertEqual(Person.objects(settings__foo1='bar1').first().id, p1.id) - self.assertEqual(Person.objects(settings__foo2='bar2').first().id, p2.id) + self.assertEqual(Person.objects(settings__foo1="bar1").first().id, p1.id) + self.assertEqual(Person.objects(settings__foo2="bar2").first().id, p2.id) def test_query_generic_embedded_document_attribute_with_inheritance(self): class BaseSettings(EmbeddedDocument): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} base_foo = StringField() class AdminSettings(BaseSettings): @@ -331,14 +337,14 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase): Person.drop_collection() - p = Person(settings=AdminSettings(base_foo='basefoo', sub_foo='subfoo')) + p = Person(settings=AdminSettings(base_foo="basefoo", sub_foo="subfoo")) p.save() # Test non exiting attribute with self.assertRaises(InvalidQueryError) as ctx_err: - self.assertEqual(Person.objects(settings__notexist='bar').first().id, p.id) + self.assertEqual(Person.objects(settings__notexist="bar").first().id, p.id) self.assertEqual(unicode(ctx_err.exception), u'Cannot resolve field "notexist"') # Test existing attribute - self.assertEqual(Person.objects(settings__base_foo='basefoo').first().id, p.id) - self.assertEqual(Person.objects(settings__sub_foo='subfoo').first().id, p.id) + self.assertEqual(Person.objects(settings__base_foo="basefoo").first().id, p.id) + self.assertEqual(Person.objects(settings__sub_foo="subfoo").first().id, p.id) diff --git a/tests/fields/test_float_field.py b/tests/fields/test_float_field.py index fa92cf20..9f357ce5 100644 --- a/tests/fields/test_float_field.py +++ b/tests/fields/test_float_field.py @@ -7,7 +7,6 @@ from tests.utils import MongoDBTestCase class TestFloatField(MongoDBTestCase): - def test_float_ne_operator(self): class TestDocument(Document): float_fld = FloatField() @@ -23,6 +22,7 @@ class TestFloatField(MongoDBTestCase): def test_validation(self): """Ensure that invalid values cannot be assigned to float fields. """ + class Person(Document): height = FloatField(min_value=0.1, max_value=3.5) @@ -33,7 +33,7 @@ class TestFloatField(MongoDBTestCase): person.height = 1.89 person.validate() - person.height = '2.0' + person.height = "2.0" self.assertRaises(ValidationError, person.validate) person.height = 0.01 @@ -42,7 +42,7 @@ class TestFloatField(MongoDBTestCase): person.height = 4.0 self.assertRaises(ValidationError, person.validate) - person_2 = Person(height='something invalid') + person_2 = Person(height="something invalid") self.assertRaises(ValidationError, person_2.validate) big_person = BigPerson() diff --git a/tests/fields/test_int_field.py b/tests/fields/test_int_field.py index 1b1f7ad9..b7db0416 100644 --- a/tests/fields/test_int_field.py +++ b/tests/fields/test_int_field.py @@ -5,10 +5,10 @@ from tests.utils import MongoDBTestCase class TestIntField(MongoDBTestCase): - def test_int_validation(self): """Ensure that invalid values cannot be assigned to int fields. """ + class Person(Document): age = IntField(min_value=0, max_value=110) @@ -26,7 +26,7 @@ class TestIntField(MongoDBTestCase): self.assertRaises(ValidationError, person.validate) person.age = 120 self.assertRaises(ValidationError, person.validate) - person.age = 'ten' + person.age = "ten" self.assertRaises(ValidationError, person.validate) def test_ne_operator(self): diff --git a/tests/fields/test_lazy_reference_field.py b/tests/fields/test_lazy_reference_field.py index 1d6e6e79..2a686d7f 100644 --- a/tests/fields/test_lazy_reference_field.py +++ b/tests/fields/test_lazy_reference_field.py @@ -25,7 +25,7 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal() oc = Ocurrence(animal=animal) - self.assertIn('LazyReference', repr(oc.animal)) + self.assertIn("LazyReference", repr(oc.animal)) def test___getattr___unknown_attr_raises_attribute_error(self): class Animal(Document): @@ -93,7 +93,7 @@ class TestLazyReferenceField(MongoDBTestCase): def test_lazy_reference_set(self): class Animal(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} name = StringField() tag = StringField() @@ -109,18 +109,17 @@ class TestLazyReferenceField(MongoDBTestCase): nick = StringField() animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() + sub_animal = SubAnimal(nick="doggo", name="dog").save() for ref in ( - animal, - animal.pk, - DBRef(animal._get_collection_name(), animal.pk), - LazyReference(Animal, animal.pk), - - sub_animal, - sub_animal.pk, - DBRef(sub_animal._get_collection_name(), sub_animal.pk), - LazyReference(SubAnimal, sub_animal.pk), - ): + animal, + animal.pk, + DBRef(animal._get_collection_name(), animal.pk), + LazyReference(Animal, animal.pk), + sub_animal, + sub_animal.pk, + DBRef(sub_animal._get_collection_name(), sub_animal.pk), + LazyReference(SubAnimal, sub_animal.pk), + ): p = Ocurrence(person="test", animal=ref).save() p.reload() self.assertIsInstance(p.animal, LazyReference) @@ -144,12 +143,12 @@ class TestLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() baddoc = BadDoc().save() for bad in ( - 42, - 'foo', - baddoc, - DBRef(baddoc._get_collection_name(), animal.pk), - LazyReference(BadDoc, animal.pk) - ): + 42, + "foo", + baddoc, + DBRef(baddoc._get_collection_name(), animal.pk), + LazyReference(BadDoc, animal.pk), + ): with self.assertRaises(ValidationError): p = Ocurrence(person="test", animal=bad).save() @@ -157,6 +156,7 @@ class TestLazyReferenceField(MongoDBTestCase): """Ensure that LazyReferenceFields can be queried using objects and values of the type of the primary key of the referenced object. """ + class Member(Document): user_num = IntField(primary_key=True) @@ -172,10 +172,10 @@ class TestLazyReferenceField(MongoDBTestCase): m2 = Member(user_num=2) m2.save() - post1 = BlogPost(title='post 1', author=m1) + post1 = BlogPost(title="post 1", author=m1) post1.save() - post2 = BlogPost(title='post 2', author=m2) + post2 = BlogPost(title="post 2", author=m2) post2.save() post = BlogPost.objects(author=m1).first() @@ -192,6 +192,7 @@ class TestLazyReferenceField(MongoDBTestCase): """Ensure that LazyReferenceFields can be queried using objects and values of the type of the primary key of the referenced object. """ + class Member(Document): user_num = IntField(primary_key=True) @@ -207,10 +208,10 @@ class TestLazyReferenceField(MongoDBTestCase): m2 = Member(user_num=2) m2.save() - post1 = BlogPost(title='post 1', author=m1) + post1 = BlogPost(title="post 1", author=m1) post1.save() - post2 = BlogPost(title='post 2', author=m2) + post2 = BlogPost(title="post 2", author=m2) post2.save() post = BlogPost.objects(author=m1).first() @@ -240,19 +241,19 @@ class TestLazyReferenceField(MongoDBTestCase): p = Ocurrence.objects.get() self.assertIsInstance(p.animal, LazyReference) with self.assertRaises(KeyError): - p.animal['name'] + p.animal["name"] with self.assertRaises(AttributeError): p.animal.name self.assertEqual(p.animal.pk, animal.pk) self.assertEqual(p.animal_passthrough.name, "Leopard") - self.assertEqual(p.animal_passthrough['name'], "Leopard") + self.assertEqual(p.animal_passthrough["name"], "Leopard") # Should not be able to access referenced document's methods with self.assertRaises(AttributeError): p.animal.save with self.assertRaises(KeyError): - p.animal['save'] + p.animal["save"] def test_lazy_reference_not_set(self): class Animal(Document): @@ -266,7 +267,7 @@ class TestLazyReferenceField(MongoDBTestCase): Animal.drop_collection() Ocurrence.drop_collection() - Ocurrence(person='foo').save() + Ocurrence(person="foo").save() p = Ocurrence.objects.get() self.assertIs(p.animal, None) @@ -303,8 +304,8 @@ class TestLazyReferenceField(MongoDBTestCase): Animal.drop_collection() Ocurrence.drop_collection() - animal1 = Animal(name='doggo').save() - animal2 = Animal(name='cheeta').save() + animal1 = Animal(name="doggo").save() + animal2 = Animal(name="cheeta").save() def check_fields_type(occ): self.assertIsInstance(occ.direct, LazyReference) @@ -316,8 +317,8 @@ class TestLazyReferenceField(MongoDBTestCase): occ = Ocurrence( in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 + in_embedded={"in_list": [animal1, animal2], "direct": animal1}, + direct=animal1, ).save() check_fields_type(occ) occ.reload() @@ -403,7 +404,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): def test_generic_lazy_reference_set(self): class Animal(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} name = StringField() tag = StringField() @@ -419,16 +420,18 @@ class TestGenericLazyReferenceField(MongoDBTestCase): nick = StringField() animal = Animal(name="Leopard", tag="heavy").save() - sub_animal = SubAnimal(nick='doggo', name='dog').save() + sub_animal = SubAnimal(nick="doggo", name="dog").save() for ref in ( - animal, - LazyReference(Animal, animal.pk), - {'_cls': 'Animal', '_ref': DBRef(animal._get_collection_name(), animal.pk)}, - - sub_animal, - LazyReference(SubAnimal, sub_animal.pk), - {'_cls': 'SubAnimal', '_ref': DBRef(sub_animal._get_collection_name(), sub_animal.pk)}, - ): + animal, + LazyReference(Animal, animal.pk), + {"_cls": "Animal", "_ref": DBRef(animal._get_collection_name(), animal.pk)}, + sub_animal, + LazyReference(SubAnimal, sub_animal.pk), + { + "_cls": "SubAnimal", + "_ref": DBRef(sub_animal._get_collection_name(), sub_animal.pk), + }, + ): p = Ocurrence(person="test", animal=ref).save() p.reload() self.assertIsInstance(p.animal, (LazyReference, Document)) @@ -441,7 +444,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): class Ocurrence(Document): person = StringField() - animal = GenericLazyReferenceField(choices=['Animal']) + animal = GenericLazyReferenceField(choices=["Animal"]) Animal.drop_collection() Ocurrence.drop_collection() @@ -451,12 +454,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): animal = Animal(name="Leopard", tag="heavy").save() baddoc = BadDoc().save() - for bad in ( - 42, - 'foo', - baddoc, - LazyReference(BadDoc, animal.pk) - ): + for bad in (42, "foo", baddoc, LazyReference(BadDoc, animal.pk)): with self.assertRaises(ValidationError): p = Ocurrence(person="test", animal=bad).save() @@ -476,10 +474,10 @@ class TestGenericLazyReferenceField(MongoDBTestCase): m2 = Member(user_num=2) m2.save() - post1 = BlogPost(title='post 1', author=m1) + post1 = BlogPost(title="post 1", author=m1) post1.save() - post2 = BlogPost(title='post 2', author=m2) + post2 = BlogPost(title="post 2", author=m2) post2.save() post = BlogPost.objects(author=m1).first() @@ -504,7 +502,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): Animal.drop_collection() Ocurrence.drop_collection() - Ocurrence(person='foo').save() + Ocurrence(person="foo").save() p = Ocurrence.objects.get() self.assertIs(p.animal, None) @@ -515,7 +513,7 @@ class TestGenericLazyReferenceField(MongoDBTestCase): class Ocurrence(Document): person = StringField() - animal = GenericLazyReferenceField('Animal') + animal = GenericLazyReferenceField("Animal") Animal.drop_collection() Ocurrence.drop_collection() @@ -542,8 +540,8 @@ class TestGenericLazyReferenceField(MongoDBTestCase): Animal.drop_collection() Ocurrence.drop_collection() - animal1 = Animal(name='doggo').save() - animal2 = Animal(name='cheeta').save() + animal1 = Animal(name="doggo").save() + animal2 = Animal(name="cheeta").save() def check_fields_type(occ): self.assertIsInstance(occ.direct, LazyReference) @@ -555,14 +553,20 @@ class TestGenericLazyReferenceField(MongoDBTestCase): occ = Ocurrence( in_list=[animal1, animal2], - in_embedded={'in_list': [animal1, animal2], 'direct': animal1}, - direct=animal1 + in_embedded={"in_list": [animal1, animal2], "direct": animal1}, + direct=animal1, ).save() check_fields_type(occ) occ.reload() check_fields_type(occ) - animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)} - animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)} + animal1_ref = { + "_cls": "Animal", + "_ref": DBRef(animal1._get_collection_name(), animal1.pk), + } + animal2_ref = { + "_cls": "Animal", + "_ref": DBRef(animal2._get_collection_name(), animal2.pk), + } occ.direct = animal1_ref occ.in_list = [animal1_ref, animal2_ref] occ.in_embedded.direct = animal1_ref diff --git a/tests/fields/test_long_field.py b/tests/fields/test_long_field.py index 3f307809..ab86eccd 100644 --- a/tests/fields/test_long_field.py +++ b/tests/fields/test_long_field.py @@ -13,23 +13,26 @@ from tests.utils import MongoDBTestCase class TestLongField(MongoDBTestCase): - def test_long_field_is_considered_as_int64(self): """ Tests that long fields are stored as long in mongo, even if long value is small enough to be an int. """ + class TestLongFieldConsideredAsInt64(Document): some_long = LongField() doc = TestLongFieldConsideredAsInt64(some_long=42).save() db = get_db() - self.assertIsInstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64) + self.assertIsInstance( + db.test_long_field_considered_as_int64.find()[0]["some_long"], Int64 + ) self.assertIsInstance(doc.some_long, six.integer_types) def test_long_validation(self): """Ensure that invalid values cannot be assigned to long fields. """ + class TestDocument(Document): value = LongField(min_value=0, max_value=110) @@ -41,7 +44,7 @@ class TestLongField(MongoDBTestCase): self.assertRaises(ValidationError, doc.validate) doc.value = 120 self.assertRaises(ValidationError, doc.validate) - doc.value = 'ten' + doc.value = "ten" self.assertRaises(ValidationError, doc.validate) def test_long_ne_operator(self): diff --git a/tests/fields/test_map_field.py b/tests/fields/test_map_field.py index cb27cfff..54f70aa1 100644 --- a/tests/fields/test_map_field.py +++ b/tests/fields/test_map_field.py @@ -7,23 +7,24 @@ from tests.utils import MongoDBTestCase class TestMapField(MongoDBTestCase): - def test_mapfield(self): """Ensure that the MapField handles the declared type.""" + class Simple(Document): mapping = MapField(IntField()) Simple.drop_collection() e = Simple() - e.mapping['someint'] = 1 + e.mapping["someint"] = 1 e.save() with self.assertRaises(ValidationError): - e.mapping['somestring'] = "abc" + e.mapping["somestring"] = "abc" e.save() with self.assertRaises(ValidationError): + class NoDeclaredType(Document): mapping = MapField() @@ -45,38 +46,37 @@ class TestMapField(MongoDBTestCase): Extensible.drop_collection() e = Extensible() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) + e.mapping["somestring"] = StringSetting(value="foo") + e.mapping["someint"] = IntegerSetting(value=42) e.save() e2 = Extensible.objects.get(id=e.id) - self.assertIsInstance(e2.mapping['somestring'], StringSetting) - self.assertIsInstance(e2.mapping['someint'], IntegerSetting) + self.assertIsInstance(e2.mapping["somestring"], StringSetting) + self.assertIsInstance(e2.mapping["someint"], IntegerSetting) with self.assertRaises(ValidationError): - e.mapping['someint'] = 123 + e.mapping["someint"] = 123 e.save() def test_embedded_mapfield_db_field(self): class Embedded(EmbeddedDocument): - number = IntField(default=0, db_field='i') + number = IntField(default=0, db_field="i") class Test(Document): - my_map = MapField(field=EmbeddedDocumentField(Embedded), - db_field='x') + my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field="x") Test.drop_collection() test = Test() - test.my_map['DICTIONARY_KEY'] = Embedded(number=1) + test.my_map["DICTIONARY_KEY"] = Embedded(number=1) test.save() Test.objects.update_one(inc__my_map__DICTIONARY_KEY__number=1) test = Test.objects.get() - self.assertEqual(test.my_map['DICTIONARY_KEY'].number, 2) + self.assertEqual(test.my_map["DICTIONARY_KEY"].number, 2) doc = self.db.test.find_one() - self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) + self.assertEqual(doc["x"]["DICTIONARY_KEY"]["i"], 2) def test_mapfield_numerical_index(self): """Ensure that MapField accept numeric strings as indexes.""" @@ -90,9 +90,9 @@ class TestMapField(MongoDBTestCase): Test.drop_collection() test = Test() - test.my_map['1'] = Embedded(name='test') + test.my_map["1"] = Embedded(name="test") test.save() - test.my_map['1'].name = 'test updated' + test.my_map["1"].name = "test updated" test.save() def test_map_field_lookup(self): @@ -110,15 +110,20 @@ class TestMapField(MongoDBTestCase): actions = MapField(EmbeddedDocumentField(Action)) Log.drop_collection() - Log(name="wilson", visited={'friends': datetime.datetime.now()}, - actions={'friends': Action(operation='drink', object='beer')}).save() + Log( + name="wilson", + visited={"friends": datetime.datetime.now()}, + actions={"friends": Action(operation="drink", object="beer")}, + ).save() - self.assertEqual(1, Log.objects( - visited__friends__exists=True).count()) + self.assertEqual(1, Log.objects(visited__friends__exists=True).count()) - self.assertEqual(1, Log.objects( - actions__friends__operation='drink', - actions__friends__object='beer').count()) + self.assertEqual( + 1, + Log.objects( + actions__friends__operation="drink", actions__friends__object="beer" + ).count(), + ) def test_map_field_unicode(self): class Info(EmbeddedDocument): @@ -130,15 +135,11 @@ class TestMapField(MongoDBTestCase): BlogPost.drop_collection() - tree = BlogPost(info_dict={ - u"éééé": { - 'description': u"VALUE: éééé" - } - }) + tree = BlogPost(info_dict={u"éééé": {"description": u"VALUE: éééé"}}) tree.save() self.assertEqual( BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, - u"VALUE: éééé" + u"VALUE: éééé", ) diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index 5e1fc605..5fd053fe 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -26,15 +26,15 @@ class TestReferenceField(MongoDBTestCase): # with a document class name. self.assertRaises(ValidationError, ReferenceField, EmbeddedDocument) - user = User(name='Test User') + user = User(name="Test User") # Ensure that the referenced object must have been saved - post1 = BlogPost(content='Chips and gravy taste good.') + post1 = BlogPost(content="Chips and gravy taste good.") post1.author = user self.assertRaises(ValidationError, post1.save) # Check that an invalid object type cannot be used - post2 = BlogPost(content='Chips and chilli taste good.') + post2 = BlogPost(content="Chips and chilli taste good.") post1.author = post2 self.assertRaises(ValidationError, post1.validate) @@ -59,7 +59,7 @@ class TestReferenceField(MongoDBTestCase): class Person(Document): name = StringField() - parent = ReferenceField('self') + parent = ReferenceField("self") Person.drop_collection() @@ -74,7 +74,7 @@ class TestReferenceField(MongoDBTestCase): class Person(Document): name = StringField() - parent = ReferenceField('self', dbref=True) + parent = ReferenceField("self", dbref=True) Person.drop_collection() @@ -82,8 +82,8 @@ class TestReferenceField(MongoDBTestCase): Person(name="Ross", parent=p1).save() self.assertEqual( - Person._get_collection().find_one({'name': 'Ross'})['parent'], - DBRef('person', p1.pk) + Person._get_collection().find_one({"name": "Ross"})["parent"], + DBRef("person", p1.pk), ) p = Person.objects.get(name="Ross") @@ -97,21 +97,17 @@ class TestReferenceField(MongoDBTestCase): class Person(Document): name = StringField() - parent = ReferenceField('self', dbref=False) + parent = ReferenceField("self", dbref=False) - p = Person( - name='Steve', - parent=DBRef('person', 'abcdefghijklmnop') + p = Person(name="Steve", parent=DBRef("person", "abcdefghijklmnop")) + self.assertEqual( + p.to_mongo(), SON([("name", u"Steve"), ("parent", "abcdefghijklmnop")]) ) - self.assertEqual(p.to_mongo(), SON([ - ('name', u'Steve'), - ('parent', 'abcdefghijklmnop') - ])) def test_objectid_reference_fields(self): class Person(Document): name = StringField() - parent = ReferenceField('self', dbref=False) + parent = ReferenceField("self", dbref=False) Person.drop_collection() @@ -119,8 +115,8 @@ class TestReferenceField(MongoDBTestCase): Person(name="Ross", parent=p1).save() col = Person._get_collection() - data = col.find_one({'name': 'Ross'}) - self.assertEqual(data['parent'], p1.pk) + data = col.find_one({"name": "Ross"}) + self.assertEqual(data["parent"], p1.pk) p = Person.objects.get(name="Ross") self.assertEqual(p.parent, p1) @@ -128,9 +124,10 @@ class TestReferenceField(MongoDBTestCase): def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. """ + class Product(Document): name = StringField() - company = ReferenceField('Company') + company = ReferenceField("Company") class Company(Document): name = StringField() @@ -138,12 +135,12 @@ class TestReferenceField(MongoDBTestCase): Product.drop_collection() Company.drop_collection() - ten_gen = Company(name='10gen') + ten_gen = Company(name="10gen") ten_gen.save() - mongodb = Product(name='MongoDB', company=ten_gen) + mongodb = Product(name="MongoDB", company=ten_gen) mongodb.save() - me = Product(name='MongoEngine') + me = Product(name="MongoEngine") me.save() obj = Product.objects(company=ten_gen).first() @@ -160,6 +157,7 @@ class TestReferenceField(MongoDBTestCase): """Ensure that ReferenceFields can be queried using objects and values of the type of the primary key of the referenced object. """ + class Member(Document): user_num = IntField(primary_key=True) @@ -175,10 +173,10 @@ class TestReferenceField(MongoDBTestCase): m2 = Member(user_num=2) m2.save() - post1 = BlogPost(title='post 1', author=m1) + post1 = BlogPost(title="post 1", author=m1) post1.save() - post2 = BlogPost(title='post 2', author=m2) + post2 = BlogPost(title="post 2", author=m2) post2.save() post = BlogPost.objects(author=m1).first() @@ -191,6 +189,7 @@ class TestReferenceField(MongoDBTestCase): """Ensure that ReferenceFields can be queried using objects and values of the type of the primary key of the referenced object. """ + class Member(Document): user_num = IntField(primary_key=True) @@ -206,10 +205,10 @@ class TestReferenceField(MongoDBTestCase): m2 = Member(user_num=2) m2.save() - post1 = BlogPost(title='post 1', author=m1) + post1 = BlogPost(title="post 1", author=m1) post1.save() - post2 = BlogPost(title='post 2', author=m2) + post2 = BlogPost(title="post 2", author=m2) post2.save() post = BlogPost.objects(author=m1).first() diff --git a/tests/fields/test_sequence_field.py b/tests/fields/test_sequence_field.py index 6124c65e..f2c8388b 100644 --- a/tests/fields/test_sequence_field.py +++ b/tests/fields/test_sequence_field.py @@ -11,38 +11,38 @@ class TestSequenceField(MongoDBTestCase): id = SequenceField(primary_key=True) name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): Person(name="Person %s" % x).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) ids = [i.id for i in Person.objects] self.assertEqual(ids, range(1, 11)) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 1000) def test_sequence_field_get_next_value(self): class Person(Document): id = SequenceField(primary_key=True) name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): Person(name="Person %s" % x).save() self.assertEqual(Person.id.get_next_value(), 11) - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() self.assertEqual(Person.id.get_next_value(), 1) @@ -50,40 +50,40 @@ class TestSequenceField(MongoDBTestCase): id = SequenceField(primary_key=True, value_decorator=str) name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): Person(name="Person %s" % x).save() - self.assertEqual(Person.id.get_next_value(), '11') - self.db['mongoengine.counters'].drop() + self.assertEqual(Person.id.get_next_value(), "11") + self.db["mongoengine.counters"].drop() - self.assertEqual(Person.id.get_next_value(), '1') + self.assertEqual(Person.id.get_next_value(), "1") def test_sequence_field_sequence_name(self): class Person(Document): - id = SequenceField(primary_key=True, sequence_name='jelly') + id = SequenceField(primary_key=True, sequence_name="jelly") name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): Person(name="Person %s" % x).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) + self.assertEqual(c["next"], 10) ids = [i.id for i in Person.objects] self.assertEqual(ids, range(1, 11)) - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) + self.assertEqual(c["next"], 10) Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) - self.assertEqual(c['next'], 1000) + c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) + self.assertEqual(c["next"], 1000) def test_multiple_sequence_fields(self): class Person(Document): @@ -91,14 +91,14 @@ class TestSequenceField(MongoDBTestCase): counter = SequenceField() name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): Person(name="Person %s" % x).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) ids = [i.id for i in Person.objects] self.assertEqual(ids, range(1, 11)) @@ -106,23 +106,23 @@ class TestSequenceField(MongoDBTestCase): counters = [i.counter for i in Person.objects] self.assertEqual(counters, range(1, 11)) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) Person.id.set_next_value(1000) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 1000) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 1000) Person.counter.set_next_value(999) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.counter'}) - self.assertEqual(c['next'], 999) + c = self.db["mongoengine.counters"].find_one({"_id": "person.counter"}) + self.assertEqual(c["next"], 999) def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Animal.drop_collection() a = Animal(name="Boi").save() @@ -151,7 +151,7 @@ class TestSequenceField(MongoDBTestCase): id = SequenceField(primary_key=True) name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Animal.drop_collection() Person.drop_collection() @@ -159,11 +159,11 @@ class TestSequenceField(MongoDBTestCase): Animal(name="Animal %s" % x).save() Person(name="Person %s" % x).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) + self.assertEqual(c["next"], 10) ids = [i.id for i in Person.objects] self.assertEqual(ids, range(1, 11)) @@ -171,32 +171,32 @@ class TestSequenceField(MongoDBTestCase): id = [i.id for i in Animal.objects] self.assertEqual(id, range(1, 11)) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) - c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "animal.id"}) + self.assertEqual(c["next"], 10) def test_sequence_field_value_decorator(self): class Person(Document): id = SequenceField(primary_key=True, value_decorator=str) name = StringField() - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Person.drop_collection() for x in range(10): p = Person(name="Person %s" % x) p.save() - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) ids = [i.id for i in Person.objects] self.assertEqual(ids, map(str, range(1, 11))) - c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) - self.assertEqual(c['next'], 10) + c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) + self.assertEqual(c["next"], 10) def test_embedded_sequence_field(self): class Comment(EmbeddedDocument): @@ -207,14 +207,18 @@ class TestSequenceField(MongoDBTestCase): title = StringField(required=True) comments = ListField(EmbeddedDocumentField(Comment)) - self.db['mongoengine.counters'].drop() + self.db["mongoengine.counters"].drop() Post.drop_collection() - Post(title="MongoEngine", - comments=[Comment(content="NoSQL Rocks"), - Comment(content="MongoEngine Rocks")]).save() - c = self.db['mongoengine.counters'].find_one({'_id': 'comment.id'}) - self.assertEqual(c['next'], 2) + Post( + title="MongoEngine", + comments=[ + Comment(content="NoSQL Rocks"), + Comment(content="MongoEngine Rocks"), + ], + ).save() + c = self.db["mongoengine.counters"].find_one({"_id": "comment.id"}) + self.assertEqual(c["next"], 2) post = Post.objects.first() self.assertEqual(1, post.comments[0].id) self.assertEqual(2, post.comments[1].id) @@ -223,7 +227,7 @@ class TestSequenceField(MongoDBTestCase): class Base(Document): name = StringField() counter = SequenceField() - meta = {'abstract': True} + meta = {"abstract": True} class Foo(Base): pass @@ -231,24 +235,27 @@ class TestSequenceField(MongoDBTestCase): class Bar(Base): pass - bar = Bar(name='Bar') + bar = Bar(name="Bar") bar.save() - foo = Foo(name='Foo') + foo = Foo(name="Foo") foo.save() - self.assertTrue('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertFalse(('foo.counter' or 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) + self.assertTrue( + "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") + ) + self.assertFalse( + ("foo.counter" or "bar.counter") + in self.db["mongoengine.counters"].find().distinct("_id") + ) self.assertNotEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Base) - self.assertEqual(bar._fields['counter'].owner_document, Base) + self.assertEqual(foo._fields["counter"].owner_document, Base) + self.assertEqual(bar._fields["counter"].owner_document, Base) def test_no_inherited_sequencefield(self): class Base(Document): name = StringField() - meta = {'abstract': True} + meta = {"abstract": True} class Foo(Base): counter = SequenceField() @@ -256,16 +263,19 @@ class TestSequenceField(MongoDBTestCase): class Bar(Base): counter = SequenceField() - bar = Bar(name='Bar') + bar = Bar(name="Bar") bar.save() - foo = Foo(name='Foo') + foo = Foo(name="Foo") foo.save() - self.assertFalse('base.counter' in - self.db['mongoengine.counters'].find().distinct('_id')) - self.assertTrue(('foo.counter' and 'bar.counter') in - self.db['mongoengine.counters'].find().distinct('_id')) + self.assertFalse( + "base.counter" in self.db["mongoengine.counters"].find().distinct("_id") + ) + self.assertTrue( + ("foo.counter" and "bar.counter") + in self.db["mongoengine.counters"].find().distinct("_id") + ) self.assertEqual(foo.counter, bar.counter) - self.assertEqual(foo._fields['counter'].owner_document, Foo) - self.assertEqual(bar._fields['counter'].owner_document, Bar) + self.assertEqual(foo._fields["counter"].owner_document, Foo) + self.assertEqual(bar._fields["counter"].owner_document, Bar) diff --git a/tests/fields/test_url_field.py b/tests/fields/test_url_field.py index ddbf707e..81baf8d0 100644 --- a/tests/fields/test_url_field.py +++ b/tests/fields/test_url_field.py @@ -5,49 +5,53 @@ from tests.utils import MongoDBTestCase class TestURLField(MongoDBTestCase): - def test_validation(self): """Ensure that URLFields validate urls properly.""" + class Link(Document): url = URLField() link = Link() - link.url = 'google' + link.url = "google" self.assertRaises(ValidationError, link.validate) - link.url = 'http://www.google.com:8080' + link.url = "http://www.google.com:8080" link.validate() def test_unicode_url_validation(self): """Ensure unicode URLs are validated properly.""" + class Link(Document): url = URLField() link = Link() - link.url = u'http://привет.com' + link.url = u"http://привет.com" # TODO fix URL validation - this *IS* a valid URL # For now we just want to make sure that the error message is correct with self.assertRaises(ValidationError) as ctx_err: link.validate() - self.assertEqual(unicode(ctx_err.exception), - u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])") + self.assertEqual( + unicode(ctx_err.exception), + u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])", + ) def test_url_scheme_validation(self): """Ensure that URLFields validate urls with specific schemes properly. """ + class Link(Document): url = URLField() class SchemeLink(Document): - url = URLField(schemes=['ws', 'irc']) + url = URLField(schemes=["ws", "irc"]) link = Link() - link.url = 'ws://google.com' + link.url = "ws://google.com" self.assertRaises(ValidationError, link.validate) scheme_link = SchemeLink() - scheme_link.url = 'ws://google.com' + scheme_link.url = "ws://google.com" scheme_link.validate() def test_underscore_allowed_in_domains_names(self): @@ -55,5 +59,5 @@ class TestURLField(MongoDBTestCase): url = URLField() link = Link() - link.url = 'https://san_leandro-ca.geebo.com' + link.url = "https://san_leandro-ca.geebo.com" link.validate() diff --git a/tests/fields/test_uuid_field.py b/tests/fields/test_uuid_field.py index 7b7faaf2..647dceaf 100644 --- a/tests/fields/test_uuid_field.py +++ b/tests/fields/test_uuid_field.py @@ -15,11 +15,8 @@ class TestUUIDField(MongoDBTestCase): uid = uuid.uuid4() person = Person(api_key=uid).save() self.assertEqual( - get_as_pymongo(person), - {'_id': person.id, - 'api_key': str(uid) - } - ) + get_as_pymongo(person), {"_id": person.id, "api_key": str(uid)} + ) def test_field_string(self): """Test UUID fields storing as String @@ -37,8 +34,10 @@ class TestUUIDField(MongoDBTestCase): person.api_key = api_key person.validate() - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') + invalid = ( + "9d159858-549b-4975-9f98-dd2f987c113g", + "9d159858-549b-4975-9f98-dd2f987c113", + ) for api_key in invalid: person.api_key = api_key self.assertRaises(ValidationError, person.validate) @@ -58,8 +57,10 @@ class TestUUIDField(MongoDBTestCase): person.api_key = api_key person.validate() - invalid = ('9d159858-549b-4975-9f98-dd2f987c113g', - '9d159858-549b-4975-9f98-dd2f987c113') + invalid = ( + "9d159858-549b-4975-9f98-dd2f987c113g", + "9d159858-549b-4975-9f98-dd2f987c113", + ) for api_key in invalid: person.api_key = api_key self.assertRaises(ValidationError, person.validate) diff --git a/tests/fixtures.py b/tests/fixtures.py index b8303b99..9f06f1ab 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -11,7 +11,7 @@ class PickleEmbedded(EmbeddedDocument): class PickleTest(Document): number = IntField() - string = StringField(choices=(('One', '1'), ('Two', '2'))) + string = StringField(choices=(("One", "1"), ("Two", "2"))) embedded = EmbeddedDocumentField(PickleEmbedded) lists = ListField(StringField()) photo = FileField() @@ -19,7 +19,7 @@ class PickleTest(Document): class NewDocumentPickleTest(Document): number = IntField() - string = StringField(choices=(('One', '1'), ('Two', '2'))) + string = StringField(choices=(("One", "1"), ("Two", "2"))) embedded = EmbeddedDocumentField(PickleEmbedded) lists = ListField(StringField()) photo = FileField() @@ -36,7 +36,7 @@ class PickleDynamicTest(DynamicDocument): class PickleSignalsTest(Document): number = IntField() - string = StringField(choices=(('One', '1'), ('Two', '2'))) + string = StringField(choices=(("One", "1"), ("Two", "2"))) embedded = EmbeddedDocumentField(PickleEmbedded) lists = ListField(StringField()) @@ -58,4 +58,4 @@ class Mixin(object): class Base(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 250e2601..9f0fe827 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -7,79 +7,78 @@ __all__ = ("QueryFieldListTest", "OnlyExcludeAllTest") class QueryFieldListTest(unittest.TestCase): - def test_empty(self): q = QueryFieldList() self.assertFalse(q) - q = QueryFieldList(always_include=['_cls']) + q = QueryFieldList(always_include=["_cls"]) self.assertFalse(q) def test_include_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY, _only_called=True) - self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + q += QueryFieldList( + fields=["a", "b"], value=QueryFieldList.ONLY, _only_called=True + ) + self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"a": 1, "b": 1, "c": 1}) def test_include_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'a': 1, 'b': 1}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': 1}) + q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"a": 1, "b": 1}) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {"a": 1}) def test_exclude_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': 0, 'b': 0, 'c': 0}) + q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {"a": 0, "b": 0, "c": 0}) def test_exclude_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) - self.assertEqual(q.as_dict(), {'a': 0, 'b': 0}) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'c': 1}) + q += QueryFieldList(fields=["a", "b"], value=QueryFieldList.EXCLUDE) + self.assertEqual(q.as_dict(), {"a": 0, "b": 0}) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"c": 1}) def test_always_include(self): - q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) + q = QueryFieldList(always_include=["x", "y"]) + q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) def test_reset(self): - q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'c': 1}) + q = QueryFieldList(always_include=["x", "y"]) + q += QueryFieldList(fields=["a", "b", "x"], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "c": 1}) q.reset() self.assertFalse(q) - q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) - self.assertEqual(q.as_dict(), {'x': 1, 'y': 1, 'b': 1, 'c': 1}) + q += QueryFieldList(fields=["b", "c"], value=QueryFieldList.ONLY) + self.assertEqual(q.as_dict(), {"x": 1, "y": 1, "b": 1, "c": 1}) def test_using_a_slice(self): q = QueryFieldList() - q += QueryFieldList(fields=['a'], value={"$slice": 5}) - self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) + q += QueryFieldList(fields=["a"], value={"$slice": 5}) + self.assertEqual(q.as_dict(), {"a": {"$slice": 5}}) class OnlyExcludeAllTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") class Person(Document): name = StringField() age = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} Person.drop_collection() self.Person = Person def test_mixing_only_exclude(self): - class MyDoc(Document): a = StringField() b = StringField() @@ -88,32 +87,32 @@ class OnlyExcludeAllTest(unittest.TestCase): e = StringField() f = StringField() - include = ['a', 'b', 'c', 'd', 'e'] - exclude = ['d', 'e'] - only = ['b', 'c'] + include = ["a", "b", "c", "d", "e"] + exclude = ["d", "e"] + only = ["b", "c"] qs = MyDoc.objects.fields(**{i: 1 for i in include}) - self.assertEqual(qs._loaded_fields.as_dict(), - {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) + self.assertEqual( + qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1, "d": 1, "e": 1} + ) qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) - self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) qs = MyDoc.objects.exclude(*exclude) qs = qs.fields(**{i: 1 for i in include}) - self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"a": 1, "b": 1, "c": 1}) qs = qs.only(*only) - self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"b": 1, "c": 1}) def test_slicing(self): - class MyDoc(Document): a = ListField() b = ListField() @@ -122,24 +121,23 @@ class OnlyExcludeAllTest(unittest.TestCase): e = ListField() f = ListField() - include = ['a', 'b', 'c', 'd', 'e'] - exclude = ['d', 'e'] - only = ['b', 'c'] + include = ["a", "b", "c", "d", "e"] + exclude = ["d", "e"] + only = ["b", "c"] qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) qs = qs.only(*only) qs = qs.fields(slice__b=5) - self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}, 'c': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": 1}) qs = qs.fields(slice__c=[5, 1]) - self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}, 'c': {'$slice': [5, 1]}}) + self.assertEqual( + qs._loaded_fields.as_dict(), {"b": {"$slice": 5}, "c": {"$slice": [5, 1]}} + ) - qs = qs.exclude('c') - self.assertEqual(qs._loaded_fields.as_dict(), - {'b': {'$slice': 5}}) + qs = qs.exclude("c") + self.assertEqual(qs._loaded_fields.as_dict(), {"b": {"$slice": 5}}) def test_mix_slice_with_other_fields(self): class MyDoc(Document): @@ -148,43 +146,42 @@ class OnlyExcludeAllTest(unittest.TestCase): c = ListField() qs = MyDoc.objects.fields(a=1, b=0, slice__c=2) - self.assertEqual(qs._loaded_fields.as_dict(), - {'c': {'$slice': 2}, 'a': 1}) + self.assertEqual(qs._loaded_fields.as_dict(), {"c": {"$slice": 2}, "a": 1}) def test_only(self): """Ensure that QuerySet.only only returns the requested fields. """ - person = self.Person(name='test', age=25) + person = self.Person(name="test", age=25) person.save() - obj = self.Person.objects.only('name').get() + obj = self.Person.objects.only("name").get() self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, None) - obj = self.Person.objects.only('age').get() + obj = self.Person.objects.only("age").get() self.assertEqual(obj.name, None) self.assertEqual(obj.age, person.age) - obj = self.Person.objects.only('name', 'age').get() + obj = self.Person.objects.only("name", "age").get() self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, person.age) - obj = self.Person.objects.only(*('id', 'name',)).get() + obj = self.Person.objects.only(*("id", "name")).get() self.assertEqual(obj.name, person.name) self.assertEqual(obj.age, None) # Check polymorphism still works class Employee(self.Person): - salary = IntField(db_field='wage') + salary = IntField(db_field="wage") - employee = Employee(name='test employee', age=40, salary=30000) + employee = Employee(name="test employee", age=40, salary=30000) employee.save() - obj = self.Person.objects(id=employee.id).only('age').get() + obj = self.Person.objects(id=employee.id).only("age").get() self.assertIsInstance(obj, Employee) # Check field names are looked up properly - obj = Employee.objects(id=employee.id).only('salary').get() + obj = Employee.objects(id=employee.id).only("salary").get() self.assertEqual(obj.salary, employee.salary) self.assertEqual(obj.name, None) @@ -208,35 +205,41 @@ class OnlyExcludeAllTest(unittest.TestCase): BlogPost.drop_collection() - post = BlogPost(content='Had a good coffee today...', various={'test_dynamic': {'some': True}}) - post.author = User(name='Test User') - post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post = BlogPost( + content="Had a good coffee today...", + various={"test_dynamic": {"some": True}}, + ) + post.author = User(name="Test User") + post.comments = [ + Comment(title="I aggree", text="Great post!"), + Comment(title="Coffee", text="I hate coffee"), + ] post.save() - obj = BlogPost.objects.only('author.name',).get() + obj = BlogPost.objects.only("author.name").get() self.assertEqual(obj.content, None) self.assertEqual(obj.author.email, None) - self.assertEqual(obj.author.name, 'Test User') + self.assertEqual(obj.author.name, "Test User") self.assertEqual(obj.comments, []) - obj = BlogPost.objects.only('various.test_dynamic.some').get() + obj = BlogPost.objects.only("various.test_dynamic.some").get() self.assertEqual(obj.various["test_dynamic"].some, True) - obj = BlogPost.objects.only('content', 'comments.title',).get() - self.assertEqual(obj.content, 'Had a good coffee today...') + obj = BlogPost.objects.only("content", "comments.title").get() + self.assertEqual(obj.content, "Had a good coffee today...") self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, 'I aggree') - self.assertEqual(obj.comments[1].title, 'Coffee') + self.assertEqual(obj.comments[0].title, "I aggree") + self.assertEqual(obj.comments[1].title, "Coffee") self.assertEqual(obj.comments[0].text, None) self.assertEqual(obj.comments[1].text, None) - obj = BlogPost.objects.only('comments',).get() + obj = BlogPost.objects.only("comments").get() self.assertEqual(obj.content, None) self.assertEqual(obj.author, None) - self.assertEqual(obj.comments[0].title, 'I aggree') - self.assertEqual(obj.comments[1].title, 'Coffee') - self.assertEqual(obj.comments[0].text, 'Great post!') - self.assertEqual(obj.comments[1].text, 'I hate coffee') + self.assertEqual(obj.comments[0].title, "I aggree") + self.assertEqual(obj.comments[1].title, "Coffee") + self.assertEqual(obj.comments[0].text, "Great post!") + self.assertEqual(obj.comments[1].text, "I hate coffee") BlogPost.drop_collection() @@ -256,15 +259,18 @@ class OnlyExcludeAllTest(unittest.TestCase): BlogPost.drop_collection() - post = BlogPost(content='Had a good coffee today...') - post.author = User(name='Test User') - post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] + post = BlogPost(content="Had a good coffee today...") + post.author = User(name="Test User") + post.comments = [ + Comment(title="I aggree", text="Great post!"), + Comment(title="Coffee", text="I hate coffee"), + ] post.save() - obj = BlogPost.objects.exclude('author', 'comments.text').get() + obj = BlogPost.objects.exclude("author", "comments.text").get() self.assertEqual(obj.author, None) - self.assertEqual(obj.content, 'Had a good coffee today...') - self.assertEqual(obj.comments[0].title, 'I aggree') + self.assertEqual(obj.content, "Had a good coffee today...") + self.assertEqual(obj.comments[0].title, "I aggree") self.assertEqual(obj.comments[0].text, None) BlogPost.drop_collection() @@ -283,32 +289,43 @@ class OnlyExcludeAllTest(unittest.TestCase): attachments = ListField(EmbeddedDocumentField(Attachment)) Email.drop_collection() - email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email = Email( + sender="me", + to="you", + subject="From Russia with Love", + body="Hello!", + content_type="text/plain", + ) email.attachments = [ - Attachment(name='file1.doc', content='ABC'), - Attachment(name='file2.doc', content='XYZ'), + Attachment(name="file1.doc", content="ABC"), + Attachment(name="file2.doc", content="XYZ"), ] email.save() - obj = Email.objects.exclude('content_type').exclude('body').get() - self.assertEqual(obj.sender, 'me') - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, 'From Russia with Love') + obj = Email.objects.exclude("content_type").exclude("body").get() + self.assertEqual(obj.sender, "me") + self.assertEqual(obj.to, "you") + self.assertEqual(obj.subject, "From Russia with Love") self.assertEqual(obj.body, None) self.assertEqual(obj.content_type, None) - obj = Email.objects.only('sender', 'to').exclude('body', 'sender').get() + obj = Email.objects.only("sender", "to").exclude("body", "sender").get() self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, 'you') + self.assertEqual(obj.to, "you") self.assertEqual(obj.subject, None) self.assertEqual(obj.body, None) self.assertEqual(obj.content_type, None) - obj = Email.objects.exclude('attachments.content').exclude('body').only('to', 'attachments.name').get() - self.assertEqual(obj.attachments[0].name, 'file1.doc') + obj = ( + Email.objects.exclude("attachments.content") + .exclude("body") + .only("to", "attachments.name") + .get() + ) + self.assertEqual(obj.attachments[0].name, "file1.doc") self.assertEqual(obj.attachments[0].content, None) self.assertEqual(obj.sender, None) - self.assertEqual(obj.to, 'you') + self.assertEqual(obj.to, "you") self.assertEqual(obj.subject, None) self.assertEqual(obj.body, None) self.assertEqual(obj.content_type, None) @@ -316,7 +333,6 @@ class OnlyExcludeAllTest(unittest.TestCase): Email.drop_collection() def test_all_fields(self): - class Email(Document): sender = StringField() to = StringField() @@ -326,21 +342,33 @@ class OnlyExcludeAllTest(unittest.TestCase): Email.drop_collection() - email = Email(sender='me', to='you', subject='From Russia with Love', body='Hello!', content_type='text/plain') + email = Email( + sender="me", + to="you", + subject="From Russia with Love", + body="Hello!", + content_type="text/plain", + ) email.save() - obj = Email.objects.exclude('content_type', 'body').only('to', 'body').all_fields().get() - self.assertEqual(obj.sender, 'me') - self.assertEqual(obj.to, 'you') - self.assertEqual(obj.subject, 'From Russia with Love') - self.assertEqual(obj.body, 'Hello!') - self.assertEqual(obj.content_type, 'text/plain') + obj = ( + Email.objects.exclude("content_type", "body") + .only("to", "body") + .all_fields() + .get() + ) + self.assertEqual(obj.sender, "me") + self.assertEqual(obj.to, "you") + self.assertEqual(obj.subject, "From Russia with Love") + self.assertEqual(obj.body, "Hello!") + self.assertEqual(obj.content_type, "text/plain") Email.drop_collection() def test_slicing_fields(self): """Ensure that query slicing an array works. """ + class Numbers(Document): n = ListField(IntField()) @@ -414,11 +442,10 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertEqual(numbers.embedded.n, [-5, -4, -3, -2, -1]) def test_exclude_from_subclasses_docs(self): - class Base(Document): username = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Anon(Base): anon = BooleanField() @@ -436,5 +463,5 @@ class OnlyExcludeAllTest(unittest.TestCase): self.assertRaises(LookUpError, Base.objects.exclude, "made_up") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 45e6a089..95dc913d 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -10,9 +10,9 @@ __all__ = ("GeoQueriesTest",) class GeoQueriesTest(MongoDBTestCase): - def _create_event_data(self, point_field_class=GeoPointField): """Create some sample data re-used in many of the tests below.""" + class Event(Document): title = StringField() date = DateTimeField() @@ -28,15 +28,18 @@ class GeoQueriesTest(MongoDBTestCase): event1 = Event.objects.create( title="Coltrane Motion @ Double Door", date=datetime.datetime.now() - datetime.timedelta(days=1), - location=[-87.677137, 41.909889]) + location=[-87.677137, 41.909889], + ) event2 = Event.objects.create( title="Coltrane Motion @ Bottom of the Hill", date=datetime.datetime.now() - datetime.timedelta(days=10), - location=[-122.4194155, 37.7749295]) + location=[-122.4194155, 37.7749295], + ) event3 = Event.objects.create( title="Coltrane Motion @ Empty Bottle", date=datetime.datetime.now(), - location=[-87.686638, 41.900474]) + location=[-87.686638, 41.900474], + ) return event1, event2, event3 @@ -65,8 +68,7 @@ class GeoQueriesTest(MongoDBTestCase): # find events within 10 degrees of san francisco point = [-122.415579, 37.7566023] - events = self.Event.objects(location__near=point, - location__max_distance=10) + events = self.Event.objects(location__near=point, location__max_distance=10) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) @@ -78,8 +80,7 @@ class GeoQueriesTest(MongoDBTestCase): # find events at least 10 degrees away of san francisco point = [-122.415579, 37.7566023] - events = self.Event.objects(location__near=point, - location__min_distance=10) + events = self.Event.objects(location__near=point, location__min_distance=10) self.assertEqual(events.count(), 2) def test_within_distance(self): @@ -88,8 +89,7 @@ class GeoQueriesTest(MongoDBTestCase): # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 5] - events = self.Event.objects( - location__within_distance=point_and_distance) + events = self.Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 2) events = list(events) self.assertNotIn(event2, events) @@ -98,21 +98,18 @@ class GeoQueriesTest(MongoDBTestCase): # find events within 10 degrees of san francisco point_and_distance = [[-122.415579, 37.7566023], 10] - events = self.Event.objects( - location__within_distance=point_and_distance) + events = self.Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[-73.9509714, 40.7237134], 1] - events = self.Event.objects( - location__within_distance=point_and_distance) + events = self.Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 0) # ensure ordering is respected by "within_distance" point_and_distance = [[-87.67892, 41.9120459], 10] - events = self.Event.objects( - location__within_distance=point_and_distance) + events = self.Event.objects(location__within_distance=point_and_distance) events = events.order_by("-date") self.assertEqual(events.count(), 2) self.assertEqual(events[0], event3) @@ -145,7 +142,7 @@ class GeoQueriesTest(MongoDBTestCase): polygon2 = [ (-1.742249, 54.033586), (-1.225891, 52.792797), - (-4.40094, 53.389881) + (-4.40094, 53.389881), ] events = self.Event.objects(location__within_polygon=polygon2) self.assertEqual(events.count(), 0) @@ -154,9 +151,7 @@ class GeoQueriesTest(MongoDBTestCase): """Make sure the "near" operator works with a PointField, which corresponds to a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) # find all events "near" pitchfork office, chicago. # note that "near" will show the san francisco event, too, @@ -175,26 +170,23 @@ class GeoQueriesTest(MongoDBTestCase): """Ensure the "max_distance" operator works alongside the "near" operator with a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) # find events within 10km of san francisco point = [-122.415579, 37.7566023] - events = self.Event.objects(location__near=point, - location__max_distance=10000) + events = self.Event.objects(location__near=point, location__max_distance=10000) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) # find events within 1km of greenpoint, broolyn, nyc, ny - events = self.Event.objects(location__near=[-73.9509714, 40.7237134], - location__max_distance=1000) + events = self.Event.objects( + location__near=[-73.9509714, 40.7237134], location__max_distance=1000 + ) self.assertEqual(events.count(), 0) # ensure ordering is respected by "near" events = self.Event.objects( - location__near=[-87.67892, 41.9120459], - location__max_distance=10000 + location__near=[-87.67892, 41.9120459], location__max_distance=10000 ).order_by("-date") self.assertEqual(events.count(), 2) self.assertEqual(events[0], event3) @@ -203,9 +195,7 @@ class GeoQueriesTest(MongoDBTestCase): """Ensure the "geo_within_box" operator works with a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) # check that within_box works box = [(-125.0, 35.0), (-100.0, 40.0)] @@ -217,9 +207,7 @@ class GeoQueriesTest(MongoDBTestCase): """Ensure the "geo_within_polygon" operator works with a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) polygon = [ (-87.694445, 41.912114), @@ -235,7 +223,7 @@ class GeoQueriesTest(MongoDBTestCase): polygon2 = [ (-1.742249, 54.033586), (-1.225891, 52.792797), - (-4.40094, 53.389881) + (-4.40094, 53.389881), ] events = self.Event.objects(location__geo_within_polygon=polygon2) self.assertEqual(events.count(), 0) @@ -244,23 +232,20 @@ class GeoQueriesTest(MongoDBTestCase): """Ensure "min_distace" and "max_distance" operators work well together with the "near" operator in a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) # ensure min_distance and max_distance combine well events = self.Event.objects( location__near=[-87.67892, 41.9120459], location__min_distance=1000, - location__max_distance=10000 + location__max_distance=10000, ).order_by("-date") self.assertEqual(events.count(), 1) self.assertEqual(events[0], event3) # ensure ordering is respected by "near" with "min_distance" events = self.Event.objects( - location__near=[-87.67892, 41.9120459], - location__min_distance=10000 + location__near=[-87.67892, 41.9120459], location__min_distance=10000 ).order_by("-date") self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) @@ -269,14 +254,11 @@ class GeoQueriesTest(MongoDBTestCase): """Make sure the "geo_within_center" operator works with a 2dsphere index. """ - event1, event2, event3 = self._create_event_data( - point_field_class=PointField - ) + event1, event2, event3 = self._create_event_data(point_field_class=PointField) # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[-87.67892, 41.9120459], 2] - events = self.Event.objects( - location__geo_within_center=point_and_distance) + events = self.Event.objects(location__geo_within_center=point_and_distance) self.assertEqual(events.count(), 2) events = list(events) self.assertNotIn(event2, events) @@ -287,6 +269,7 @@ class GeoQueriesTest(MongoDBTestCase): """Helper test method ensuring given point field class works well in an embedded document. """ + class Venue(EmbeddedDocument): location = point_field_class() name = StringField() @@ -300,12 +283,11 @@ class GeoQueriesTest(MongoDBTestCase): venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889]) venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295]) - event1 = Event(title="Coltrane Motion @ Double Door", - venue=venue1).save() - event2 = Event(title="Coltrane Motion @ Bottom of the Hill", - venue=venue2).save() - event3 = Event(title="Coltrane Motion @ Empty Bottle", - venue=venue1).save() + event1 = Event(title="Coltrane Motion @ Double Door", venue=venue1).save() + event2 = Event( + title="Coltrane Motion @ Bottom of the Hill", venue=venue2 + ).save() + event3 = Event(title="Coltrane Motion @ Empty Bottle", venue=venue1).save() # find all events "near" pitchfork office, chicago. # note that "near" will show the san francisco event, too, @@ -324,6 +306,7 @@ class GeoQueriesTest(MongoDBTestCase): def test_spherical_geospatial_operators(self): """Ensure that spherical geospatial queries are working.""" + class Point(Document): location = GeoPointField() @@ -343,26 +326,26 @@ class GeoQueriesTest(MongoDBTestCase): # Same behavior for _within_spherical_distance points = Point.objects( - location__within_spherical_distance=[ - [-122, 37.5], - 60 / earth_radius - ] + location__within_spherical_distance=[[-122, 37.5], 60 / earth_radius] ) self.assertEqual(points.count(), 2) - points = Point.objects(location__near_sphere=[-122, 37.5], - location__max_distance=60 / earth_radius) + points = Point.objects( + location__near_sphere=[-122, 37.5], location__max_distance=60 / earth_radius + ) self.assertEqual(points.count(), 2) # Test query works with max_distance, being farer from one point - points = Point.objects(location__near_sphere=[-122, 37.8], - location__max_distance=60 / earth_radius) + points = Point.objects( + location__near_sphere=[-122, 37.8], location__max_distance=60 / earth_radius + ) close_point = points.first() self.assertEqual(points.count(), 1) # Test query works with min_distance, being farer from one point - points = Point.objects(location__near_sphere=[-122, 37.8], - location__min_distance=60 / earth_radius) + points = Point.objects( + location__near_sphere=[-122, 37.8], location__min_distance=60 / earth_radius + ) self.assertEqual(points.count(), 1) far_point = points.first() self.assertNotEqual(close_point, far_point) @@ -384,10 +367,7 @@ class GeoQueriesTest(MongoDBTestCase): # Finds only one point because only the first point is within 60km of # the reference point to the south. points = Point.objects( - location__within_spherical_distance=[ - [-122, 36.5], - 60 / earth_radius - ] + location__within_spherical_distance=[[-122, 36.5], 60 / earth_radius] ) self.assertEqual(points.count(), 1) self.assertEqual(points[0].id, south_point.id) @@ -413,8 +393,10 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(1, roads) # Within - polygon = {"type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + polygon = { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], + } roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count() self.assertEqual(1, roads) @@ -425,8 +407,7 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(1, roads) # Intersects - line = {"type": "LineString", - "coordinates": [[40, 5], [40, 6]]} + line = {"type": "LineString", "coordinates": [[40, 5], [40, 6]]} roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count() self.assertEqual(1, roads) @@ -436,8 +417,10 @@ class GeoQueriesTest(MongoDBTestCase): roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count() self.assertEqual(1, roads) - polygon = {"type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + polygon = { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], + } roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count() self.assertEqual(1, roads) @@ -468,8 +451,10 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(1, roads) # Within - polygon = {"type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + polygon = { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], + } roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count() self.assertEqual(1, roads) @@ -480,8 +465,7 @@ class GeoQueriesTest(MongoDBTestCase): self.assertEqual(1, roads) # Intersects - line = {"type": "LineString", - "coordinates": [[40, 5], [41, 6]]} + line = {"type": "LineString", "coordinates": [[40, 5], [41, 6]]} roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count() self.assertEqual(1, roads) @@ -491,8 +475,10 @@ class GeoQueriesTest(MongoDBTestCase): roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count() self.assertEqual(1, roads) - polygon = {"type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + polygon = { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]], + } roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count() self.assertEqual(1, roads) @@ -504,20 +490,20 @@ class GeoQueriesTest(MongoDBTestCase): def test_aspymongo_with_only(self): """Ensure as_pymongo works with only""" + class Place(Document): location = PointField() Place.drop_collection() p = Place(location=[24.946861267089844, 60.16311983618494]) p.save() - qs = Place.objects().only('location') + qs = Place.objects().only("location") self.assertDictEqual( - qs.as_pymongo()[0]['location'], - {u'type': u'Point', - u'coordinates': [ - 24.946861267089844, - 60.16311983618494] - } + qs.as_pymongo()[0]["location"], + { + u"type": u"Point", + u"coordinates": [24.946861267089844, 60.16311983618494], + }, ) def test_2dsphere_point_sets_correctly(self): @@ -542,11 +528,15 @@ class GeoQueriesTest(MongoDBTestCase): Location(line=[[1, 2], [2, 2]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}) + self.assertEqual( + loc["line"], {"type": "LineString", "coordinates": [[1, 2], [2, 2]]} + ) Location.objects.update(set__line=[[2, 1], [1, 2]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]}) + self.assertEqual( + loc["line"], {"type": "LineString", "coordinates": [[2, 1], [1, 2]]} + ) def test_geojson_PolygonField(self): class Location(Document): @@ -556,12 +546,18 @@ class GeoQueriesTest(MongoDBTestCase): Location(poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save() loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["poly"], {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}) + self.assertEqual( + loc["poly"], + {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}, + ) Location.objects.update(set__poly=[[[40, 4], [40, 6], [41, 6], [40, 4]]]) loc = Location.objects.as_pymongo()[0] - self.assertEqual(loc["poly"], {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]}) + self.assertEqual( + loc["poly"], + {"type": "Polygon", "coordinates": [[[40, 4], [40, 6], [41, 6], [40, 4]]]}, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py index 3c5879ba..e092d11c 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/modify.py @@ -11,7 +11,6 @@ class Doc(Document): class FindAndModifyTest(unittest.TestCase): - def setUp(self): connect(db="mongoenginetest") Doc.drop_collection() @@ -82,9 +81,14 @@ class FindAndModifyTest(unittest.TestCase): old_doc = Doc.objects().order_by("-id").modify(set__value=-1) self.assertEqual(old_doc.to_json(), doc.to_json()) - self.assertDbEqual([ - {"_id": 0, "value": 3}, {"_id": 1, "value": 2}, - {"_id": 2, "value": 1}, {"_id": 3, "value": -1}]) + self.assertDbEqual( + [ + {"_id": 0, "value": 3}, + {"_id": 1, "value": 2}, + {"_id": 2, "value": 1}, + {"_id": 3, "value": -1}, + ] + ) def test_modify_with_fields(self): Doc(id=0, value=0).save() @@ -103,27 +107,25 @@ class FindAndModifyTest(unittest.TestCase): blog = BlogPost.objects.create() # Push a new tag via modify with new=False (default). - BlogPost(id=blog.id).modify(push__tags='code') + BlogPost(id=blog.id).modify(push__tags="code") self.assertEqual(blog.tags, []) blog.reload() - self.assertEqual(blog.tags, ['code']) + self.assertEqual(blog.tags, ["code"]) # Push a new tag via modify with new=True. - blog = BlogPost.objects(id=blog.id).modify(push__tags='java', new=True) - self.assertEqual(blog.tags, ['code', 'java']) + blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True) + self.assertEqual(blog.tags, ["code", "java"]) # Push a new tag with a positional argument. - blog = BlogPost.objects(id=blog.id).modify( - push__tags__0='python', - new=True) - self.assertEqual(blog.tags, ['python', 'code', 'java']) + blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True) + self.assertEqual(blog.tags, ["python", "code", "java"]) # Push multiple new tags with a positional argument. blog = BlogPost.objects(id=blog.id).modify( - push__tags__1=['go', 'rust'], - new=True) - self.assertEqual(blog.tags, ['python', 'go', 'rust', 'code', 'java']) + push__tags__1=["go", "rust"], new=True + ) + self.assertEqual(blog.tags, ["python", "go", "rust", "code", "java"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/queryset/pickable.py b/tests/queryset/pickable.py index bf7bb31c..0945fcbc 100644 --- a/tests/queryset/pickable.py +++ b/tests/queryset/pickable.py @@ -4,7 +4,7 @@ from pymongo.mongo_client import MongoClient from mongoengine import Document, StringField, IntField from mongoengine.connection import connect -__author__ = 'stas' +__author__ = "stas" class Person(Document): @@ -17,6 +17,7 @@ class TestQuerysetPickable(unittest.TestCase): Test for adding pickling support for QuerySet instances See issue https://github.com/MongoEngine/mongoengine/issues/442 """ + def setUp(self): super(TestQuerysetPickable, self).setUp() @@ -24,10 +25,7 @@ class TestQuerysetPickable(unittest.TestCase): connection.drop_database("test") - self.john = Person.objects.create( - name="John", - age=21 - ) + self.john = Person.objects.create(name="John", age=21) def test_picke_simple_qs(self): @@ -54,15 +52,9 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(Person.objects.first().age, 23) def test_pickle_support_filtration(self): - Person.objects.create( - name="Alice", - age=22 - ) + Person.objects.create(name="Alice", age=22) - Person.objects.create( - name="Bob", - age=23 - ) + Person.objects.create(name="Bob", age=23) qs = Person.objects.filter(age__gte=22) self.assertEqual(qs.count(), 2) @@ -71,9 +63,3 @@ class TestQuerysetPickable(unittest.TestCase): self.assertEqual(loaded.count(), 2) self.assertEqual(loaded.filter(name="Bob").first().age, 23) - - - - - - diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index c86e4095..21f35012 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -17,29 +17,34 @@ from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version -from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, - QuerySet, QuerySetManager, queryset_manager) +from mongoengine.queryset import ( + DoesNotExist, + MultipleObjectsReturned, + QuerySet, + QuerySetManager, + queryset_manager, +) class db_ops_tracker(query_counter): - def get_ops(self): ignore_query = dict(self._ignored_query) - ignore_query['command.count'] = {'$ne': 'system.profile'} # Ignore the query issued by query_counter + ignore_query["command.count"] = { + "$ne": "system.profile" + } # Ignore the query issued by query_counter return list(self.db.system.profile.find(ignore_query)) def get_key_compat(mongo_ver): - ORDER_BY_KEY = 'sort' - CMD_QUERY_KEY = 'command' if mongo_ver >= MONGODB_36 else 'query' + ORDER_BY_KEY = "sort" + CMD_QUERY_KEY = "command" if mongo_ver >= MONGODB_36 else "query" return ORDER_BY_KEY, CMD_QUERY_KEY class QuerySetTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') - connect(db='mongoenginetest2', alias='test2') + connect(db="mongoenginetest") + connect(db="mongoenginetest2", alias="test2") class PersonMeta(EmbeddedDocument): weight = IntField() @@ -48,7 +53,7 @@ class QuerySetTest(unittest.TestCase): name = StringField() age = IntField() person_meta = EmbeddedDocumentField(PersonMeta) - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} Person.drop_collection() self.PersonMeta = PersonMeta @@ -60,12 +65,14 @@ class QuerySetTest(unittest.TestCase): """Ensure that a QuerySet is correctly initialised by QuerySetManager. """ self.assertIsInstance(self.Person.objects, QuerySet) - self.assertEqual(self.Person.objects._collection.name, - self.Person._get_collection_name()) - self.assertIsInstance(self.Person.objects._collection, pymongo.collection.Collection) + self.assertEqual( + self.Person.objects._collection.name, self.Person._get_collection_name() + ) + self.assertIsInstance( + self.Person.objects._collection, pymongo.collection.Collection + ) def test_cannot_perform_joins_references(self): - class BlogPost(Document): author = ReferenceField(self.Person) author2 = GenericReferenceField() @@ -80,8 +87,8 @@ class QuerySetTest(unittest.TestCase): def test_find(self): """Ensure that a query returns a valid set of results.""" - user_a = self.Person.objects.create(name='User A', age=20) - user_b = self.Person.objects.create(name='User B', age=30) + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) # Find all people in the collection people = self.Person.objects @@ -92,11 +99,11 @@ class QuerySetTest(unittest.TestCase): self.assertIsInstance(results[0].id, ObjectId) self.assertEqual(results[0], user_a) - self.assertEqual(results[0].name, 'User A') + self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].age, 20) self.assertEqual(results[1], user_b) - self.assertEqual(results[1].name, 'User B') + self.assertEqual(results[1].name, "User B") self.assertEqual(results[1].age, 30) # Filter people by age @@ -109,8 +116,8 @@ class QuerySetTest(unittest.TestCase): def test_limit(self): """Ensure that QuerySet.limit works as expected.""" - user_a = self.Person.objects.create(name='User A', age=20) - user_b = self.Person.objects.create(name='User B', age=30) + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) # Test limit on a new queryset people = list(self.Person.objects.limit(1)) @@ -131,15 +138,15 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(people), 2) # Test chaining of only after limit - person = self.Person.objects().limit(1).only('name').first() + person = self.Person.objects().limit(1).only("name").first() self.assertEqual(person, user_a) - self.assertEqual(person.name, 'User A') + self.assertEqual(person.name, "User A") self.assertEqual(person.age, None) def test_skip(self): """Ensure that QuerySet.skip works as expected.""" - user_a = self.Person.objects.create(name='User A', age=20) - user_b = self.Person.objects.create(name='User B', age=30) + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) # Test skip on a new queryset people = list(self.Person.objects.skip(1)) @@ -155,20 +162,20 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(people2[0], user_b) # Test chaining of only after skip - person = self.Person.objects().skip(1).only('name').first() + person = self.Person.objects().skip(1).only("name").first() self.assertEqual(person, user_b) - self.assertEqual(person.name, 'User B') + self.assertEqual(person.name, "User B") self.assertEqual(person.age, None) def test___getitem___invalid_index(self): """Ensure slicing a queryset works as expected.""" with self.assertRaises(TypeError): - self.Person.objects()['a'] + self.Person.objects()["a"] def test_slice(self): """Ensure slicing a queryset works as expected.""" - user_a = self.Person.objects.create(name='User A', age=20) - user_b = self.Person.objects.create(name='User B', age=30) + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) user_c = self.Person.objects.create(name="User C", age=40) # Test slice limit @@ -202,7 +209,7 @@ class QuerySetTest(unittest.TestCase): qs._cursor_obj = None people = list(qs) self.assertEqual(len(people), 1) - self.assertEqual(people[0].name, 'User B') + self.assertEqual(people[0].name, "User B") # Test empty slice people = list(self.Person.objects[1:1]) @@ -215,14 +222,18 @@ class QuerySetTest(unittest.TestCase): # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): - self.Person(name='A%s' % i, age=i).save() + self.Person(name="A%s" % i, age=i).save() self.assertEqual(self.Person.objects.count(), 55) self.assertEqual("Person object", "%s" % self.Person.objects[0]) - self.assertEqual("[, ]", - "%s" % self.Person.objects[1:3]) - self.assertEqual("[, ]", - "%s" % self.Person.objects[51:53]) + self.assertEqual( + "[, ]", + "%s" % self.Person.objects[1:3], + ) + self.assertEqual( + "[, ]", + "%s" % self.Person.objects[51:53], + ) def test_find_one(self): """Ensure that a query using find_one returns a valid result. @@ -276,8 +287,7 @@ class QuerySetTest(unittest.TestCase): # Retrieve the first person from the database self.assertRaises(MultipleObjectsReturned, self.Person.objects.get) - self.assertRaises(self.Person.MultipleObjectsReturned, - self.Person.objects.get) + self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get) # Use a query to filter the people found to just person2 person = self.Person.objects.get(age=30) @@ -289,6 +299,7 @@ class QuerySetTest(unittest.TestCase): def test_find_array_position(self): """Ensure that query by array position works. """ + class Comment(EmbeddedDocument): name = StringField() @@ -301,34 +312,34 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() - Blog.objects.create(tags=['a', 'b']) - self.assertEqual(Blog.objects(tags__0='a').count(), 1) - self.assertEqual(Blog.objects(tags__0='b').count(), 0) - self.assertEqual(Blog.objects(tags__1='a').count(), 0) - self.assertEqual(Blog.objects(tags__1='b').count(), 1) + Blog.objects.create(tags=["a", "b"]) + self.assertEqual(Blog.objects(tags__0="a").count(), 1) + self.assertEqual(Blog.objects(tags__0="b").count(), 0) + self.assertEqual(Blog.objects(tags__1="a").count(), 0) + self.assertEqual(Blog.objects(tags__1="b").count(), 1) Blog.drop_collection() - comment1 = Comment(name='testa') - comment2 = Comment(name='testb') + comment1 = Comment(name="testa") + comment2 = Comment(name="testb") post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) blog1 = Blog.objects.create(posts=[post1, post2]) blog2 = Blog.objects.create(posts=[post2, post1]) - blog = Blog.objects(posts__0__comments__0__name='testa').get() + blog = Blog.objects(posts__0__comments__0__name="testa").get() self.assertEqual(blog, blog1) - blog = Blog.objects(posts__0__comments__0__name='testb').get() + blog = Blog.objects(posts__0__comments__0__name="testb").get() self.assertEqual(blog, blog2) - query = Blog.objects(posts__1__comments__1__name='testb') + query = Blog.objects(posts__1__comments__1__name="testb") self.assertEqual(query.count(), 2) - query = Blog.objects(posts__1__comments__1__name='testa') + query = Blog.objects(posts__1__comments__1__name="testa") self.assertEqual(query.count(), 0) - query = Blog.objects(posts__0__comments__1__name='testa') + query = Blog.objects(posts__0__comments__1__name="testa") self.assertEqual(query.count(), 0) Blog.drop_collection() @@ -367,13 +378,14 @@ class QuerySetTest(unittest.TestCase): q2 = q2.filter(ref=a1)._query self.assertEqual(q1, q2) - a_objects = A.objects(s='test1') + a_objects = A.objects(s="test1") query = B.objects(ref__in=a_objects) query = query.filter(boolfield=True) self.assertEqual(query.count(), 1) def test_batch_size(self): """Ensure that batch_size works.""" + class A(Document): s = StringField() @@ -416,33 +428,33 @@ class QuerySetTest(unittest.TestCase): self.Person.drop_collection() write_concern = {"fsync": True} - author = self.Person.objects.create(name='Test User') + author = self.Person.objects.create(name="Test User") author.save(write_concern=write_concern) # Ensure no regression of #1958 - author = self.Person(name='Test User2') + author = self.Person(name="Test User2") author.save(write_concern=None) # will default to {w: 1} - result = self.Person.objects.update( - set__name='Ross', write_concern={"w": 1}) + result = self.Person.objects.update(set__name="Ross", write_concern={"w": 1}) self.assertEqual(result, 2) - result = self.Person.objects.update( - set__name='Ross', write_concern={"w": 0}) + result = self.Person.objects.update(set__name="Ross", write_concern={"w": 0}) self.assertEqual(result, None) result = self.Person.objects.update_one( - set__name='Test User', write_concern={"w": 1}) + set__name="Test User", write_concern={"w": 1} + ) self.assertEqual(result, 1) result = self.Person.objects.update_one( - set__name='Test User', write_concern={"w": 0}) + set__name="Test User", write_concern={"w": 0} + ) self.assertEqual(result, None) def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" self.Person.drop_collection() - author = self.Person.objects.create(name='Test User') + author = self.Person.objects.create(name="Test User") with self.assertRaises(OperationError): self.Person.objects(pk=author.pk).update({}) @@ -457,6 +469,7 @@ class QuerySetTest(unittest.TestCase): set__posts__1__comments__1__name="testc" Check that it only works for ListFields. """ + class Comment(EmbeddedDocument): name = StringField() @@ -469,16 +482,16 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() - comment1 = Comment(name='testa') - comment2 = Comment(name='testb') + comment1 = Comment(name="testa") + comment2 = Comment(name="testb") post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) Blog.objects.create(posts=[post1, post2]) Blog.objects.create(posts=[post2, post1]) # Update all of the first comments of second posts of all blogs - Blog.objects().update(set__posts__1__comments__0__name='testc') - testc_blogs = Blog.objects(posts__1__comments__0__name='testc') + Blog.objects().update(set__posts__1__comments__0__name="testc") + testc_blogs = Blog.objects(posts__1__comments__0__name="testc") self.assertEqual(testc_blogs.count(), 2) Blog.drop_collection() @@ -486,14 +499,13 @@ class QuerySetTest(unittest.TestCase): Blog.objects.create(posts=[post2, post1]) # Update only the first blog returned by the query - Blog.objects().update_one( - set__posts__1__comments__1__name='testc') - testc_blogs = Blog.objects(posts__1__comments__1__name='testc') + Blog.objects().update_one(set__posts__1__comments__1__name="testc") + testc_blogs = Blog.objects(posts__1__comments__1__name="testc") self.assertEqual(testc_blogs.count(), 1) # Check that using this indexing syntax on a non-list fails with self.assertRaises(InvalidQueryError): - Blog.objects().update(set__posts__1__comments__0__name__1='asdf') + Blog.objects().update(set__posts__1__comments__0__name__1="asdf") Blog.drop_collection() @@ -519,7 +531,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects(comments__by="jane").update(inc__comments__S__votes=1) post = BlogPost.objects.first() - self.assertEqual(post.comments[1].by, 'jane') + self.assertEqual(post.comments[1].by, "jane") self.assertEqual(post.comments[1].votes, 8) def test_update_using_positional_operator_matches_first(self): @@ -563,7 +575,7 @@ class QuerySetTest(unittest.TestCase): # Nested updates arent supported yet.. with self.assertRaises(OperationError): Simple.drop_collection() - Simple(x=[{'test': [1, 2, 3, 4]}]).save() + Simple(x=[{"test": [1, 2, 3, 4]}]).save() Simple.objects(x__test=2).update(set__x__S__test__S=3) self.assertEqual(simple.x, [1, 2, 3, 4]) @@ -590,10 +602,11 @@ class QuerySetTest(unittest.TestCase): BlogPost(title="ABC", comments=[c1, c2]).save() BlogPost.objects(comments__by="joe").update( - set__comments__S__votes=Vote(score=4)) + set__comments__S__votes=Vote(score=4) + ) post = BlogPost.objects.first() - self.assertEqual(post.comments[0].by, 'joe') + self.assertEqual(post.comments[0].by, "joe") self.assertEqual(post.comments[0].votes.score, 4) def test_update_min_max(self): @@ -618,16 +631,15 @@ class QuerySetTest(unittest.TestCase): item = StringField() price = FloatField() - product = Product.objects.create(item='ABC', price=10.99) - product = Product.objects.create(item='ABC', price=10.99) + product = Product.objects.create(item="ABC", price=10.99) + product = Product.objects.create(item="ABC", price=10.99) Product.objects(id=product.id).update(mul__price=1.25) self.assertEqual(Product.objects.get(id=product.id).price, 13.7375) - unknown_product = Product.objects.create(item='Unknown') + unknown_product = Product.objects.create(item="Unknown") Product.objects(id=unknown_product.id).update(mul__price=100) self.assertEqual(Product.objects.get(id=unknown_product.id).price, 0) def test_updates_can_have_match_operators(self): - class Comment(EmbeddedDocument): content = StringField() name = StringField(max_length=120) @@ -643,8 +655,11 @@ class QuerySetTest(unittest.TestCase): comm1 = Comment(content="very funny indeed", name="John S", vote=1) comm2 = Comment(content="kind of funny", name="Mark P", vote=0) - Post(title='Fun with MongoEngine', tags=['mongodb', 'mongoengine'], - comments=[comm1, comm2]).save() + Post( + title="Fun with MongoEngine", + tags=["mongodb", "mongoengine"], + comments=[comm1, comm2], + ).save() Post.objects().update_one(pull__comments__vote__lt=1) @@ -652,6 +667,7 @@ class QuerySetTest(unittest.TestCase): def test_mapfield_update(self): """Ensure that the MapField can be updated.""" + class Member(EmbeddedDocument): gender = StringField() age = IntField() @@ -662,37 +678,35 @@ class QuerySetTest(unittest.TestCase): Club.drop_collection() club = Club() - club.members['John'] = Member(gender="M", age=13) + club.members["John"] = Member(gender="M", age=13) club.save() - Club.objects().update( - set__members={"John": Member(gender="F", age=14)}) + Club.objects().update(set__members={"John": Member(gender="F", age=14)}) club = Club.objects().first() - self.assertEqual(club.members['John'].gender, "F") - self.assertEqual(club.members['John'].age, 14) + self.assertEqual(club.members["John"].gender, "F") + self.assertEqual(club.members["John"].age, 14) def test_dictfield_update(self): """Ensure that the DictField can be updated.""" + class Club(Document): members = DictField() club = Club() - club.members['John'] = {'gender': 'M', 'age': 13} + club.members["John"] = {"gender": "M", "age": 13} club.save() - Club.objects().update( - set__members={"John": {'gender': 'F', 'age': 14}}) + Club.objects().update(set__members={"John": {"gender": "F", "age": 14}}) club = Club.objects().first() - self.assertEqual(club.members['John']['gender'], "F") - self.assertEqual(club.members['John']['age'], 14) + self.assertEqual(club.members["John"]["gender"], "F") + self.assertEqual(club.members["John"]["age"], 14) def test_update_results(self): self.Person.drop_collection() - result = self.Person(name="Bob", age=25).update( - upsert=True, full_result=True) + result = self.Person(name="Bob", age=25).update(upsert=True, full_result=True) self.assertIsInstance(result, UpdateResult) self.assertIn("upserted", result.raw_result) self.assertFalse(result.raw_result["updatedExisting"]) @@ -703,8 +717,7 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(result.raw_result["updatedExisting"]) self.Person(name="Bob", age=20).save() - result = self.Person.objects(name="Bob").update( - set__name="bobby", multi=True) + result = self.Person.objects(name="Bob").update(set__name="bobby", multi=True) self.assertEqual(result, 2) def test_update_validate(self): @@ -718,8 +731,12 @@ class QuerySetTest(unittest.TestCase): ed_f = EmbeddedDocumentField(EmDoc) self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True) - self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True) - self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) + self.assertRaises( + ValidationError, Doc.objects().update, dt_f="datetime", upsert=True + ) + self.assertRaises( + ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True + ) def test_update_related_models(self): class TestPerson(Document): @@ -732,34 +749,33 @@ class QuerySetTest(unittest.TestCase): TestPerson.drop_collection() TestOrganization.drop_collection() - p = TestPerson(name='p1') + p = TestPerson(name="p1") p.save() - o = TestOrganization(name='o1') + o = TestOrganization(name="o1") o.save() o.owner = p - p.name = 'p2' + p.name = "p2" - self.assertEqual(o._get_changed_fields(), ['owner']) - self.assertEqual(p._get_changed_fields(), ['name']) + self.assertEqual(o._get_changed_fields(), ["owner"]) + self.assertEqual(p._get_changed_fields(), ["name"]) o.save() self.assertEqual(o._get_changed_fields(), []) - self.assertEqual(p._get_changed_fields(), ['name']) # Fails; it's empty + self.assertEqual(p._get_changed_fields(), ["name"]) # Fails; it's empty # This will do NOTHING at all, even though we changed the name p.save() p.reload() - self.assertEqual(p.name, 'p2') # Fails; it's still `p1` + self.assertEqual(p.name, "p2") # Fails; it's still `p1` def test_upsert(self): self.Person.drop_collection() - self.Person.objects( - pk=ObjectId(), name="Bob", age=30).update(upsert=True) + self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True) bob = self.Person.objects.first() self.assertEqual("Bob", bob.name) @@ -786,7 +802,8 @@ class QuerySetTest(unittest.TestCase): self.Person.drop_collection() self.Person.objects(pk=ObjectId()).update( - set__name='Bob', set_on_insert__age=30, upsert=True) + set__name="Bob", set_on_insert__age=30, upsert=True + ) bob = self.Person.objects.first() self.assertEqual("Bob", bob.name) @@ -797,7 +814,7 @@ class QuerySetTest(unittest.TestCase): field = IntField() class B(Document): - meta = {'collection': 'b'} + meta = {"collection": "b"} field = IntField(default=1) embed = EmbeddedDocumentField(Embed, default=Embed) @@ -820,7 +837,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(record.embed.field, 2) # Request only the _id field and save - clone = B.objects().only('id').first() + clone = B.objects().only("id").first() clone.save() # Reload the record and see that the embed data is not lost @@ -831,6 +848,7 @@ class QuerySetTest(unittest.TestCase): def test_bulk_insert(self): """Ensure that bulk insert works""" + class Comment(EmbeddedDocument): name = StringField() @@ -847,14 +865,13 @@ class QuerySetTest(unittest.TestCase): # Recreates the collection self.assertEqual(0, Blog.objects.count()) - comment1 = Comment(name='testa') - comment2 = Comment(name='testb') + comment1 = Comment(name="testa") + comment2 = Comment(name="testb") post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) # Check bulk insert using load_bulk=False - blogs = [Blog(title="%s" % i, posts=[post1, post2]) - for i in range(99)] + blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs, load_bulk=False) @@ -866,8 +883,7 @@ class QuerySetTest(unittest.TestCase): Blog.ensure_indexes() # Check bulk insert using load_bulk=True - blogs = [Blog(title="%s" % i, posts=[post1, post2]) - for i in range(99)] + blogs = [Blog(title="%s" % i, posts=[post1, post2]) for i in range(99)] with query_counter() as q: self.assertEqual(q, 0) Blog.objects.insert(blogs) @@ -875,8 +891,8 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() - comment1 = Comment(name='testa') - comment2 = Comment(name='testb') + comment1 = Comment(name="testa") + comment2 = Comment(name="testb") post1 = Post(comments=[comment1, comment2]) post2 = Post(comments=[comment2, comment2]) blog1 = Blog(title="code", posts=[post1, post2]) @@ -892,8 +908,7 @@ class QuerySetTest(unittest.TestCase): blog = Blog.objects.first() Blog.objects.insert(blog) self.assertEqual( - str(cm.exception), - 'Some documents have ObjectIds, use doc.update() instead' + str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" ) # test inserting a query set @@ -901,8 +916,7 @@ class QuerySetTest(unittest.TestCase): blogs_qs = Blog.objects Blog.objects.insert(blogs_qs) self.assertEqual( - str(cm.exception), - 'Some documents have ObjectIds, use doc.update() instead' + str(cm.exception), "Some documents have ObjectIds, use doc.update() instead" ) # insert 1 new doc @@ -948,13 +962,13 @@ class QuerySetTest(unittest.TestCase): name = StringField() Blog.drop_collection() - Blog(name='test').save() + Blog(name="test").save() with self.assertRaises(OperationError): Blog.objects.insert("HELLO WORLD") with self.assertRaises(OperationError): - Blog.objects.insert({'name': 'garbage'}) + Blog.objects.insert({"name": "garbage"}) def test_bulk_insert_update_input_document_ids(self): class Comment(Document): @@ -1010,10 +1024,11 @@ class QuerySetTest(unittest.TestCase): """Make sure we don't perform unnecessary db operations when none of document's fields were updated. """ + class Person(Document): name = StringField() - owns = ListField(ReferenceField('Organization')) - projects = ListField(ReferenceField('Project')) + owns = ListField(ReferenceField("Organization")) + projects = ListField(ReferenceField("Project")) class Organization(Document): name = StringField() @@ -1070,8 +1085,8 @@ class QuerySetTest(unittest.TestCase): def test_repeated_iteration(self): """Ensure that QuerySet rewinds itself one iteration finishes. """ - self.Person(name='Person 1').save() - self.Person(name='Person 2').save() + self.Person(name="Person 1").save() + self.Person(name="Person 2").save() queryset = self.Person.objects people1 = [person for person in queryset] @@ -1099,7 +1114,7 @@ class QuerySetTest(unittest.TestCase): for i in range(1000): Doc(number=i).save() - docs = Doc.objects.order_by('number') + docs = Doc.objects.order_by("number") self.assertEqual(docs.count(), 1000) @@ -1107,88 +1122,89 @@ class QuerySetTest(unittest.TestCase): self.assertIn("Doc: 0", docs_string) self.assertEqual(docs.count(), 1000) - self.assertIn('(remaining elements truncated)', "%s" % docs) + self.assertIn("(remaining elements truncated)", "%s" % docs) # Limit and skip docs = docs[1:4] - self.assertEqual('[, , ]', "%s" % docs) + self.assertEqual("[, , ]", "%s" % docs) self.assertEqual(docs.count(with_limit_and_skip=True), 3) for doc in docs: - self.assertEqual('.. queryset mid-iteration ..', repr(docs)) + self.assertEqual(".. queryset mid-iteration ..", repr(docs)) def test_regex_query_shortcuts(self): """Ensure that contains, startswith, endswith, etc work. """ - person = self.Person(name='Guido van Rossum') + person = self.Person(name="Guido van Rossum") person.save() # Test contains - obj = self.Person.objects(name__contains='van').first() + obj = self.Person.objects(name__contains="van").first() self.assertEqual(obj, person) - obj = self.Person.objects(name__contains='Van').first() + obj = self.Person.objects(name__contains="Van").first() self.assertEqual(obj, None) # Test icontains - obj = self.Person.objects(name__icontains='Van').first() + obj = self.Person.objects(name__icontains="Van").first() self.assertEqual(obj, person) # Test startswith - obj = self.Person.objects(name__startswith='Guido').first() + obj = self.Person.objects(name__startswith="Guido").first() self.assertEqual(obj, person) - obj = self.Person.objects(name__startswith='guido').first() + obj = self.Person.objects(name__startswith="guido").first() self.assertEqual(obj, None) # Test istartswith - obj = self.Person.objects(name__istartswith='guido').first() + obj = self.Person.objects(name__istartswith="guido").first() self.assertEqual(obj, person) # Test endswith - obj = self.Person.objects(name__endswith='Rossum').first() + obj = self.Person.objects(name__endswith="Rossum").first() self.assertEqual(obj, person) - obj = self.Person.objects(name__endswith='rossuM').first() + obj = self.Person.objects(name__endswith="rossuM").first() self.assertEqual(obj, None) # Test iendswith - obj = self.Person.objects(name__iendswith='rossuM').first() + obj = self.Person.objects(name__iendswith="rossuM").first() self.assertEqual(obj, person) # Test exact - obj = self.Person.objects(name__exact='Guido van Rossum').first() + obj = self.Person.objects(name__exact="Guido van Rossum").first() self.assertEqual(obj, person) - obj = self.Person.objects(name__exact='Guido van rossum').first() + obj = self.Person.objects(name__exact="Guido van rossum").first() self.assertEqual(obj, None) - obj = self.Person.objects(name__exact='Guido van Rossu').first() + obj = self.Person.objects(name__exact="Guido van Rossu").first() self.assertEqual(obj, None) # Test iexact - obj = self.Person.objects(name__iexact='gUIDO VAN rOSSUM').first() + obj = self.Person.objects(name__iexact="gUIDO VAN rOSSUM").first() self.assertEqual(obj, person) - obj = self.Person.objects(name__iexact='gUIDO VAN rOSSU').first() + obj = self.Person.objects(name__iexact="gUIDO VAN rOSSU").first() self.assertEqual(obj, None) # Test unsafe expressions - person = self.Person(name='Guido van Rossum [.\'Geek\']') + person = self.Person(name="Guido van Rossum [.'Geek']") person.save() - obj = self.Person.objects(name__icontains='[.\'Geek').first() + obj = self.Person.objects(name__icontains="[.'Geek").first() self.assertEqual(obj, person) def test_not(self): """Ensure that the __not operator works as expected. """ - alice = self.Person(name='Alice', age=25) + alice = self.Person(name="Alice", age=25) alice.save() - obj = self.Person.objects(name__iexact='alice').first() + obj = self.Person.objects(name__iexact="alice").first() self.assertEqual(obj, alice) - obj = self.Person.objects(name__not__iexact='alice').first() + obj = self.Person.objects(name__not__iexact="alice").first() self.assertEqual(obj, None) def test_filter_chaining(self): """Ensure filters can be chained together. """ + class Blog(Document): id = StringField(primary_key=True) @@ -1217,25 +1233,26 @@ class QuerySetTest(unittest.TestCase): blog=blog_1, title="Blog Post #1", is_published=True, - published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) + published_date=datetime.datetime(2010, 1, 5, 0, 0, 0), ) BlogPost.objects.create( blog=blog_2, title="Blog Post #2", is_published=True, - published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) + published_date=datetime.datetime(2010, 1, 6, 0, 0, 0), ) BlogPost.objects.create( blog=blog_3, title="Blog Post #3", is_published=True, - published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) + published_date=datetime.datetime(2010, 1, 7, 0, 0, 0), ) # find all published blog posts before 2010-01-07 published_posts = BlogPost.published() published_posts = published_posts.filter( - published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0)) + published_date__lt=datetime.datetime(2010, 1, 7, 0, 0, 0) + ) self.assertEqual(published_posts.count(), 2) blog_posts = BlogPost.objects @@ -1247,11 +1264,11 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() def test_filter_chaining_with_regex(self): - person = self.Person(name='Guido van Rossum') + person = self.Person(name="Guido van Rossum") person.save() people = self.Person.objects - people = people.filter(name__startswith='Gui').filter(name__not__endswith='tum') + people = people.filter(name__startswith="Gui").filter(name__not__endswith="tum") self.assertEqual(people.count(), 1) def assertSequence(self, qs, expected): @@ -1264,27 +1281,23 @@ class QuerySetTest(unittest.TestCase): def test_ordering(self): """Ensure default ordering is applied and can be overridden. """ + class BlogPost(Document): title = StringField() published_date = DateTimeField() - meta = { - 'ordering': ['-published_date'] - } + meta = {"ordering": ["-published_date"]} BlogPost.drop_collection() blog_post_1 = BlogPost.objects.create( - title="Blog Post #1", - published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) + title="Blog Post #1", published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) ) blog_post_2 = BlogPost.objects.create( - title="Blog Post #2", - published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) + title="Blog Post #2", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) ) blog_post_3 = BlogPost.objects.create( - title="Blog Post #3", - published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) + title="Blog Post #3", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) ) # get the "first" BlogPost using default ordering @@ -1307,39 +1320,35 @@ class QuerySetTest(unittest.TestCase): title = StringField() published_date = DateTimeField() - meta = { - 'ordering': ['-published_date'] - } + meta = {"ordering": ["-published_date"]} BlogPost.drop_collection() # default ordering should be used by default with db_ops_tracker() as q: - BlogPost.objects.filter(title='whatever').first() + BlogPost.objects.filter(title="whatever").first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], - {'published_date': -1} + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": -1} ) # calling order_by() should clear the default ordering with db_ops_tracker() as q: - BlogPost.objects.filter(title='whatever').order_by().first() + BlogPost.objects.filter(title="whatever").order_by().first() self.assertEqual(len(q.get_ops()), 1) self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # calling an explicit order_by should use a specified sort with db_ops_tracker() as q: - BlogPost.objects.filter(title='whatever').order_by('published_date').first() + BlogPost.objects.filter(title="whatever").order_by("published_date").first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], - {'published_date': 1} + q.get_ops()[0][CMD_QUERY_KEY][ORDER_BY_KEY], {"published_date": 1} ) # calling order_by() after an explicit sort should clear it with db_ops_tracker() as q: - qs = BlogPost.objects.filter(title='whatever').order_by('published_date') + qs = BlogPost.objects.filter(title="whatever").order_by("published_date") qs.order_by().first() self.assertEqual(len(q.get_ops()), 1) self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) @@ -1353,21 +1362,20 @@ class QuerySetTest(unittest.TestCase): title = StringField() published_date = DateTimeField() - meta = { - 'ordering': ['-published_date'] - } + meta = {"ordering": ["-published_date"]} BlogPost.objects.create( - title='whatever', published_date=datetime.datetime.utcnow()) + title="whatever", published_date=datetime.datetime.utcnow() + ) with db_ops_tracker() as q: - BlogPost.objects.get(title='whatever') + BlogPost.objects.get(title="whatever") self.assertEqual(len(q.get_ops()), 1) self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) # Ordering should be ignored for .get even if we set it explicitly with db_ops_tracker() as q: - BlogPost.objects.order_by('-title').get(title='whatever') + BlogPost.objects.order_by("-title").get(title="whatever") self.assertEqual(len(q.get_ops()), 1) self.assertNotIn(ORDER_BY_KEY, q.get_ops()[0][CMD_QUERY_KEY]) @@ -1375,6 +1383,7 @@ class QuerySetTest(unittest.TestCase): """Ensure that an embedded document is properly returned from different manners of querying. """ + class User(EmbeddedDocument): name = StringField() @@ -1384,23 +1393,20 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - user = User(name='Test User') - BlogPost.objects.create( - author=user, - content='Had a good coffee today...' - ) + user = User(name="Test User") + BlogPost.objects.create(author=user, content="Had a good coffee today...") result = BlogPost.objects.first() self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, 'Test User') + self.assertEqual(result.author.name, "Test User") result = BlogPost.objects.get(author__name=user.name) self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, 'Test User') + self.assertEqual(result.author.name, "Test User") - result = BlogPost.objects.get(author={'name': user.name}) + result = BlogPost.objects.get(author={"name": user.name}) self.assertIsInstance(result.author, User) - self.assertEqual(result.author.name, 'Test User') + self.assertEqual(result.author.name, "Test User") # Fails, since the string is not a type that is able to represent the # author's document structure (should be dict) @@ -1409,6 +1415,7 @@ class QuerySetTest(unittest.TestCase): def test_find_empty_embedded(self): """Ensure that you can save and find an empty embedded document.""" + class User(EmbeddedDocument): name = StringField() @@ -1418,7 +1425,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - BlogPost.objects.create(content='Anonymous post...') + BlogPost.objects.create(content="Anonymous post...") result = BlogPost.objects.get(author=None) self.assertEqual(result.author, None) @@ -1426,15 +1433,16 @@ class QuerySetTest(unittest.TestCase): def test_find_dict_item(self): """Ensure that DictField items may be found. """ + class BlogPost(Document): info = DictField() BlogPost.drop_collection() - post = BlogPost(info={'title': 'test'}) + post = BlogPost(info={"title": "test"}) post.save() - post_obj = BlogPost.objects(info__title='test').first() + post_obj = BlogPost.objects(info__title="test").first() self.assertEqual(post_obj.id, post.id) BlogPost.drop_collection() @@ -1442,6 +1450,7 @@ class QuerySetTest(unittest.TestCase): def test_exec_js_query(self): """Ensure that queries are properly formed for use in exec_js. """ + class BlogPost(Document): hits = IntField() published = BooleanField() @@ -1468,10 +1477,10 @@ class QuerySetTest(unittest.TestCase): """ # Ensure that normal queries work - c = BlogPost.objects(published=True).exec_js(js_func, 'hits') + c = BlogPost.objects(published=True).exec_js(js_func, "hits") self.assertEqual(c, 2) - c = BlogPost.objects(published=False).exec_js(js_func, 'hits') + c = BlogPost.objects(published=False).exec_js(js_func, "hits") self.assertEqual(c, 1) BlogPost.drop_collection() @@ -1479,22 +1488,22 @@ class QuerySetTest(unittest.TestCase): def test_exec_js_field_sub(self): """Ensure that field substitutions occur properly in exec_js functions. """ + class Comment(EmbeddedDocument): - content = StringField(db_field='body') + content = StringField(db_field="body") class BlogPost(Document): - name = StringField(db_field='doc-name') - comments = ListField(EmbeddedDocumentField(Comment), - db_field='cmnts') + name = StringField(db_field="doc-name") + comments = ListField(EmbeddedDocumentField(Comment), db_field="cmnts") BlogPost.drop_collection() - comments1 = [Comment(content='cool'), Comment(content='yay')] - post1 = BlogPost(name='post1', comments=comments1) + comments1 = [Comment(content="cool"), Comment(content="yay")] + post1 = BlogPost(name="post1", comments=comments1) post1.save() - comments2 = [Comment(content='nice stuff')] - post2 = BlogPost(name='post2', comments=comments2) + comments2 = [Comment(content="nice stuff")] + post2 = BlogPost(name="post2", comments=comments2) post2.save() code = """ @@ -1514,16 +1523,15 @@ class QuerySetTest(unittest.TestCase): """ sub_code = BlogPost.objects._sub_js_fields(code) - code_chunks = ['doc["cmnts"];', 'doc["doc-name"],', - 'doc["cmnts"][i]["body"]'] + code_chunks = ['doc["cmnts"];', 'doc["doc-name"],', 'doc["cmnts"][i]["body"]'] for chunk in code_chunks: self.assertIn(chunk, sub_code) results = BlogPost.objects.exec_js(code) expected_results = [ - {u'comment': u'cool', u'document': u'post1'}, - {u'comment': u'yay', u'document': u'post1'}, - {u'comment': u'nice stuff', u'document': u'post2'}, + {u"comment": u"cool", u"document": u"post1"}, + {u"comment": u"yay", u"document": u"post1"}, + {u"comment": u"nice stuff", u"document": u"post2"}, ] self.assertEqual(results, expected_results) @@ -1552,55 +1560,60 @@ class QuerySetTest(unittest.TestCase): def test_reverse_delete_rule_cascade(self): """Ensure cascading deletion of referring documents from the database. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + BlogPost.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - BlogPost(content='Watching TV', author=me).save() - BlogPost(content='Chilling out', author=me).save() - BlogPost(content='Pro Testing', author=someoneelse).save() + BlogPost(content="Watching TV", author=me).save() + BlogPost(content="Chilling out", author=me).save() + BlogPost(content="Pro Testing", author=someoneelse).save() self.assertEqual(3, BlogPost.objects.count()) - self.Person.objects(name='Test User').delete() + self.Person.objects(name="Test User").delete() self.assertEqual(1, BlogPost.objects.count()) def test_reverse_delete_rule_cascade_on_abstract_document(self): """Ensure cascading deletion of referring documents from the database does not fail on abstract document. """ + class AbstractBlogPost(Document): - meta = {'abstract': True} + meta = {"abstract": True} author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) class BlogPost(AbstractBlogPost): content = StringField() + BlogPost.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - BlogPost(content='Watching TV', author=me).save() - BlogPost(content='Chilling out', author=me).save() - BlogPost(content='Pro Testing', author=someoneelse).save() + BlogPost(content="Watching TV", author=me).save() + BlogPost(content="Chilling out", author=me).save() + BlogPost(content="Pro Testing", author=someoneelse).save() self.assertEqual(3, BlogPost.objects.count()) - self.Person.objects(name='Test User').delete() + self.Person.objects(name="Test User").delete() self.assertEqual(1, BlogPost.objects.count()) def test_reverse_delete_rule_cascade_cycle(self): """Ensure reference cascading doesn't loop if reference graph isn't a tree """ + class Dummy(Document): - reference = ReferenceField('self', reverse_delete_rule=CASCADE) + reference = ReferenceField("self", reverse_delete_rule=CASCADE) base = Dummy().save() other = Dummy(reference=base).save() @@ -1616,14 +1629,15 @@ class QuerySetTest(unittest.TestCase): """Ensure reference cascading doesn't loop if reference graph isn't a tree """ + class Category(Document): name = StringField() class Dummy(Document): - reference = ReferenceField('self', reverse_delete_rule=CASCADE) + reference = ReferenceField("self", reverse_delete_rule=CASCADE) cat = ReferenceField(Category, reverse_delete_rule=CASCADE) - cat = Category(name='cat').save() + cat = Category(name="cat").save() base = Dummy(cat=cat).save() other = Dummy(reference=base).save() other2 = Dummy(reference=other).save() @@ -1640,24 +1654,25 @@ class QuerySetTest(unittest.TestCase): """Ensure self-referencing CASCADE deletes do not result in infinite loop """ + class Category(Document): name = StringField() - parent = ReferenceField('self', reverse_delete_rule=CASCADE) + parent = ReferenceField("self", reverse_delete_rule=CASCADE) Category.drop_collection() num_children = 3 - base = Category(name='Root') + base = Category(name="Root") base.save() # Create a simple parent-child tree for i in range(num_children): - child_name = 'Child-%i' % i + child_name = "Child-%i" % i child = Category(name=child_name, parent=base) child.save() for i in range(num_children): - child_child_name = 'Child-Child-%i' % i + child_child_name = "Child-Child-%i" % i child_child = Category(name=child_child_name, parent=child) child_child.save() @@ -1673,6 +1688,7 @@ class QuerySetTest(unittest.TestCase): def test_reverse_delete_rule_nullify(self): """Ensure nullification of references to deleted documents. """ + class Category(Document): name = StringField() @@ -1683,14 +1699,14 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() Category.drop_collection() - lameness = Category(name='Lameness') + lameness = Category(name="Lameness") lameness.save() - post = BlogPost(content='Watching TV', category=lameness) + post = BlogPost(content="Watching TV", category=lameness) post.save() self.assertEqual(1, BlogPost.objects.count()) - self.assertEqual('Lameness', BlogPost.objects.first().category.name) + self.assertEqual("Lameness", BlogPost.objects.first().category.name) Category.objects.delete() self.assertEqual(1, BlogPost.objects.count()) self.assertEqual(None, BlogPost.objects.first().category) @@ -1699,24 +1715,26 @@ class QuerySetTest(unittest.TestCase): """Ensure nullification of references to deleted documents when reference is on an abstract document. """ + class AbstractBlogPost(Document): - meta = {'abstract': True} + meta = {"abstract": True} author = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) class BlogPost(AbstractBlogPost): content = StringField() + BlogPost.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - BlogPost(content='Watching TV', author=me).save() + BlogPost(content="Watching TV", author=me).save() self.assertEqual(1, BlogPost.objects.count()) self.assertEqual(me, BlogPost.objects.first().author) - self.Person.objects(name='Test User').delete() + self.Person.objects(name="Test User").delete() self.assertEqual(1, BlogPost.objects.count()) self.assertEqual(None, BlogPost.objects.first().author) @@ -1724,6 +1742,7 @@ class QuerySetTest(unittest.TestCase): """Ensure deletion gets denied on documents that still have references to them. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=DENY) @@ -1731,10 +1750,10 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() self.Person.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - post = BlogPost(content='Watching TV', author=me) + post = BlogPost(content="Watching TV", author=me) post.save() self.assertRaises(OperationError, self.Person.objects.delete) @@ -1743,18 +1762,20 @@ class QuerySetTest(unittest.TestCase): """Ensure deletion gets denied on documents that still have references to them, when reference is on an abstract document. """ + class AbstractBlogPost(Document): - meta = {'abstract': True} + meta = {"abstract": True} author = ReferenceField(self.Person, reverse_delete_rule=DENY) class BlogPost(AbstractBlogPost): content = StringField() + BlogPost.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - BlogPost(content='Watching TV', author=me).save() + BlogPost(content="Watching TV", author=me).save() self.assertEqual(1, BlogPost.objects.count()) self.assertRaises(OperationError, self.Person.objects.delete) @@ -1762,24 +1783,24 @@ class QuerySetTest(unittest.TestCase): def test_reverse_delete_rule_pull(self): """Ensure pulling of references to deleted documents. """ + class BlogPost(Document): content = StringField() - authors = ListField(ReferenceField(self.Person, - reverse_delete_rule=PULL)) + authors = ListField(ReferenceField(self.Person, reverse_delete_rule=PULL)) BlogPost.drop_collection() self.Person.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - post = BlogPost(content='Watching TV', authors=[me, someoneelse]) + post = BlogPost(content="Watching TV", authors=[me, someoneelse]) post.save() - another = BlogPost(content='Chilling Out', authors=[someoneelse]) + another = BlogPost(content="Chilling Out", authors=[someoneelse]) another.save() someoneelse.delete() @@ -1793,10 +1814,10 @@ class QuerySetTest(unittest.TestCase): """Ensure pulling of references to deleted documents when reference is defined on an abstract document.. """ + class AbstractBlogPost(Document): - meta = {'abstract': True} - authors = ListField(ReferenceField(self.Person, - reverse_delete_rule=PULL)) + meta = {"abstract": True} + authors = ListField(ReferenceField(self.Person, reverse_delete_rule=PULL)) class BlogPost(AbstractBlogPost): content = StringField() @@ -1804,16 +1825,16 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() self.Person.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - post = BlogPost(content='Watching TV', authors=[me, someoneelse]) + post = BlogPost(content="Watching TV", authors=[me, someoneelse]) post.save() - another = BlogPost(content='Chilling Out', authors=[someoneelse]) + another = BlogPost(content="Chilling Out", authors=[someoneelse]) another.save() someoneelse.delete() @@ -1824,7 +1845,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(another.authors, []) def test_delete_with_limits(self): - class Log(Document): pass @@ -1839,19 +1859,21 @@ class QuerySetTest(unittest.TestCase): def test_delete_with_limit_handles_delete_rules(self): """Ensure cascading deletion of referring documents from the database. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) + BlogPost.drop_collection() - me = self.Person(name='Test User') + me = self.Person(name="Test User") me.save() - someoneelse = self.Person(name='Some-one Else') + someoneelse = self.Person(name="Some-one Else") someoneelse.save() - BlogPost(content='Watching TV', author=me).save() - BlogPost(content='Chilling out', author=me).save() - BlogPost(content='Pro Testing', author=someoneelse).save() + BlogPost(content="Watching TV", author=me).save() + BlogPost(content="Chilling out", author=me).save() + BlogPost(content="Pro Testing", author=someoneelse).save() self.assertEqual(3, BlogPost.objects.count()) self.Person.objects()[:1].delete() @@ -1870,6 +1892,7 @@ class QuerySetTest(unittest.TestCase): def test_reference_field_find(self): """Ensure cascading deletion of referring documents from the database. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person) @@ -1877,7 +1900,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() self.Person.drop_collection() - me = self.Person(name='Test User').save() + me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() self.assertEqual(1, BlogPost.objects(author=me).count()) @@ -1886,12 +1909,12 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual( - 1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) def test_reference_field_find_dbref(self): """Ensure cascading deletion of referring documents from the database. """ + class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, dbref=True) @@ -1899,7 +1922,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() self.Person.drop_collection() - me = self.Person(name='Test User').save() + me = self.Person(name="Test User").save() BlogPost(content="test 123", author=me).save() self.assertEqual(1, BlogPost.objects(author=me).count()) @@ -1908,8 +1931,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(1, BlogPost.objects(author__in=[me]).count()) self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count()) - self.assertEqual( - 1, BlogPost.objects(author__in=["%s" % me.pk]).count()) + self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count()) def test_update_intfield_operator(self): class BlogPost(Document): @@ -1946,7 +1968,7 @@ class QuerySetTest(unittest.TestCase): post = BlogPost(review=3.5) post.save() - BlogPost.objects.update_one(inc__review=0.1) # test with floats + BlogPost.objects.update_one(inc__review=0.1) # test with floats post.reload() self.assertEqual(float(post.review), 3.6) @@ -1954,7 +1976,7 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(float(post.review), 3.5) - BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal + BlogPost.objects.update_one(inc__review=Decimal(0.12)) # test with Decimal post.reload() self.assertEqual(float(post.review), 3.62) @@ -1972,38 +1994,39 @@ class QuerySetTest(unittest.TestCase): post.save() with self.assertRaises(OperationError): - BlogPost.objects.update_one(inc__review=0.1) # test with floats + BlogPost.objects.update_one(inc__review=0.1) # test with floats def test_update_listfield_operator(self): """Ensure that atomic updates work properly. """ + class BlogPost(Document): tags = ListField(StringField()) BlogPost.drop_collection() - post = BlogPost(tags=['test']) + post = BlogPost(tags=["test"]) post.save() # ListField operator - BlogPost.objects.update(push__tags='mongo') + BlogPost.objects.update(push__tags="mongo") post.reload() - self.assertIn('mongo', post.tags) + self.assertIn("mongo", post.tags) - BlogPost.objects.update_one(push_all__tags=['db', 'nosql']) + BlogPost.objects.update_one(push_all__tags=["db", "nosql"]) post.reload() - self.assertIn('db', post.tags) - self.assertIn('nosql', post.tags) + self.assertIn("db", post.tags) + self.assertIn("nosql", post.tags) tags = post.tags[:-1] BlogPost.objects.update(pop__tags=1) post.reload() self.assertEqual(post.tags, tags) - BlogPost.objects.update_one(add_to_set__tags='unique') - BlogPost.objects.update_one(add_to_set__tags='unique') + BlogPost.objects.update_one(add_to_set__tags="unique") + BlogPost.objects.update_one(add_to_set__tags="unique") post.reload() - self.assertEqual(post.tags.count('unique'), 1) + self.assertEqual(post.tags.count("unique"), 1) BlogPost.drop_collection() @@ -2013,18 +2036,19 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - post = BlogPost(title='garbage').save() + post = BlogPost(title="garbage").save() self.assertNotEqual(post.title, None) BlogPost.objects.update_one(unset__title=1) post.reload() self.assertEqual(post.title, None) pymongo_doc = BlogPost.objects.as_pymongo().first() - self.assertNotIn('title', pymongo_doc) + self.assertNotIn("title", pymongo_doc) def test_update_push_with_position(self): """Ensure that the 'push' update with position works properly. """ + class BlogPost(Document): slug = StringField() tags = ListField(StringField()) @@ -2036,20 +2060,21 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.filter(id=post.id).update(push__tags="code") BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"]) post.reload() - self.assertEqual(post.tags, ['mongodb', 'python', 'code']) + self.assertEqual(post.tags, ["mongodb", "python", "code"]) BlogPost.objects.filter(id=post.id).update(set__tags__2="java") post.reload() - self.assertEqual(post.tags, ['mongodb', 'python', 'java']) + self.assertEqual(post.tags, ["mongodb", "python", "java"]) # test push with singular value - BlogPost.objects.filter(id=post.id).update(push__tags__0='scala') + BlogPost.objects.filter(id=post.id).update(push__tags__0="scala") post.reload() - self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java']) + self.assertEqual(post.tags, ["scala", "mongodb", "python", "java"]) def test_update_push_list_of_list(self): """Ensure that the 'push' update operation works in the list of list """ + class BlogPost(Document): slug = StringField() tags = ListField() @@ -2065,6 +2090,7 @@ class QuerySetTest(unittest.TestCase): def test_update_push_and_pull_add_to_set(self): """Ensure that the 'pull' update operation works correctly. """ + class BlogPost(Document): slug = StringField() tags = ListField(StringField()) @@ -2078,8 +2104,7 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags, ["code"]) - BlogPost.objects.filter(id=post.id).update( - push_all__tags=["mongodb", "code"]) + BlogPost.objects.filter(id=post.id).update(push_all__tags=["mongodb", "code"]) post.reload() self.assertEqual(post.tags, ["code", "mongodb", "code"]) @@ -2087,13 +2112,13 @@ class QuerySetTest(unittest.TestCase): post.reload() self.assertEqual(post.tags, ["mongodb"]) - BlogPost.objects(slug="test").update( - pull_all__tags=["mongodb", "code"]) + BlogPost.objects(slug="test").update(pull_all__tags=["mongodb", "code"]) post.reload() self.assertEqual(post.tags, []) BlogPost.objects(slug="test").update( - __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}}) + __raw__={"$addToSet": {"tags": {"$each": ["code", "mongodb", "code"]}}} + ) post.reload() self.assertEqual(post.tags, ["code", "mongodb"]) @@ -2101,13 +2126,13 @@ class QuerySetTest(unittest.TestCase): class Item(Document): name = StringField(required=True) description = StringField(max_length=50) - parents = ListField(ReferenceField('self')) + parents = ListField(ReferenceField("self")) Item.drop_collection() - item = Item(name='test item').save() - parent_1 = Item(name='parent 1').save() - parent_2 = Item(name='parent 2').save() + item = Item(name="test item").save() + parent_1 = Item(name="parent 1").save() + parent_2 = Item(name="parent 2").save() item.update(add_to_set__parents=[parent_1, parent_2, parent_1]) item.reload() @@ -2115,12 +2140,11 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([parent_1, parent_2], item.parents) def test_pull_nested(self): - class Collaborator(EmbeddedDocument): user = StringField() def __unicode__(self): - return '%s' % self.user + return "%s" % self.user class Site(Document): name = StringField(max_length=75, unique=True, required=True) @@ -2128,23 +2152,21 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() - c = Collaborator(user='Esteban') + c = Collaborator(user="Esteban") s = Site(name="test", collaborators=[c]).save() - Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') + Site.objects(id=s.id).update_one(pull__collaborators__user="Esteban") self.assertEqual(Site.objects.first().collaborators, []) with self.assertRaises(InvalidQueryError): - Site.objects(id=s.id).update_one( - pull_all__collaborators__user=['Ross']) + Site.objects(id=s.id).update_one(pull_all__collaborators__user=["Ross"]) def test_pull_from_nested_embedded(self): - class User(EmbeddedDocument): name = StringField() def __unicode__(self): - return '%s' % self.name + return "%s" % self.name class Collaborator(EmbeddedDocument): helpful = ListField(EmbeddedDocumentField(User)) @@ -2156,21 +2178,24 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() - c = User(name='Esteban') - f = User(name='Frank') - s = Site(name="test", collaborators=Collaborator( - helpful=[c], unhelpful=[f])).save() + c = User(name="Esteban") + f = User(name="Frank") + s = Site( + name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f]) + ).save() Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) - self.assertEqual(Site.objects.first().collaborators['helpful'], []) + self.assertEqual(Site.objects.first().collaborators["helpful"], []) Site.objects(id=s.id).update_one( - pull__collaborators__unhelpful={'name': 'Frank'}) - self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + pull__collaborators__unhelpful={"name": "Frank"} + ) + self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( - pull_all__collaborators__helpful__name=['Ross']) + pull_all__collaborators__helpful__name=["Ross"] + ) def test_pull_from_nested_embedded_using_in_nin(self): """Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators. @@ -2180,7 +2205,7 @@ class QuerySetTest(unittest.TestCase): name = StringField() def __unicode__(self): - return '%s' % self.name + return "%s" % self.name class Collaborator(EmbeddedDocument): helpful = ListField(EmbeddedDocumentField(User)) @@ -2192,60 +2217,62 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() - a = User(name='Esteban') - b = User(name='Frank') - x = User(name='Harry') - y = User(name='John') + a = User(name="Esteban") + b = User(name="Frank") + x = User(name="Harry") + y = User(name="John") - s = Site(name="test", collaborators=Collaborator( - helpful=[a, b], unhelpful=[x, y])).save() + s = Site( + name="test", collaborators=Collaborator(helpful=[a, b], unhelpful=[x, y]) + ).save() - Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a - self.assertEqual(Site.objects.first().collaborators['helpful'], [b]) + Site.objects(id=s.id).update_one( + pull__collaborators__helpful__name__in=["Esteban"] + ) # Pull a + self.assertEqual(Site.objects.first().collaborators["helpful"], [b]) - Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x - self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y]) + Site.objects(id=s.id).update_one( + pull__collaborators__unhelpful__name__nin=["John"] + ) # Pull x + self.assertEqual(Site.objects.first().collaborators["unhelpful"], [y]) def test_pull_from_nested_mapfield(self): - class Collaborator(EmbeddedDocument): user = StringField() def __unicode__(self): - return '%s' % self.user + return "%s" % self.user class Site(Document): name = StringField(max_length=75, unique=True, required=True) - collaborators = MapField( - ListField(EmbeddedDocumentField(Collaborator))) + collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) Site.drop_collection() - c = Collaborator(user='Esteban') - f = Collaborator(user='Frank') - s = Site(name="test", collaborators={'helpful': [c], 'unhelpful': [f]}) + c = Collaborator(user="Esteban") + f = Collaborator(user="Frank") + s = Site(name="test", collaborators={"helpful": [c], "unhelpful": [f]}) s.save() - Site.objects(id=s.id).update_one( - pull__collaborators__helpful__user='Esteban') - self.assertEqual(Site.objects.first().collaborators['helpful'], []) + Site.objects(id=s.id).update_one(pull__collaborators__helpful__user="Esteban") + self.assertEqual(Site.objects.first().collaborators["helpful"], []) Site.objects(id=s.id).update_one( - pull__collaborators__unhelpful={'user': 'Frank'}) - self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + pull__collaborators__unhelpful={"user": "Frank"} + ) + self.assertEqual(Site.objects.first().collaborators["unhelpful"], []) with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( - pull_all__collaborators__helpful__user=['Ross']) + pull_all__collaborators__helpful__user=["Ross"] + ) def test_pull_in_genericembedded_field(self): - class Foo(EmbeddedDocument): name = StringField() class Bar(Document): - foos = ListField(GenericEmbeddedDocumentField( - choices=[Foo, ])) + foos = ListField(GenericEmbeddedDocumentField(choices=[Foo])) Bar.drop_collection() @@ -2261,15 +2288,14 @@ class QuerySetTest(unittest.TestCase): BlogTag.drop_collection() - BlogTag(name='garbage').save() - default_update = BlogTag.objects.update_one(name='new') + BlogTag(name="garbage").save() + default_update = BlogTag.objects.update_one(name="new") self.assertEqual(default_update, 1) - full_result_update = BlogTag.objects.update_one(name='new', full_result=True) + full_result_update = BlogTag.objects.update_one(name="new", full_result=True) self.assertIsInstance(full_result_update, UpdateResult) def test_update_one_pop_generic_reference(self): - class BlogTag(Document): name = StringField(required=True) @@ -2280,9 +2306,9 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() BlogTag.drop_collection() - tag_1 = BlogTag(name='code') + tag_1 = BlogTag(name="code") tag_1.save() - tag_2 = BlogTag(name='mongodb') + tag_2 = BlogTag(name="mongodb") tag_2.save() post = BlogPost(slug="test", tags=[tag_1]) @@ -2301,7 +2327,6 @@ class QuerySetTest(unittest.TestCase): BlogTag.drop_collection() def test_editting_embedded_objects(self): - class BlogTag(EmbeddedDocument): name = StringField(required=True) @@ -2311,8 +2336,8 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() - tag_1 = BlogTag(name='code') - tag_2 = BlogTag(name='mongodb') + tag_1 = BlogTag(name="code") + tag_2 = BlogTag(name="mongodb") post = BlogPost(slug="test", tags=[tag_1]) post.save() @@ -2323,7 +2348,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects(slug="test-2").update_one(set__tags__0__name="python") post.reload() - self.assertEqual(post.tags[0].name, 'python') + self.assertEqual(post.tags[0].name, "python") BlogPost.objects(slug="test-2").update_one(pop__tags=-1) post.reload() @@ -2332,13 +2357,12 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() def test_set_list_embedded_documents(self): - class Author(EmbeddedDocument): name = StringField() class Message(Document): title = StringField() - authors = ListField(EmbeddedDocumentField('Author')) + authors = ListField(EmbeddedDocumentField("Author")) Message.drop_collection() @@ -2346,15 +2370,19 @@ class QuerySetTest(unittest.TestCase): message.save() Message.objects(authors__name="Harry").update_one( - set__authors__S=Author(name="Ross")) + set__authors__S=Author(name="Ross") + ) message = message.reload() self.assertEqual(message.authors[0].name, "Ross") Message.objects(authors__name="Ross").update_one( - set__authors=[Author(name="Harry"), - Author(name="Ross"), - Author(name="Adam")]) + set__authors=[ + Author(name="Harry"), + Author(name="Ross"), + Author(name="Adam"), + ] + ) message = message.reload() self.assertEqual(message.authors[0].name, "Harry") @@ -2362,7 +2390,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(message.authors[2].name, "Adam") def test_set_generic_embedded_documents(self): - class Bar(EmbeddedDocument): name = StringField() @@ -2372,15 +2399,13 @@ class QuerySetTest(unittest.TestCase): User.drop_collection() - User(username='abc').save() - User.objects(username='abc').update( - set__bar=Bar(name='test'), upsert=True) + User(username="abc").save() + User.objects(username="abc").update(set__bar=Bar(name="test"), upsert=True) - user = User.objects(username='abc').first() + user = User.objects(username="abc").first() self.assertEqual(user.bar.name, "test") def test_reload_embedded_docs_instance(self): - class SubDoc(EmbeddedDocument): val = IntField() @@ -2393,7 +2418,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(doc.pk, doc.embedded._instance.pk) def test_reload_list_embedded_docs_instance(self): - class SubDoc(EmbeddedDocument): val = IntField() @@ -2412,16 +2436,16 @@ class QuerySetTest(unittest.TestCase): self.Person(name="User A", age=20).save() self.Person(name="User C", age=30).save() - names = [p.name for p in self.Person.objects.order_by('-age')] - self.assertEqual(names, ['User B', 'User C', 'User A']) + names = [p.name for p in self.Person.objects.order_by("-age")] + self.assertEqual(names, ["User B", "User C", "User A"]) - names = [p.name for p in self.Person.objects.order_by('+age')] - self.assertEqual(names, ['User A', 'User C', 'User B']) + names = [p.name for p in self.Person.objects.order_by("+age")] + self.assertEqual(names, ["User A", "User C", "User B"]) - names = [p.name for p in self.Person.objects.order_by('age')] - self.assertEqual(names, ['User A', 'User C', 'User B']) + names = [p.name for p in self.Person.objects.order_by("age")] + self.assertEqual(names, ["User A", "User C", "User B"]) - ages = [p.age for p in self.Person.objects.order_by('-name')] + ages = [p.age for p in self.Person.objects.order_by("-name")] self.assertEqual(ages, [30, 40, 20]) def test_order_by_optional(self): @@ -2432,31 +2456,22 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() blog_post_3 = BlogPost.objects.create( - title="Blog Post #3", - published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) + title="Blog Post #3", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) ) blog_post_2 = BlogPost.objects.create( - title="Blog Post #2", - published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) + title="Blog Post #2", published_date=datetime.datetime(2010, 1, 5, 0, 0, 0) ) blog_post_4 = BlogPost.objects.create( - title="Blog Post #4", - published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) - ) - blog_post_1 = BlogPost.objects.create( - title="Blog Post #1", - published_date=None + title="Blog Post #4", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) ) + blog_post_1 = BlogPost.objects.create(title="Blog Post #1", published_date=None) expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4] - self.assertSequence(BlogPost.objects.order_by('published_date'), - expected) - self.assertSequence(BlogPost.objects.order_by('+published_date'), - expected) + self.assertSequence(BlogPost.objects.order_by("published_date"), expected) + self.assertSequence(BlogPost.objects.order_by("+published_date"), expected) expected.reverse() - self.assertSequence(BlogPost.objects.order_by('-published_date'), - expected) + self.assertSequence(BlogPost.objects.order_by("-published_date"), expected) def test_order_by_list(self): class BlogPost(Document): @@ -2466,23 +2481,20 @@ class QuerySetTest(unittest.TestCase): BlogPost.drop_collection() blog_post_1 = BlogPost.objects.create( - title="A", - published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) + title="A", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) ) blog_post_2 = BlogPost.objects.create( - title="B", - published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) + title="B", published_date=datetime.datetime(2010, 1, 6, 0, 0, 0) ) blog_post_3 = BlogPost.objects.create( - title="C", - published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) + title="C", published_date=datetime.datetime(2010, 1, 7, 0, 0, 0) ) - qs = BlogPost.objects.order_by('published_date', 'title') + qs = BlogPost.objects.order_by("published_date", "title") expected = [blog_post_1, blog_post_2, blog_post_3] self.assertSequence(qs, expected) - qs = BlogPost.objects.order_by('-published_date', '-title') + qs = BlogPost.objects.order_by("-published_date", "-title") expected.reverse() self.assertSequence(qs, expected) @@ -2493,7 +2505,7 @@ class QuerySetTest(unittest.TestCase): self.Person(name="User A", age=20).save() self.Person(name="User C", age=30).save() - only_age = self.Person.objects.order_by('-age').only('age') + only_age = self.Person.objects.order_by("-age").only("age") names = [p.name for p in only_age] ages = [p.age for p in only_age] @@ -2502,19 +2514,19 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(names, [None, None, None]) self.assertEqual(ages, [40, 30, 20]) - qs = self.Person.objects.all().order_by('-age') + qs = self.Person.objects.all().order_by("-age") qs = qs.limit(10) ages = [p.age for p in qs] self.assertEqual(ages, [40, 30, 20]) qs = self.Person.objects.all().limit(10) - qs = qs.order_by('-age') + qs = qs.order_by("-age") ages = [p.age for p in qs] self.assertEqual(ages, [40, 30, 20]) qs = self.Person.objects.all().skip(0) - qs = qs.order_by('-age') + qs = qs.order_by("-age") ages = [p.age for p in qs] self.assertEqual(ages, [40, 30, 20]) @@ -2538,47 +2550,47 @@ class QuerySetTest(unittest.TestCase): Author(author=person_b).save() Author(author=person_c).save() - names = [ - a.author.name for a in Author.objects.order_by('-author__age')] - self.assertEqual(names, ['User A', 'User B', 'User C']) + names = [a.author.name for a in Author.objects.order_by("-author__age")] + self.assertEqual(names, ["User A", "User B", "User C"]) def test_comment(self): """Make sure adding a comment to the query gets added to the query""" MONGO_VER = self.mongodb_version _, CMD_QUERY_KEY = get_key_compat(MONGO_VER) - QUERY_KEY = 'filter' - COMMENT_KEY = 'comment' + QUERY_KEY = "filter" + COMMENT_KEY = "comment" class User(Document): age = IntField() with db_ops_tracker() as q: - adult1 = (User.objects.filter(age__gte=18) - .comment('looking for an adult') - .first()) + adult1 = ( + User.objects.filter(age__gte=18).comment("looking for an adult").first() + ) - adult2 = (User.objects.comment('looking for an adult') - .filter(age__gte=18) - .first()) + adult2 = ( + User.objects.comment("looking for an adult").filter(age__gte=18).first() + ) ops = q.get_ops() self.assertEqual(len(ops), 2) for op in ops: - self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {'age': {'$gte': 18}}) - self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], 'looking for an adult') + self.assertEqual(op[CMD_QUERY_KEY][QUERY_KEY], {"age": {"$gte": 18}}) + self.assertEqual(op[CMD_QUERY_KEY][COMMENT_KEY], "looking for an adult") def test_map_reduce(self): """Ensure map/reduce is both mapping and reducing. """ + class BlogPost(Document): title = StringField() - tags = ListField(StringField(), db_field='post-tag-list') + tags = ListField(StringField(), db_field="post-tag-list") BlogPost.drop_collection() - BlogPost(title="Post #1", tags=['music', 'film', 'print']).save() - BlogPost(title="Post #2", tags=['music', 'film']).save() - BlogPost(title="Post #3", tags=['film', 'photography']).save() + BlogPost(title="Post #1", tags=["music", "film", "print"]).save() + BlogPost(title="Post #2", tags=["music", "film"]).save() + BlogPost(title="Post #3", tags=["film", "photography"]).save() map_f = """ function() { @@ -2628,8 +2640,8 @@ class QuerySetTest(unittest.TestCase): post2.save() post3.save() - self.assertEqual(BlogPost._fields['title'].db_field, '_id') - self.assertEqual(BlogPost._meta['id_field'], 'title') + self.assertEqual(BlogPost._fields["title"].db_field, "_id") + self.assertEqual(BlogPost._meta["id_field"], "title") map_f = """ function() { @@ -2661,16 +2673,14 @@ class QuerySetTest(unittest.TestCase): """ Test map/reduce custom output """ - register_connection('test2', 'mongoenginetest2') + register_connection("test2", "mongoenginetest2") class Family(Document): - id = IntField( - primary_key=True) + id = IntField(primary_key=True) log = StringField() class Person(Document): - id = IntField( - primary_key=True) + id = IntField(primary_key=True) name = StringField() age = IntField() family = ReferenceField(Family) @@ -2745,7 +2755,8 @@ class QuerySetTest(unittest.TestCase): cursor = Family.objects.map_reduce( map_f=map_family, reduce_f=reduce_f, - output={'replace': 'family_map', 'db_alias': 'test2'}) + output={"replace": "family_map", "db_alias": "test2"}, + ) # start a map/reduce cursor.next() @@ -2753,43 +2764,56 @@ class QuerySetTest(unittest.TestCase): results = Person.objects.map_reduce( map_f=map_person, reduce_f=reduce_f, - output={'reduce': 'family_map', 'db_alias': 'test2'}) + output={"reduce": "family_map", "db_alias": "test2"}, + ) results = list(results) - collection = get_db('test2').family_map + collection = get_db("test2").family_map self.assertEqual( - collection.find_one({'_id': 1}), { - '_id': 1, - 'value': { - 'persons': [ - {'age': 21, 'name': u'Wilson Jr'}, - {'age': 45, 'name': u'Wilson Father'}, - {'age': 40, 'name': u'Eliana Costa'}, - {'age': 17, 'name': u'Tayza Mariana'}], - 'totalAge': 123} - }) + collection.find_one({"_id": 1}), + { + "_id": 1, + "value": { + "persons": [ + {"age": 21, "name": u"Wilson Jr"}, + {"age": 45, "name": u"Wilson Father"}, + {"age": 40, "name": u"Eliana Costa"}, + {"age": 17, "name": u"Tayza Mariana"}, + ], + "totalAge": 123, + }, + }, + ) self.assertEqual( - collection.find_one({'_id': 2}), { - '_id': 2, - 'value': { - 'persons': [ - {'age': 16, 'name': u'Isabella Luanna'}, - {'age': 36, 'name': u'Sandra Mara'}, - {'age': 10, 'name': u'Igor Gabriel'}], - 'totalAge': 62} - }) + collection.find_one({"_id": 2}), + { + "_id": 2, + "value": { + "persons": [ + {"age": 16, "name": u"Isabella Luanna"}, + {"age": 36, "name": u"Sandra Mara"}, + {"age": 10, "name": u"Igor Gabriel"}, + ], + "totalAge": 62, + }, + }, + ) self.assertEqual( - collection.find_one({'_id': 3}), { - '_id': 3, - 'value': { - 'persons': [ - {'age': 30, 'name': u'Arthur WA'}, - {'age': 25, 'name': u'Paula Leonel'}], - 'totalAge': 55} - }) + collection.find_one({"_id": 3}), + { + "_id": 3, + "value": { + "persons": [ + {"age": 30, "name": u"Arthur WA"}, + {"age": 25, "name": u"Paula Leonel"}, + ], + "totalAge": 55, + }, + }, + ) def test_map_reduce_finalize(self): """Ensure that map, reduce, and finalize run and introduce "scope" @@ -2798,10 +2822,10 @@ class QuerySetTest(unittest.TestCase): from time import mktime class Link(Document): - title = StringField(db_field='bpTitle') + title = StringField(db_field="bpTitle") up_votes = IntField() down_votes = IntField() - submitted = DateTimeField(db_field='sTime') + submitted = DateTimeField(db_field="sTime") Link.drop_collection() @@ -2811,30 +2835,42 @@ class QuerySetTest(unittest.TestCase): # Fri, 12 Feb 2010 14:36:00 -0600. Link ordering should # reflect order of insertion below, but is not influenced # by insertion order. - Link(title="Google Buzz auto-followed a woman's abusive ex ...", - up_votes=1079, - down_votes=553, - submitted=now - datetime.timedelta(hours=4)).save() - Link(title="We did it! Barbie is a computer engineer.", - up_votes=481, - down_votes=124, - submitted=now - datetime.timedelta(hours=2)).save() - Link(title="This Is A Mosquito Getting Killed By A Laser", - up_votes=1446, - down_votes=530, - submitted=now - datetime.timedelta(hours=13)).save() - Link(title="Arabic flashcards land physics student in jail.", - up_votes=215, - down_votes=105, - submitted=now - datetime.timedelta(hours=6)).save() - Link(title="The Burger Lab: Presenting, the Flood Burger", - up_votes=48, - down_votes=17, - submitted=now - datetime.timedelta(hours=5)).save() - Link(title="How to see polarization with the naked eye", - up_votes=74, - down_votes=13, - submitted=now - datetime.timedelta(hours=10)).save() + Link( + title="Google Buzz auto-followed a woman's abusive ex ...", + up_votes=1079, + down_votes=553, + submitted=now - datetime.timedelta(hours=4), + ).save() + Link( + title="We did it! Barbie is a computer engineer.", + up_votes=481, + down_votes=124, + submitted=now - datetime.timedelta(hours=2), + ).save() + Link( + title="This Is A Mosquito Getting Killed By A Laser", + up_votes=1446, + down_votes=530, + submitted=now - datetime.timedelta(hours=13), + ).save() + Link( + title="Arabic flashcards land physics student in jail.", + up_votes=215, + down_votes=105, + submitted=now - datetime.timedelta(hours=6), + ).save() + Link( + title="The Burger Lab: Presenting, the Flood Burger", + up_votes=48, + down_votes=17, + submitted=now - datetime.timedelta(hours=5), + ).save() + Link( + title="How to see polarization with the naked eye", + up_votes=74, + down_votes=13, + submitted=now - datetime.timedelta(hours=10), + ).save() map_f = """ function() { @@ -2885,17 +2921,15 @@ class QuerySetTest(unittest.TestCase): # provide the reddit epoch (used for ranking) as a variable available # to all phases of the map/reduce operation: map, reduce, and finalize. reddit_epoch = mktime(datetime.datetime(2005, 12, 8, 7, 46, 43).timetuple()) - scope = {'reddit_epoch': reddit_epoch} + scope = {"reddit_epoch": reddit_epoch} # run a map/reduce operation across all links. ordering is set # to "-value", which orders the "weight" value returned from # "finalize_f" in descending order. results = Link.objects.order_by("-value") - results = results.map_reduce(map_f, - reduce_f, - "myresults", - finalize_f=finalize_f, - scope=scope) + results = results.map_reduce( + map_f, reduce_f, "myresults", finalize_f=finalize_f, scope=scope + ) results = list(results) # assert troublesome Buzz article is ranked 1st @@ -2909,54 +2943,56 @@ class QuerySetTest(unittest.TestCase): def test_item_frequencies(self): """Ensure that item frequencies are properly generated from lists. """ + class BlogPost(Document): hits = IntField() - tags = ListField(StringField(), db_field='blogTags') + tags = ListField(StringField(), db_field="blogTags") BlogPost.drop_collection() - BlogPost(hits=1, tags=['music', 'film', 'actors', 'watch']).save() - BlogPost(hits=2, tags=['music', 'watch']).save() - BlogPost(hits=2, tags=['music', 'actors']).save() + BlogPost(hits=1, tags=["music", "film", "actors", "watch"]).save() + BlogPost(hits=2, tags=["music", "watch"]).save() + BlogPost(hits=2, tags=["music", "actors"]).save() def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual( - set(['music', 'film', 'actors', 'watch']), set(f.keys())) - self.assertEqual(f['music'], 3) - self.assertEqual(f['actors'], 2) - self.assertEqual(f['watch'], 2) - self.assertEqual(f['film'], 1) + self.assertEqual(set(["music", "film", "actors", "watch"]), set(f.keys())) + self.assertEqual(f["music"], 3) + self.assertEqual(f["actors"], 2) + self.assertEqual(f["watch"], 2) + self.assertEqual(f["film"], 1) - exec_js = BlogPost.objects.item_frequencies('tags') - map_reduce = BlogPost.objects.item_frequencies('tags', map_reduce=True) + exec_js = BlogPost.objects.item_frequencies("tags") + map_reduce = BlogPost.objects.item_frequencies("tags", map_reduce=True) test_assertions(exec_js) test_assertions(map_reduce) # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) - self.assertEqual(f['music'], 2) - self.assertEqual(f['actors'], 1) - self.assertEqual(f['watch'], 1) + self.assertEqual(set(["music", "actors", "watch"]), set(f.keys())) + self.assertEqual(f["music"], 2) + self.assertEqual(f["actors"], 1) + self.assertEqual(f["watch"], 1) - exec_js = BlogPost.objects(hits__gt=1).item_frequencies('tags') - map_reduce = BlogPost.objects( - hits__gt=1).item_frequencies('tags', map_reduce=True) + exec_js = BlogPost.objects(hits__gt=1).item_frequencies("tags") + map_reduce = BlogPost.objects(hits__gt=1).item_frequencies( + "tags", map_reduce=True + ) test_assertions(exec_js) test_assertions(map_reduce) # Check that normalization works def test_assertions(f): - self.assertAlmostEqual(f['music'], 3.0 / 8.0) - self.assertAlmostEqual(f['actors'], 2.0 / 8.0) - self.assertAlmostEqual(f['watch'], 2.0 / 8.0) - self.assertAlmostEqual(f['film'], 1.0 / 8.0) + self.assertAlmostEqual(f["music"], 3.0 / 8.0) + self.assertAlmostEqual(f["actors"], 2.0 / 8.0) + self.assertAlmostEqual(f["watch"], 2.0 / 8.0) + self.assertAlmostEqual(f["film"], 1.0 / 8.0) - exec_js = BlogPost.objects.item_frequencies('tags', normalize=True) + exec_js = BlogPost.objects.item_frequencies("tags", normalize=True) map_reduce = BlogPost.objects.item_frequencies( - 'tags', normalize=True, map_reduce=True) + "tags", normalize=True, map_reduce=True + ) test_assertions(exec_js) test_assertions(map_reduce) @@ -2966,8 +3002,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(f[1], 1) self.assertEqual(f[2], 2) - exec_js = BlogPost.objects.item_frequencies('hits') - map_reduce = BlogPost.objects.item_frequencies('hits', map_reduce=True) + exec_js = BlogPost.objects.item_frequencies("hits") + map_reduce = BlogPost.objects.item_frequencies("hits", map_reduce=True) test_assertions(exec_js) test_assertions(map_reduce) @@ -2987,57 +3023,56 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() doc = Person(name="Guido") - doc.phone = Phone(number='62-3331-1656') + doc.phone = Phone(number="62-3331-1656") doc.save() doc = Person(name="Marr") - doc.phone = Phone(number='62-3331-1656') + doc.phone = Phone(number="62-3331-1656") doc.save() doc = Person(name="WP Junior") - doc.phone = Phone(number='62-3332-1656') + doc.phone = Phone(number="62-3332-1656") doc.save() def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual( - set(['62-3331-1656', '62-3332-1656']), set(f.keys())) - self.assertEqual(f['62-3331-1656'], 2) - self.assertEqual(f['62-3332-1656'], 1) + self.assertEqual(set(["62-3331-1656", "62-3332-1656"]), set(f.keys())) + self.assertEqual(f["62-3331-1656"], 2) + self.assertEqual(f["62-3332-1656"], 1) - exec_js = Person.objects.item_frequencies('phone.number') - map_reduce = Person.objects.item_frequencies( - 'phone.number', map_reduce=True) + exec_js = Person.objects.item_frequencies("phone.number") + map_reduce = Person.objects.item_frequencies("phone.number", map_reduce=True) test_assertions(exec_js) test_assertions(map_reduce) # Ensure query is taken into account def test_assertions(f): f = {key: int(val) for key, val in f.items()} - self.assertEqual(set(['62-3331-1656']), set(f.keys())) - self.assertEqual(f['62-3331-1656'], 2) + self.assertEqual(set(["62-3331-1656"]), set(f.keys())) + self.assertEqual(f["62-3331-1656"], 2) - exec_js = Person.objects( - phone__number='62-3331-1656').item_frequencies('phone.number') - map_reduce = Person.objects( - phone__number='62-3331-1656').item_frequencies('phone.number', map_reduce=True) + exec_js = Person.objects(phone__number="62-3331-1656").item_frequencies( + "phone.number" + ) + map_reduce = Person.objects(phone__number="62-3331-1656").item_frequencies( + "phone.number", map_reduce=True + ) test_assertions(exec_js) test_assertions(map_reduce) # Check that normalization works def test_assertions(f): - self.assertEqual(f['62-3331-1656'], 2.0 / 3.0) - self.assertEqual(f['62-3332-1656'], 1.0 / 3.0) + self.assertEqual(f["62-3331-1656"], 2.0 / 3.0) + self.assertEqual(f["62-3332-1656"], 1.0 / 3.0) - exec_js = Person.objects.item_frequencies( - 'phone.number', normalize=True) + exec_js = Person.objects.item_frequencies("phone.number", normalize=True) map_reduce = Person.objects.item_frequencies( - 'phone.number', normalize=True, map_reduce=True) + "phone.number", normalize=True, map_reduce=True + ) test_assertions(exec_js) test_assertions(map_reduce) def test_item_frequencies_null_values(self): - class Person(Document): name = StringField() city = StringField() @@ -3047,16 +3082,15 @@ class QuerySetTest(unittest.TestCase): Person(name="Wilson Snr", city="CRB").save() Person(name="Wilson Jr").save() - freq = Person.objects.item_frequencies('city') - self.assertEqual(freq, {'CRB': 1.0, None: 1.0}) - freq = Person.objects.item_frequencies('city', normalize=True) - self.assertEqual(freq, {'CRB': 0.5, None: 0.5}) + freq = Person.objects.item_frequencies("city") + self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + freq = Person.objects.item_frequencies("city", normalize=True) + self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) - freq = Person.objects.item_frequencies('city', map_reduce=True) - self.assertEqual(freq, {'CRB': 1.0, None: 1.0}) - freq = Person.objects.item_frequencies( - 'city', normalize=True, map_reduce=True) - self.assertEqual(freq, {'CRB': 0.5, None: 0.5}) + freq = Person.objects.item_frequencies("city", map_reduce=True) + self.assertEqual(freq, {"CRB": 1.0, None: 1.0}) + freq = Person.objects.item_frequencies("city", normalize=True, map_reduce=True) + self.assertEqual(freq, {"CRB": 0.5, None: 0.5}) def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): @@ -3080,11 +3114,11 @@ class QuerySetTest(unittest.TestCase): p.extra = Extra(tag="friend") p.save() - ot = Person.objects.item_frequencies('extra.tag', map_reduce=False) - self.assertEqual(ot, {None: 1.0, u'friend': 1.0}) + ot = Person.objects.item_frequencies("extra.tag", map_reduce=False) + self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) - ot = Person.objects.item_frequencies('extra.tag', map_reduce=True) - self.assertEqual(ot, {None: 1.0, u'friend': 1.0}) + ot = Person.objects.item_frequencies("extra.tag", map_reduce=True) + self.assertEqual(ot, {None: 1.0, u"friend": 1.0}) def test_item_frequencies_with_0_values(self): class Test(Document): @@ -3095,9 +3129,9 @@ class QuerySetTest(unittest.TestCase): t.val = 0 t.save() - ot = Test.objects.item_frequencies('val', map_reduce=True) + ot = Test.objects.item_frequencies("val", map_reduce=True) self.assertEqual(ot, {0: 1}) - ot = Test.objects.item_frequencies('val', map_reduce=False) + ot = Test.objects.item_frequencies("val", map_reduce=False) self.assertEqual(ot, {0: 1}) def test_item_frequencies_with_False_values(self): @@ -3109,9 +3143,9 @@ class QuerySetTest(unittest.TestCase): t.val = False t.save() - ot = Test.objects.item_frequencies('val', map_reduce=True) + ot = Test.objects.item_frequencies("val", map_reduce=True) self.assertEqual(ot, {False: 1}) - ot = Test.objects.item_frequencies('val', map_reduce=False) + ot = Test.objects.item_frequencies("val", map_reduce=False) self.assertEqual(ot, {False: 1}) def test_item_frequencies_normalize(self): @@ -3126,113 +3160,108 @@ class QuerySetTest(unittest.TestCase): for i in range(20): Test(val=2).save() - freqs = Test.objects.item_frequencies( - 'val', map_reduce=False, normalize=True) + freqs = Test.objects.item_frequencies("val", map_reduce=False, normalize=True) self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) - freqs = Test.objects.item_frequencies( - 'val', map_reduce=True, normalize=True) + freqs = Test.objects.item_frequencies("val", map_reduce=True, normalize=True) self.assertEqual(freqs, {1: 50.0 / 70, 2: 20.0 / 70}) def test_average(self): """Ensure that field can be averaged correctly. """ - self.Person(name='person', age=0).save() - self.assertEqual(int(self.Person.objects.average('age')), 0) + self.Person(name="person", age=0).save() + self.assertEqual(int(self.Person.objects.average("age")), 0) ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): - self.Person(name='test%s' % i, age=age).save() + self.Person(name="test%s" % i, age=age).save() avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0 - self.assertAlmostEqual(int(self.Person.objects.average('age')), avg) + self.assertAlmostEqual(int(self.Person.objects.average("age")), avg) - self.Person(name='ageless person').save() - self.assertEqual(int(self.Person.objects.average('age')), avg) + self.Person(name="ageless person").save() + self.assertEqual(int(self.Person.objects.average("age")), avg) # dot notation - self.Person( - name='person meta', person_meta=self.PersonMeta(weight=0)).save() + self.Person(name="person meta", person_meta=self.PersonMeta(weight=0)).save() self.assertAlmostEqual( - int(self.Person.objects.average('person_meta.weight')), 0) + int(self.Person.objects.average("person_meta.weight")), 0 + ) for i, weight in enumerate(ages): self.Person( - name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save() + name="test meta%i", person_meta=self.PersonMeta(weight=weight) + ).save() self.assertAlmostEqual( - int(self.Person.objects.average('person_meta.weight')), avg + int(self.Person.objects.average("person_meta.weight")), avg ) - self.Person(name='test meta none').save() - self.assertEqual( - int(self.Person.objects.average('person_meta.weight')), avg - ) + self.Person(name="test meta none").save() + self.assertEqual(int(self.Person.objects.average("person_meta.weight")), avg) # test summing over a filtered queryset over_50 = [a for a in ages if a >= 50] avg = float(sum(over_50)) / len(over_50) - self.assertEqual( - self.Person.objects.filter(age__gte=50).average('age'), - avg - ) + self.assertEqual(self.Person.objects.filter(age__gte=50).average("age"), avg) def test_sum(self): """Ensure that field can be summed over correctly. """ ages = [23, 54, 12, 94, 27] for i, age in enumerate(ages): - self.Person(name='test%s' % i, age=age).save() + self.Person(name="test%s" % i, age=age).save() - self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.assertEqual(self.Person.objects.sum("age"), sum(ages)) - self.Person(name='ageless person').save() - self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.Person(name="ageless person").save() + self.assertEqual(self.Person.objects.sum("age"), sum(ages)) for i, age in enumerate(ages): - self.Person(name='test meta%s' % - i, person_meta=self.PersonMeta(weight=age)).save() + self.Person( + name="test meta%s" % i, person_meta=self.PersonMeta(weight=age) + ).save() - self.assertEqual( - self.Person.objects.sum('person_meta.weight'), sum(ages) - ) + self.assertEqual(self.Person.objects.sum("person_meta.weight"), sum(ages)) - self.Person(name='weightless person').save() - self.assertEqual(self.Person.objects.sum('age'), sum(ages)) + self.Person(name="weightless person").save() + self.assertEqual(self.Person.objects.sum("age"), sum(ages)) # test summing over a filtered queryset self.assertEqual( - self.Person.objects.filter(age__gte=50).sum('age'), - sum([a for a in ages if a >= 50]) + self.Person.objects.filter(age__gte=50).sum("age"), + sum([a for a in ages if a >= 50]), ) def test_sum_over_db_field(self): """Ensure that a field mapped to a db field with a different name can be summed over correctly. """ + class UserVisit(Document): - num_visits = IntField(db_field='visits') + num_visits = IntField(db_field="visits") UserVisit.drop_collection() UserVisit.objects.create(num_visits=10) UserVisit.objects.create(num_visits=5) - self.assertEqual(UserVisit.objects.sum('num_visits'), 15) + self.assertEqual(UserVisit.objects.sum("num_visits"), 15) def test_average_over_db_field(self): """Ensure that a field mapped to a db field with a different name can have its average computed correctly. """ + class UserVisit(Document): - num_visits = IntField(db_field='visits') + num_visits = IntField(db_field="visits") UserVisit.drop_collection() UserVisit.objects.create(num_visits=20) UserVisit.objects.create(num_visits=10) - self.assertEqual(UserVisit.objects.average('num_visits'), 15) + self.assertEqual(UserVisit.objects.average("num_visits"), 15) def test_embedded_average(self): class Pay(EmbeddedDocument): @@ -3240,17 +3269,16 @@ class QuerySetTest(unittest.TestCase): class Doc(Document): name = StringField() - pay = EmbeddedDocumentField( - Pay) + pay = EmbeddedDocumentField(Pay) Doc.drop_collection() - Doc(name='Wilson Junior', pay=Pay(value=150)).save() - Doc(name='Isabella Luanna', pay=Pay(value=530)).save() - Doc(name='Tayza mariana', pay=Pay(value=165)).save() - Doc(name='Eliana Costa', pay=Pay(value=115)).save() + Doc(name="Wilson Junior", pay=Pay(value=150)).save() + Doc(name="Isabella Luanna", pay=Pay(value=530)).save() + Doc(name="Tayza mariana", pay=Pay(value=165)).save() + Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.average('pay.value'), 240) + self.assertEqual(Doc.objects.average("pay.value"), 240) def test_embedded_array_average(self): class Pay(EmbeddedDocument): @@ -3262,12 +3290,12 @@ class QuerySetTest(unittest.TestCase): Doc.drop_collection() - Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() - Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() - Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() - Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() + Doc(name="Wilson Junior", pay=Pay(values=[150, 100])).save() + Doc(name="Isabella Luanna", pay=Pay(values=[530, 100])).save() + Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() + Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.average('pay.values'), 170) + self.assertEqual(Doc.objects.average("pay.values"), 170) def test_array_average(self): class Doc(Document): @@ -3280,7 +3308,7 @@ class QuerySetTest(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.average('values'), 170) + self.assertEqual(Doc.objects.average("values"), 170) def test_embedded_sum(self): class Pay(EmbeddedDocument): @@ -3292,12 +3320,12 @@ class QuerySetTest(unittest.TestCase): Doc.drop_collection() - Doc(name='Wilson Junior', pay=Pay(value=150)).save() - Doc(name='Isabella Luanna', pay=Pay(value=530)).save() - Doc(name='Tayza mariana', pay=Pay(value=165)).save() - Doc(name='Eliana Costa', pay=Pay(value=115)).save() + Doc(name="Wilson Junior", pay=Pay(value=150)).save() + Doc(name="Isabella Luanna", pay=Pay(value=530)).save() + Doc(name="Tayza mariana", pay=Pay(value=165)).save() + Doc(name="Eliana Costa", pay=Pay(value=115)).save() - self.assertEqual(Doc.objects.sum('pay.value'), 960) + self.assertEqual(Doc.objects.sum("pay.value"), 960) def test_embedded_array_sum(self): class Pay(EmbeddedDocument): @@ -3309,12 +3337,12 @@ class QuerySetTest(unittest.TestCase): Doc.drop_collection() - Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save() - Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save() - Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save() - Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save() + Doc(name="Wilson Junior", pay=Pay(values=[150, 100])).save() + Doc(name="Isabella Luanna", pay=Pay(values=[530, 100])).save() + Doc(name="Tayza mariana", pay=Pay(values=[165, 100])).save() + Doc(name="Eliana Costa", pay=Pay(values=[115, 100])).save() - self.assertEqual(Doc.objects.sum('pay.values'), 1360) + self.assertEqual(Doc.objects.sum("pay.values"), 1360) def test_array_sum(self): class Doc(Document): @@ -3327,21 +3355,24 @@ class QuerySetTest(unittest.TestCase): Doc(values=[165, 100]).save() Doc(values=[115, 100]).save() - self.assertEqual(Doc.objects.sum('values'), 1360) + self.assertEqual(Doc.objects.sum("values"), 1360) def test_distinct(self): """Ensure that the QuerySet.distinct method works. """ - self.Person(name='Mr Orange', age=20).save() - self.Person(name='Mr White', age=20).save() - self.Person(name='Mr Orange', age=30).save() - self.Person(name='Mr Pink', age=30).save() - self.assertEqual(set(self.Person.objects.distinct('name')), - set(['Mr Orange', 'Mr White', 'Mr Pink'])) - self.assertEqual(set(self.Person.objects.distinct('age')), - set([20, 30])) - self.assertEqual(set(self.Person.objects(age=30).distinct('name')), - set(['Mr Orange', 'Mr Pink'])) + self.Person(name="Mr Orange", age=20).save() + self.Person(name="Mr White", age=20).save() + self.Person(name="Mr Orange", age=30).save() + self.Person(name="Mr Pink", age=30).save() + self.assertEqual( + set(self.Person.objects.distinct("name")), + set(["Mr Orange", "Mr White", "Mr Pink"]), + ) + self.assertEqual(set(self.Person.objects.distinct("age")), set([20, 30])) + self.assertEqual( + set(self.Person.objects(age=30).distinct("name")), + set(["Mr Orange", "Mr Pink"]), + ) def test_distinct_handles_references(self): class Foo(Document): @@ -3367,53 +3398,58 @@ class QuerySetTest(unittest.TestCase): content = StringField() is_active = BooleanField(default=True) - meta = {'indexes': [ - {'fields': ['$title', "$content"], - 'default_language': 'portuguese', - 'weights': {'title': 10, 'content': 2} - } - ]} + meta = { + "indexes": [ + { + "fields": ["$title", "$content"], + "default_language": "portuguese", + "weights": {"title": 10, "content": 2}, + } + ] + } News.drop_collection() info = News.objects._collection.index_information() - self.assertIn('title_text_content_text', info) - self.assertIn('textIndexVersion', info['title_text_content_text']) + self.assertIn("title_text_content_text", info) + self.assertIn("textIndexVersion", info["title_text_content_text"]) - News(title="Neymar quebrou a vertebra", - content="O Brasil sofre com a perda de Neymar").save() + News( + title="Neymar quebrou a vertebra", + content="O Brasil sofre com a perda de Neymar", + ).save() - News(title="Brasil passa para as quartas de finais", - content="Com o brasil nas quartas de finais teremos um " - "jogo complicado com a alemanha").save() + News( + title="Brasil passa para as quartas de finais", + content="Com o brasil nas quartas de finais teremos um " + "jogo complicado com a alemanha", + ).save() - count = News.objects.search_text( - "neymar", language="portuguese").count() + count = News.objects.search_text("neymar", language="portuguese").count() self.assertEqual(count, 1) - count = News.objects.search_text( - "brasil -neymar").count() + count = News.objects.search_text("brasil -neymar").count() self.assertEqual(count, 1) - News(title=u"As eleições no Brasil já estão em planejamento", - content=u"A candidata dilma roussef já começa o teu planejamento", - is_active=False).save() + News( + title=u"As eleições no Brasil já estão em planejamento", + content=u"A candidata dilma roussef já começa o teu planejamento", + is_active=False, + ).save() - new = News.objects(is_active=False).search_text( - "dilma", language="pt").first() + new = News.objects(is_active=False).search_text("dilma", language="pt").first() - query = News.objects(is_active=False).search_text( - "dilma", language="pt")._query + query = News.objects(is_active=False).search_text("dilma", language="pt")._query self.assertEqual( - query, {'$text': { - '$search': 'dilma', '$language': 'pt'}, - 'is_active': False}) + query, + {"$text": {"$search": "dilma", "$language": "pt"}, "is_active": False}, + ) self.assertFalse(new.is_active) - self.assertIn('dilma', new.content) - self.assertIn('planejamento', new.title) + self.assertIn("dilma", new.content) + self.assertIn("planejamento", new.title) query = News.objects.search_text("candidata") self.assertEqual(query._search_text, "candidata") @@ -3422,15 +3458,14 @@ class QuerySetTest(unittest.TestCase): self.assertIsInstance(new.get_text_score(), float) # count - query = News.objects.search_text('brasil').order_by('$text_score') + query = News.objects.search_text("brasil").order_by("$text_score") self.assertEqual(query._search_text, "brasil") self.assertEqual(query.count(), 3) - self.assertEqual(query._query, {'$text': {'$search': 'brasil'}}) + self.assertEqual(query._query, {"$text": {"$search": "brasil"}}) cursor_args = query._cursor_args - cursor_args_fields = cursor_args['projection'] - self.assertEqual( - cursor_args_fields, {'_text_score': {'$meta': 'textScore'}}) + cursor_args_fields = cursor_args["projection"] + self.assertEqual(cursor_args_fields, {"_text_score": {"$meta": "textScore"}}) text_scores = [i.get_text_score() for i in query] self.assertEqual(len(text_scores), 3) @@ -3440,20 +3475,19 @@ class QuerySetTest(unittest.TestCase): max_text_score = text_scores[0] # get item - item = News.objects.search_text( - 'brasil').order_by('$text_score').first() + item = News.objects.search_text("brasil").order_by("$text_score").first() self.assertEqual(item.get_text_score(), max_text_score) def test_distinct_handles_references_to_alias(self): - register_connection('testdb', 'mongoenginetest2') + register_connection("testdb", "mongoenginetest2") class Foo(Document): bar = ReferenceField("Bar") - meta = {'db_alias': 'testdb'} + meta = {"db_alias": "testdb"} class Bar(Document): text = StringField() - meta = {'db_alias': 'testdb'} + meta = {"db_alias": "testdb"} Bar.drop_collection() Foo.drop_collection() @@ -3469,8 +3503,9 @@ class QuerySetTest(unittest.TestCase): def test_distinct_handles_db_field(self): """Ensure that distinct resolves field name to db_field as expected. """ + class Product(Document): - product_id = IntField(db_field='pid') + product_id = IntField(db_field="pid") Product.drop_collection() @@ -3478,15 +3513,12 @@ class QuerySetTest(unittest.TestCase): Product(product_id=2).save() Product(product_id=1).save() - self.assertEqual(set(Product.objects.distinct('product_id')), - set([1, 2])) - self.assertEqual(set(Product.objects.distinct('pid')), - set([1, 2])) + self.assertEqual(set(Product.objects.distinct("product_id")), set([1, 2])) + self.assertEqual(set(Product.objects.distinct("pid")), set([1, 2])) Product.drop_collection() def test_distinct_ListField_EmbeddedDocumentField(self): - class Author(EmbeddedDocument): name = StringField() @@ -3524,8 +3556,8 @@ class QuerySetTest(unittest.TestCase): Book.drop_collection() - europe = Continent(continent_name='europe') - asia = Continent(continent_name='asia') + europe = Continent(continent_name="europe") + asia = Continent(continent_name="asia") scotland = Country(country_name="Scotland", continent=europe) tibet = Country(country_name="Tibet", continent=asia) @@ -3544,13 +3576,12 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(continent_list, [europe, asia]) def test_distinct_ListField_ReferenceField(self): - class Bar(Document): text = StringField() class Foo(Document): - bar = ReferenceField('Bar') - bar_lst = ListField(ReferenceField('Bar')) + bar = ReferenceField("Bar") + bar_lst = ListField(ReferenceField("Bar")) Bar.drop_collection() Foo.drop_collection() @@ -3569,6 +3600,7 @@ class QuerySetTest(unittest.TestCase): def test_custom_manager(self): """Ensure that custom QuerySetManager instances work as expected. """ + class BlogPost(Document): tags = ListField(StringField()) deleted = BooleanField(default=False) @@ -3586,32 +3618,30 @@ class QuerySetTest(unittest.TestCase): @queryset_manager def music_posts(doc_cls, queryset, deleted=False): - return queryset(tags='music', - deleted=deleted).order_by('date') + return queryset(tags="music", deleted=deleted).order_by("date") BlogPost.drop_collection() - post1 = BlogPost(tags=['music', 'film']).save() - post2 = BlogPost(tags=['music']).save() - post3 = BlogPost(tags=['film', 'actors']).save() - post4 = BlogPost(tags=['film', 'actors', 'music'], deleted=True).save() + post1 = BlogPost(tags=["music", "film"]).save() + post2 = BlogPost(tags=["music"]).save() + post3 = BlogPost(tags=["film", "actors"]).save() + post4 = BlogPost(tags=["film", "actors", "music"], deleted=True).save() - self.assertEqual([p.id for p in BlogPost.objects()], - [post1.id, post2.id, post3.id]) - self.assertEqual([p.id for p in BlogPost.objects_1_arg()], - [post1.id, post2.id, post3.id]) - self.assertEqual([p.id for p in BlogPost.music_posts()], - [post1.id, post2.id]) + self.assertEqual( + [p.id for p in BlogPost.objects()], [post1.id, post2.id, post3.id] + ) + self.assertEqual( + [p.id for p in BlogPost.objects_1_arg()], [post1.id, post2.id, post3.id] + ) + self.assertEqual([p.id for p in BlogPost.music_posts()], [post1.id, post2.id]) - self.assertEqual([p.id for p in BlogPost.music_posts(True)], - [post4.id]) + self.assertEqual([p.id for p in BlogPost.music_posts(True)], [post4.id]) BlogPost.drop_collection() def test_custom_manager_overriding_objects_works(self): - class Foo(Document): - bar = StringField(default='bar') + bar = StringField(default="bar") active = BooleanField(default=False) @queryset_manager @@ -3635,9 +3665,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(1, Foo.objects.count()) def test_inherit_objects(self): - class Foo(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} active = BooleanField(default=True) @queryset_manager @@ -3652,9 +3681,8 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(0, Bar.objects.count()) def test_inherit_objects_override(self): - class Foo(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} active = BooleanField(default=True) @queryset_manager @@ -3662,7 +3690,6 @@ class QuerySetTest(unittest.TestCase): return queryset(active=True) class Bar(Foo): - @queryset_manager def objects(klass, queryset): return queryset(active=False) @@ -3675,12 +3702,13 @@ class QuerySetTest(unittest.TestCase): def test_query_value_conversion(self): """Ensure that query values are properly converted when necessary. """ + class BlogPost(Document): author = ReferenceField(self.Person) BlogPost.drop_collection() - person = self.Person(name='test', age=30) + person = self.Person(name="test", age=30) person.save() post = BlogPost(author=person) @@ -3701,14 +3729,15 @@ class QuerySetTest(unittest.TestCase): def test_update_value_conversion(self): """Ensure that values used in updates are converted before use. """ + class Group(Document): members = ListField(ReferenceField(self.Person)) Group.drop_collection() - user1 = self.Person(name='user1') + user1 = self.Person(name="user1") user1.save() - user2 = self.Person(name='user2') + user2 = self.Person(name="user2") user2.save() group = Group() @@ -3726,6 +3755,7 @@ class QuerySetTest(unittest.TestCase): def test_bulk(self): """Ensure bulk querying by object id returns a proper dict. """ + class BlogPost(Document): title = StringField() @@ -3764,13 +3794,13 @@ class QuerySetTest(unittest.TestCase): def test_custom_querysets(self): """Ensure that custom QuerySet classes may be used. """ - class CustomQuerySet(QuerySet): + class CustomQuerySet(QuerySet): def not_empty(self): return self.count() > 0 class Post(Document): - meta = {'queryset_class': CustomQuerySet} + meta = {"queryset_class": CustomQuerySet} Post.drop_collection() @@ -3787,7 +3817,6 @@ class QuerySetTest(unittest.TestCase): """ class CustomQuerySet(QuerySet): - def not_empty(self): return self.count() > 0 @@ -3812,7 +3841,6 @@ class QuerySetTest(unittest.TestCase): """ class CustomQuerySetManager(QuerySetManager): - @staticmethod def get_queryset(doc_cls, queryset): return queryset(is_published=True) @@ -3835,12 +3863,11 @@ class QuerySetTest(unittest.TestCase): """ class CustomQuerySet(QuerySet): - def not_empty(self): return self.count() > 0 class Base(Document): - meta = {'abstract': True, 'queryset_class': CustomQuerySet} + meta = {"abstract": True, "queryset_class": CustomQuerySet} class Post(Base): pass @@ -3859,7 +3886,6 @@ class QuerySetTest(unittest.TestCase): """ class CustomQuerySet(QuerySet): - def not_empty(self): return self.count() > 0 @@ -3867,7 +3893,7 @@ class QuerySetTest(unittest.TestCase): queryset_class = CustomQuerySet class Base(Document): - meta = {'abstract': True} + meta = {"abstract": True} objects = CustomQuerySetManager() class Post(Base): @@ -3891,10 +3917,13 @@ class QuerySetTest(unittest.TestCase): for i in range(10): Post(title="Post %s" % i).save() - self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)) + self.assertEqual( + 5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True) + ) self.assertEqual( - 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False)) + 10, Post.objects.limit(5).skip(5).count(with_limit_and_skip=False) + ) def test_count_and_none(self): """Test count works with None()""" @@ -3916,11 +3945,12 @@ class QuerySetTest(unittest.TestCase): class A(Document): b = ListField(EmbeddedDocumentField(B)) - self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0) + self.assertEqual(A.objects(b=[{"c": "c"}]).count(), 0) def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ + class Post(Document): title = StringField() @@ -3937,6 +3967,7 @@ class QuerySetTest(unittest.TestCase): def test_order_then_filter(self): """Ensure that ordering still works after filtering. """ + class Number(Document): n = IntField() @@ -3946,14 +3977,15 @@ class QuerySetTest(unittest.TestCase): n1 = Number.objects.create(n=1) self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by('n')), [n1, n2]) - self.assertEqual(list(Number.objects.order_by('n').filter()), [n1, n2]) + self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) + self.assertEqual(list(Number.objects.order_by("n").filter()), [n1, n2]) Number.drop_collection() def test_clone(self): """Ensure that cloning clones complex querysets """ + class Number(Document): n = IntField() @@ -3983,19 +4015,20 @@ class QuerySetTest(unittest.TestCase): def test_using(self): """Ensure that switching databases for a queryset is possible """ + class Number2(Document): n = IntField() Number2.drop_collection() - with switch_db(Number2, 'test2') as Number2: + with switch_db(Number2, "test2") as Number2: Number2.drop_collection() for i in range(1, 10): t = Number2(n=i) - t.switch_db('test2') + t.switch_db("test2") t.save() - self.assertEqual(len(Number2.objects.using('test2')), 9) + self.assertEqual(len(Number2.objects.using("test2")), 9) def test_unset_reference(self): class Comment(Document): @@ -4007,7 +4040,7 @@ class QuerySetTest(unittest.TestCase): Comment.drop_collection() Post.drop_collection() - comment = Comment.objects.create(text='test') + comment = Comment.objects.create(text="test") post = Post.objects.create(comment=comment) self.assertEqual(post.comment, comment) @@ -4020,7 +4053,7 @@ class QuerySetTest(unittest.TestCase): def test_order_works_with_custom_db_field_names(self): class Number(Document): - n = IntField(db_field='number') + n = IntField(db_field="number") Number.drop_collection() @@ -4028,13 +4061,14 @@ class QuerySetTest(unittest.TestCase): n1 = Number.objects.create(n=1) self.assertEqual(list(Number.objects), [n2, n1]) - self.assertEqual(list(Number.objects.order_by('n')), [n1, n2]) + self.assertEqual(list(Number.objects.order_by("n")), [n1, n2]) Number.drop_collection() def test_order_works_with_primary(self): """Ensure that order_by and primary work. """ + class Number(Document): n = IntField(primary_key=True) @@ -4044,28 +4078,29 @@ class QuerySetTest(unittest.TestCase): Number(n=2).save() Number(n=3).save() - numbers = [n.n for n in Number.objects.order_by('-n')] + numbers = [n.n for n in Number.objects.order_by("-n")] self.assertEqual([3, 2, 1], numbers) - numbers = [n.n for n in Number.objects.order_by('+n')] + numbers = [n.n for n in Number.objects.order_by("+n")] self.assertEqual([1, 2, 3], numbers) Number.drop_collection() def test_ensure_index(self): """Ensure that manual creation of indexes works. """ + class Comment(Document): message = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} - Comment.ensure_index('message') + Comment.ensure_index("message") info = Comment.objects._collection.index_information() - info = [(value['key'], - value.get('unique', False), - value.get('sparse', False)) - for key, value in iteritems(info)] - self.assertIn(([('_cls', 1), ('message', 1)], False, False), info) + info = [ + (value["key"], value.get("unique", False), value.get("sparse", False)) + for key, value in iteritems(info) + ] + self.assertIn(([("_cls", 1), ("message", 1)], False, False), info) def test_where(self): """Ensure that where clauses work. @@ -4084,23 +4119,25 @@ class QuerySetTest(unittest.TestCase): b.save() c.save() - query = IntPair.objects.where('this[~fielda] >= this[~fieldb]') - self.assertEqual( - 'this["fielda"] >= this["fieldb"]', query._where_clause) + query = IntPair.objects.where("this[~fielda] >= this[~fieldb]") + self.assertEqual('this["fielda"] >= this["fieldb"]', query._where_clause) results = list(query) self.assertEqual(2, len(results)) self.assertIn(a, results) self.assertIn(c, results) - query = IntPair.objects.where('this[~fielda] == this[~fieldb]') + query = IntPair.objects.where("this[~fielda] == this[~fieldb]") results = list(query) self.assertEqual(1, len(results)) self.assertIn(a, results) query = IntPair.objects.where( - 'function() { return this[~fielda] >= this[~fieldb] }') + "function() { return this[~fielda] >= this[~fieldb] }" + ) self.assertEqual( - 'function() { return this["fielda"] >= this["fieldb"] }', query._where_clause) + 'function() { return this["fielda"] >= this["fieldb"] }', + query._where_clause, + ) results = list(query) self.assertEqual(2, len(results)) self.assertIn(a, results) @@ -4110,7 +4147,6 @@ class QuerySetTest(unittest.TestCase): list(IntPair.objects.where(fielda__gte=3)) def test_scalar(self): - class Organization(Document): name = StringField() @@ -4127,13 +4163,13 @@ class QuerySetTest(unittest.TestCase): # Efficient way to get all unique organization names for a given # set of users (Pretend this has additional filtering.) - user_orgs = set(User.objects.scalar('organization')) - orgs = Organization.objects(id__in=user_orgs).scalar('name') - self.assertEqual(list(orgs), ['White House']) + user_orgs = set(User.objects.scalar("organization")) + orgs = Organization.objects(id__in=user_orgs).scalar("name") + self.assertEqual(list(orgs), ["White House"]) # Efficient for generating listings, too. - orgs = Organization.objects.scalar('name').in_bulk(list(user_orgs)) - user_map = User.objects.scalar('name', 'organization') + orgs = Organization.objects.scalar("name").in_bulk(list(user_orgs)) + user_map = User.objects.scalar("name", "organization") user_listing = [(user, orgs[org]) for user, org in user_map] self.assertEqual([("Bob Dole", "White House")], user_listing) @@ -4148,7 +4184,7 @@ class QuerySetTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.scalar('x', 'y')) + plist = list(TestDoc.objects.scalar("x", "y")) self.assertEqual(len(plist), 3) self.assertEqual(plist[0], (10, True)) @@ -4166,21 +4202,16 @@ class QuerySetTest(unittest.TestCase): UserDoc(name="Eliana", age=37).save() UserDoc(name="Tayza", age=15).save() - ulist = list(UserDoc.objects.scalar('name', 'age')) + ulist = list(UserDoc.objects.scalar("name", "age")) - self.assertEqual(ulist, [ - (u'Wilson Jr', 19), - (u'Wilson', 43), - (u'Eliana', 37), - (u'Tayza', 15)]) + self.assertEqual( + ulist, + [(u"Wilson Jr", 19), (u"Wilson", 43), (u"Eliana", 37), (u"Tayza", 15)], + ) - ulist = list(UserDoc.objects.scalar('name').order_by('age')) + ulist = list(UserDoc.objects.scalar("name").order_by("age")) - self.assertEqual(ulist, [ - (u'Tayza'), - (u'Wilson Jr'), - (u'Eliana'), - (u'Wilson')]) + self.assertEqual(ulist, [(u"Tayza"), (u"Wilson Jr"), (u"Eliana"), (u"Wilson")]) def test_scalar_embedded(self): class Profile(EmbeddedDocument): @@ -4197,30 +4228,45 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() - Person(profile=Profile(name="Wilson Jr", age=19), - locale=Locale(city="Corumba-GO", country="Brazil")).save() + Person( + profile=Profile(name="Wilson Jr", age=19), + locale=Locale(city="Corumba-GO", country="Brazil"), + ).save() - Person(profile=Profile(name="Gabriel Falcao", age=23), - locale=Locale(city="New York", country="USA")).save() + Person( + profile=Profile(name="Gabriel Falcao", age=23), + locale=Locale(city="New York", country="USA"), + ).save() - Person(profile=Profile(name="Lincoln de souza", age=28), - locale=Locale(city="Belo Horizonte", country="Brazil")).save() + Person( + profile=Profile(name="Lincoln de souza", age=28), + locale=Locale(city="Belo Horizonte", country="Brazil"), + ).save() - Person(profile=Profile(name="Walter cruz", age=30), - locale=Locale(city="Brasilia", country="Brazil")).save() + Person( + profile=Profile(name="Walter cruz", age=30), + locale=Locale(city="Brasilia", country="Brazil"), + ).save() self.assertEqual( - list(Person.objects.order_by( - 'profile__age').scalar('profile__name')), - [u'Wilson Jr', u'Gabriel Falcao', u'Lincoln de souza', u'Walter cruz']) + list(Person.objects.order_by("profile__age").scalar("profile__name")), + [u"Wilson Jr", u"Gabriel Falcao", u"Lincoln de souza", u"Walter cruz"], + ) - ulist = list(Person.objects.order_by('locale.city') - .scalar('profile__name', 'profile__age', 'locale__city')) - self.assertEqual(ulist, - [(u'Lincoln de souza', 28, u'Belo Horizonte'), - (u'Walter cruz', 30, u'Brasilia'), - (u'Wilson Jr', 19, u'Corumba-GO'), - (u'Gabriel Falcao', 23, u'New York')]) + ulist = list( + Person.objects.order_by("locale.city").scalar( + "profile__name", "profile__age", "locale__city" + ) + ) + self.assertEqual( + ulist, + [ + (u"Lincoln de souza", 28, u"Belo Horizonte"), + (u"Walter cruz", 30, u"Brasilia"), + (u"Wilson Jr", 19, u"Corumba-GO"), + (u"Gabriel Falcao", 23, u"New York"), + ], + ) def test_scalar_decimal(self): from decimal import Decimal @@ -4230,10 +4276,10 @@ class QuerySetTest(unittest.TestCase): rating = DecimalField() Person.drop_collection() - Person(name="Wilson Jr", rating=Decimal('1.0')).save() + Person(name="Wilson Jr", rating=Decimal("1.0")).save() - ulist = list(Person.objects.scalar('name', 'rating')) - self.assertEqual(ulist, [(u'Wilson Jr', Decimal('1.0'))]) + ulist = list(Person.objects.scalar("name", "rating")) + self.assertEqual(ulist, [(u"Wilson Jr", Decimal("1.0"))]) def test_scalar_reference_field(self): class State(Document): @@ -4251,8 +4297,8 @@ class QuerySetTest(unittest.TestCase): Person(name="Wilson JR", state=s1).save() - plist = list(Person.objects.scalar('name', 'state')) - self.assertEqual(plist, [(u'Wilson JR', s1)]) + plist = list(Person.objects.scalar("name", "state")) + self.assertEqual(plist, [(u"Wilson JR", s1)]) def test_scalar_generic_reference_field(self): class State(Document): @@ -4270,8 +4316,8 @@ class QuerySetTest(unittest.TestCase): Person(name="Wilson JR", state=s1).save() - plist = list(Person.objects.scalar('name', 'state')) - self.assertEqual(plist, [(u'Wilson JR', s1)]) + plist = list(Person.objects.scalar("name", "state")) + self.assertEqual(plist, [(u"Wilson JR", s1)]) def test_generic_reference_field_with_only_and_as_pymongo(self): class TestPerson(Document): @@ -4284,26 +4330,32 @@ class QuerySetTest(unittest.TestCase): TestPerson.drop_collection() TestActivity.drop_collection() - person = TestPerson(name='owner') + person = TestPerson(name="owner") person.save() - a1 = TestActivity(name='a1', owner=person) + a1 = TestActivity(name="a1", owner=person) a1.save() - activity = TestActivity.objects(owner=person).scalar('id', 'owner').no_dereference().first() + activity = ( + TestActivity.objects(owner=person) + .scalar("id", "owner") + .no_dereference() + .first() + ) self.assertEqual(activity[0], a1.pk) - self.assertEqual(activity[1]['_ref'], DBRef('test_person', person.pk)) + self.assertEqual(activity[1]["_ref"], DBRef("test_person", person.pk)) - activity = TestActivity.objects(owner=person).only('id', 'owner')[0] + activity = TestActivity.objects(owner=person).only("id", "owner")[0] self.assertEqual(activity.pk, a1.pk) self.assertEqual(activity.owner, person) - activity = TestActivity.objects(owner=person).only('id', 'owner').as_pymongo().first() - self.assertEqual(activity['_id'], a1.pk) - self.assertTrue(activity['owner']['_ref'], DBRef('test_person', person.pk)) + activity = ( + TestActivity.objects(owner=person).only("id", "owner").as_pymongo().first() + ) + self.assertEqual(activity["_id"], a1.pk) + self.assertTrue(activity["owner"]["_ref"], DBRef("test_person", person.pk)) def test_scalar_db_field(self): - class TestDoc(Document): x = IntField() y = BooleanField() @@ -4314,14 +4366,13 @@ class QuerySetTest(unittest.TestCase): TestDoc(x=20, y=False).save() TestDoc(x=30, y=True).save() - plist = list(TestDoc.objects.scalar('x', 'y')) + plist = list(TestDoc.objects.scalar("x", "y")) self.assertEqual(len(plist), 3) self.assertEqual(plist[0], (10, True)) self.assertEqual(plist[1], (20, False)) self.assertEqual(plist[2], (30, True)) def test_scalar_primary_key(self): - class SettingValue(Document): key = StringField(primary_key=True) value = StringField() @@ -4330,8 +4381,8 @@ class QuerySetTest(unittest.TestCase): s = SettingValue(key="test", value="test value") s.save() - val = SettingValue.objects.scalar('key', 'value') - self.assertEqual(list(val), [('test', 'test value')]) + val = SettingValue.objects.scalar("key", "value") + self.assertEqual(list(val), [("test", "test value")]) def test_scalar_cursor_behaviour(self): """Ensure that a query returns a valid set of results. @@ -4342,83 +4393,94 @@ class QuerySetTest(unittest.TestCase): person2.save() # Find all people in the collection - people = self.Person.objects.scalar('name') + people = self.Person.objects.scalar("name") self.assertEqual(people.count(), 2) results = list(people) self.assertEqual(results[0], "User A") self.assertEqual(results[1], "User B") # Use a query to filter the people found to just person1 - people = self.Person.objects(age=20).scalar('name') + people = self.Person.objects(age=20).scalar("name") self.assertEqual(people.count(), 1) person = people.next() self.assertEqual(person, "User A") # Test limit - people = list(self.Person.objects.limit(1).scalar('name')) + people = list(self.Person.objects.limit(1).scalar("name")) self.assertEqual(len(people), 1) - self.assertEqual(people[0], 'User A') + self.assertEqual(people[0], "User A") # Test skip - people = list(self.Person.objects.skip(1).scalar('name')) + people = list(self.Person.objects.skip(1).scalar("name")) self.assertEqual(len(people), 1) - self.assertEqual(people[0], 'User B') + self.assertEqual(people[0], "User B") person3 = self.Person(name="User C", age=40) person3.save() # Test slice limit - people = list(self.Person.objects[:2].scalar('name')) + people = list(self.Person.objects[:2].scalar("name")) self.assertEqual(len(people), 2) - self.assertEqual(people[0], 'User A') - self.assertEqual(people[1], 'User B') + self.assertEqual(people[0], "User A") + self.assertEqual(people[1], "User B") # Test slice skip - people = list(self.Person.objects[1:].scalar('name')) + people = list(self.Person.objects[1:].scalar("name")) self.assertEqual(len(people), 2) - self.assertEqual(people[0], 'User B') - self.assertEqual(people[1], 'User C') + self.assertEqual(people[0], "User B") + self.assertEqual(people[1], "User C") # Test slice limit and skip - people = list(self.Person.objects[1:2].scalar('name')) + people = list(self.Person.objects[1:2].scalar("name")) self.assertEqual(len(people), 1) - self.assertEqual(people[0], 'User B') + self.assertEqual(people[0], "User B") - people = list(self.Person.objects[1:1].scalar('name')) + people = list(self.Person.objects[1:1].scalar("name")) self.assertEqual(len(people), 0) # Test slice out of range - people = list(self.Person.objects.scalar('name')[80000:80001]) + people = list(self.Person.objects.scalar("name")[80000:80001]) self.assertEqual(len(people), 0) # Test larger slice __repr__ self.Person.objects.delete() for i in range(55): - self.Person(name='A%s' % i, age=i).save() + self.Person(name="A%s" % i, age=i).save() - self.assertEqual(self.Person.objects.scalar('name').count(), 55) + self.assertEqual(self.Person.objects.scalar("name").count(), 55) self.assertEqual( - "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) + "A0", "%s" % self.Person.objects.order_by("name").scalar("name").first() + ) self.assertEqual( - "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) + "A0", "%s" % self.Person.objects.scalar("name").order_by("name")[0] + ) if six.PY3: - self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( - 'age').scalar('name')[1:3]) - self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( - 'age').scalar('name')[51:53]) + self.assertEqual( + "['A1', 'A2']", + "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + ) + self.assertEqual( + "['A51', 'A52']", + "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + ) else: - self.assertEqual("[u'A1', u'A2']", "%s" % self.Person.objects.order_by( - 'age').scalar('name')[1:3]) - self.assertEqual("[u'A51', u'A52']", "%s" % self.Person.objects.order_by( - 'age').scalar('name')[51:53]) + self.assertEqual( + "[u'A1', u'A2']", + "%s" % self.Person.objects.order_by("age").scalar("name")[1:3], + ) + self.assertEqual( + "[u'A51', u'A52']", + "%s" % self.Person.objects.order_by("age").scalar("name")[51:53], + ) # with_id and in_bulk - person = self.Person.objects.order_by('name').first() - self.assertEqual("A0", "%s" % - self.Person.objects.scalar('name').with_id(person.id)) + person = self.Person.objects.order_by("name").first() + self.assertEqual( + "A0", "%s" % self.Person.objects.scalar("name").with_id(person.id) + ) - pks = self.Person.objects.order_by('age').scalar('pk')[1:3] - names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() + pks = self.Person.objects.order_by("age").scalar("pk")[1:3] + names = self.Person.objects.scalar("name").in_bulk(list(pks)).values() if six.PY3: expected = "['A1', 'A2']" else: @@ -4430,51 +4492,61 @@ class QuerySetTest(unittest.TestCase): shape = StringField() color = StringField() thick = BooleanField() - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class Bar(Document): foo = ListField(EmbeddedDocumentField(Foo)) - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} Bar.drop_collection() - b1 = Bar(foo=[Foo(shape="square", color="purple", thick=False), - Foo(shape="circle", color="red", thick=True)]) + b1 = Bar( + foo=[ + Foo(shape="square", color="purple", thick=False), + Foo(shape="circle", color="red", thick=True), + ] + ) b1.save() - b2 = Bar(foo=[Foo(shape="square", color="red", thick=True), - Foo(shape="circle", color="purple", thick=False)]) + b2 = Bar( + foo=[ + Foo(shape="square", color="red", thick=True), + Foo(shape="circle", color="purple", thick=False), + ] + ) b2.save() - b3 = Bar(foo=[Foo(shape="square", thick=True), - Foo(shape="circle", color="purple", thick=False)]) + b3 = Bar( + foo=[ + Foo(shape="square", thick=True), + Foo(shape="circle", color="purple", thick=False), + ] + ) b3.save() - ak = list( - Bar.objects(foo__match={'shape': "square", "color": "purple"})) + ak = list(Bar.objects(foo__match={"shape": "square", "color": "purple"})) self.assertEqual([b1], ak) - ak = list( - Bar.objects(foo__elemMatch={'shape': "square", "color": "purple"})) + ak = list(Bar.objects(foo__elemMatch={"shape": "square", "color": "purple"})) self.assertEqual([b1], ak) ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) self.assertEqual([b1], ak) ak = list( - Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True})) + Bar.objects(foo__elemMatch={"shape": "square", "color__exists": True}) + ) + self.assertEqual([b1, b2], ak) + + ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": True})) self.assertEqual([b1, b2], ak) ak = list( - Bar.objects(foo__match={'shape': "square", "color__exists": True})) - self.assertEqual([b1, b2], ak) - - ak = list( - Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False})) + Bar.objects(foo__elemMatch={"shape": "square", "color__exists": False}) + ) self.assertEqual([b3], ak) - ak = list( - Bar.objects(foo__match={'shape': "square", "color__exists": False})) + ak = list(Bar.objects(foo__match={"shape": "square", "color__exists": False})) self.assertEqual([b3], ak) def test_upsert_includes_cls(self): @@ -4485,24 +4557,25 @@ class QuerySetTest(unittest.TestCase): test = StringField() Test.drop_collection() - Test.objects(test='foo').update_one(upsert=True, set__test='foo') - self.assertNotIn('_cls', Test._collection.find_one()) + Test.objects(test="foo").update_one(upsert=True, set__test="foo") + self.assertNotIn("_cls", Test._collection.find_one()) class Test(Document): - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} test = StringField() Test.drop_collection() - Test.objects(test='foo').update_one(upsert=True, set__test='foo') - self.assertIn('_cls', Test._collection.find_one()) + Test.objects(test="foo").update_one(upsert=True, set__test="foo") + self.assertIn("_cls", Test._collection.find_one()) def test_update_upsert_looks_like_a_digit(self): class MyDoc(DynamicDocument): pass + MyDoc.drop_collection() self.assertEqual(1, MyDoc.objects.update_one(upsert=True, inc__47=1)) - self.assertEqual(MyDoc.objects.get()['47'], 1) + self.assertEqual(MyDoc.objects.get()["47"], 1) def test_dictfield_key_looks_like_a_digit(self): """Only should work with DictField even if they have numeric keys.""" @@ -4511,86 +4584,84 @@ class QuerySetTest(unittest.TestCase): test = DictField() MyDoc.drop_collection() - doc = MyDoc(test={'47': 1}) + doc = MyDoc(test={"47": 1}) doc.save() - self.assertEqual(MyDoc.objects.only('test__47').get().test['47'], 1) + self.assertEqual(MyDoc.objects.only("test__47").get().test["47"], 1) def test_read_preference(self): class Bar(Document): txt = StringField() - meta = { - 'indexes': ['txt'] - } + meta = {"indexes": ["txt"]} Bar.drop_collection() bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) self.assertEqual([], bars) - self.assertRaises(TypeError, Bar.objects, read_preference='Primary') + self.assertRaises(TypeError, Bar.objects, read_preference="Primary") # read_preference as a kwarg bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) # read_preference as a query set method bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) # read_preference after skip - bars = Bar.objects.skip(1) \ - .read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + bars = Bar.objects.skip(1).read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) # read_preference after limit - bars = Bar.objects.limit(1) \ - .read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + bars = Bar.objects.limit(1).read_preference(ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) # read_preference after order_by - bars = Bar.objects.order_by('txt') \ - .read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + bars = Bar.objects.order_by("txt").read_preference( + ReadPreference.SECONDARY_PREFERRED + ) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) # read_preference after hint - bars = Bar.objects.hint([('txt', 1)]) \ - .read_preference(ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._read_preference, - ReadPreference.SECONDARY_PREFERRED) - self.assertEqual(bars._cursor._Cursor__read_preference, - ReadPreference.SECONDARY_PREFERRED) + bars = Bar.objects.hint([("txt", 1)]).read_preference( + ReadPreference.SECONDARY_PREFERRED + ) + self.assertEqual(bars._read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual( + bars._cursor._Cursor__read_preference, ReadPreference.SECONDARY_PREFERRED + ) def test_read_preference_aggregation_framework(self): class Bar(Document): txt = StringField() - meta = { - 'indexes': ['txt'] - } + meta = {"indexes": ["txt"]} + # Aggregates with read_preference - bars = Bar.objects \ - .read_preference(ReadPreference.SECONDARY_PREFERRED) \ - .aggregate() - self.assertEqual(bars._CommandCursor__collection.read_preference, - ReadPreference.SECONDARY_PREFERRED) + bars = Bar.objects.read_preference( + ReadPreference.SECONDARY_PREFERRED + ).aggregate() + self.assertEqual( + bars._CommandCursor__collection.read_preference, + ReadPreference.SECONDARY_PREFERRED, + ) def test_json_simple(self): - class Embedded(EmbeddedDocument): string = StringField() @@ -4603,7 +4674,7 @@ class QuerySetTest(unittest.TestCase): Doc(string="Bye", embedded_field=Embedded(string="Bye")).save() Doc().save() - json_data = Doc.objects.to_json(sort_keys=True, separators=(',', ':')) + json_data = Doc.objects.to_json(sort_keys=True, separators=(",", ":")) doc_objects = list(Doc.objects) self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) @@ -4616,33 +4687,34 @@ class QuerySetTest(unittest.TestCase): pass class Doc(Document): - string_field = StringField(default='1') + string_field = StringField(default="1") int_field = IntField(default=1) float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.datetime.now) embedded_document_field = EmbeddedDocumentField( - EmbeddedDoc, default=lambda: EmbeddedDoc()) + EmbeddedDoc, default=lambda: EmbeddedDoc() + ) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=ObjectId) - reference_field = ReferenceField( - Simple, default=lambda: Simple().save()) + reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) generic_reference_field = GenericReferenceField( - default=lambda: Simple().save()) - sorted_list_field = SortedListField(IntField(), - default=lambda: [1, 2, 3]) + default=lambda: Simple().save() + ) + sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") geo_point_field = GeoPointField(default=lambda: [1, 2]) sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( - default=lambda: EmbeddedDoc()) + default=lambda: EmbeddedDoc() + ) Simple.drop_collection() Doc.drop_collection() @@ -4667,111 +4739,96 @@ class QuerySetTest(unittest.TestCase): User.drop_collection() - User.objects.create(id='Bob', name="Bob Dole", age=89, price=Decimal('1.11')) + User.objects.create(id="Bob", name="Bob Dole", age=89, price=Decimal("1.11")) User.objects.create( - id='Barak', + id="Barak", name="Barak Obama", age=51, - price=Decimal('2.22'), - last_login=LastLogin( - location='White House', - ip='104.107.108.116' - ) + price=Decimal("2.22"), + last_login=LastLogin(location="White House", ip="104.107.108.116"), ) results = User.objects.as_pymongo() + self.assertEqual(set(results[0].keys()), set(["_id", "name", "age", "price"])) self.assertEqual( - set(results[0].keys()), - set(['_id', 'name', 'age', 'price']) - ) - self.assertEqual( - set(results[1].keys()), - set(['_id', 'name', 'age', 'price', 'last_login']) + set(results[1].keys()), set(["_id", "name", "age", "price", "last_login"]) ) - results = User.objects.only('id', 'name').as_pymongo() - self.assertEqual(set(results[0].keys()), set(['_id', 'name'])) + results = User.objects.only("id", "name").as_pymongo() + self.assertEqual(set(results[0].keys()), set(["_id", "name"])) - users = User.objects.only('name', 'price').as_pymongo() + users = User.objects.only("name", "price").as_pymongo() results = list(users) self.assertIsInstance(results[0], dict) self.assertIsInstance(results[1], dict) - self.assertEqual(results[0]['name'], 'Bob Dole') - self.assertEqual(results[0]['price'], 1.11) - self.assertEqual(results[1]['name'], 'Barak Obama') - self.assertEqual(results[1]['price'], 2.22) + self.assertEqual(results[0]["name"], "Bob Dole") + self.assertEqual(results[0]["price"], 1.11) + self.assertEqual(results[1]["name"], "Barak Obama") + self.assertEqual(results[1]["price"], 2.22) - users = User.objects.only('name', 'last_login').as_pymongo() + users = User.objects.only("name", "last_login").as_pymongo() results = list(users) self.assertIsInstance(results[0], dict) self.assertIsInstance(results[1], dict) - self.assertEqual(results[0], { - '_id': 'Bob', - 'name': 'Bob Dole' - }) - self.assertEqual(results[1], { - '_id': 'Barak', - 'name': 'Barak Obama', - 'last_login': { - 'location': 'White House', - 'ip': '104.107.108.116' - } - }) + self.assertEqual(results[0], {"_id": "Bob", "name": "Bob Dole"}) + self.assertEqual( + results[1], + { + "_id": "Barak", + "name": "Barak Obama", + "last_login": {"location": "White House", "ip": "104.107.108.116"}, + }, + ) def test_as_pymongo_returns_cls_attribute_when_using_inheritance(self): class User(Document): name = StringField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} User.drop_collection() user = User(name="Bob Dole").save() result = User.objects.as_pymongo().first() - self.assertEqual( - result, - { - '_cls': 'User', - '_id': user.id, - 'name': 'Bob Dole' - } - ) + self.assertEqual(result, {"_cls": "User", "_id": user.id, "name": "Bob Dole"}) def test_as_pymongo_json_limit_fields(self): - class User(Document): email = EmailField(unique=True, required=True) - password_hash = StringField( - db_field='password_hash', required=True) - password_salt = StringField( - db_field='password_salt', required=True) + password_hash = StringField(db_field="password_hash", required=True) + password_salt = StringField(db_field="password_salt", required=True) User.drop_collection() - User(email="ross@example.com", password_salt="SomeSalt", - password_hash="SomeHash").save() + User( + email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash" + ).save() serialized_user = User.objects.exclude( - 'password_salt', 'password_hash').as_pymongo()[0] - self.assertEqual({'_id', 'email'}, set(serialized_user.keys())) + "password_salt", "password_hash" + ).as_pymongo()[0] + self.assertEqual({"_id", "email"}, set(serialized_user.keys())) serialized_user = User.objects.exclude( - 'id', 'password_salt', 'password_hash').to_json() + "id", "password_salt", "password_hash" + ).to_json() self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) - serialized_user = User.objects.only('email').as_pymongo()[0] - self.assertEqual({'_id', 'email'}, set(serialized_user.keys())) + serialized_user = User.objects.only("email").as_pymongo()[0] + self.assertEqual({"_id", "email"}, set(serialized_user.keys())) - serialized_user = User.objects.exclude( - 'password_salt').only('email').as_pymongo()[0] - self.assertEqual({'_id', 'email'}, set(serialized_user.keys())) + serialized_user = ( + User.objects.exclude("password_salt").only("email").as_pymongo()[0] + ) + self.assertEqual({"_id", "email"}, set(serialized_user.keys())) - serialized_user = User.objects.exclude( - 'password_salt', 'id').only('email').as_pymongo()[0] - self.assertEqual({'email'}, set(serialized_user.keys())) + serialized_user = ( + User.objects.exclude("password_salt", "id").only("email").as_pymongo()[0] + ) + self.assertEqual({"email"}, set(serialized_user.keys())) - serialized_user = User.objects.exclude( - 'password_salt', 'id').only('email').to_json() - self.assertEqual('[{"email": "ross@example.com"}]', - serialized_user) + serialized_user = ( + User.objects.exclude("password_salt", "id").only("email").to_json() + ) + self.assertEqual('[{"email": "ross@example.com"}]', serialized_user) def test_only_after_count(self): """Test that only() works after count()""" @@ -4780,9 +4837,9 @@ class QuerySetTest(unittest.TestCase): name = StringField() age = IntField() address = StringField() + User.drop_collection() - user = User(name="User", age=50, - address="Moscow, Russia").save() + user = User(name="User", age=50, address="Moscow, Russia").save() user_queryset = User.objects(age=50) @@ -4796,7 +4853,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(result, {"_id": user.id, "name": "User", "age": 50}) def test_no_dereference(self): - class Organization(Document): name = StringField() @@ -4842,12 +4898,14 @@ class QuerySetTest(unittest.TestCase): self.assertFalse(qs_no_deref._auto_dereference) # Make sure the instance field is different from the class field - instance_org_field = user_no_deref._fields['organization'] + instance_org_field = user_no_deref._fields["organization"] self.assertIsNot(instance_org_field, cls_organization_field) self.assertFalse(instance_org_field._auto_dereference) self.assertIsInstance(user_no_deref.organization, DBRef) - self.assertTrue(cls_organization_field._auto_dereference, True) # Make sure the class Field wasn't altered + self.assertTrue( + cls_organization_field._auto_dereference, True + ) # Make sure the class Field wasn't altered def test_no_dereference_no_side_effect_on_existing_instance(self): # Relates to issue #1677 - ensures no regression of the bug @@ -4863,8 +4921,7 @@ class QuerySetTest(unittest.TestCase): Organization.drop_collection() org = Organization(name="whatever").save() - User(organization=org, - organization_gen=org).save() + User(organization=org, organization_gen=org).save() qs = User.objects() user = qs.first() @@ -4873,7 +4930,7 @@ class QuerySetTest(unittest.TestCase): user_no_deref = qs_no_deref.first() # ReferenceField - no_derf_org = user_no_deref.organization # was triggering the bug + no_derf_org = user_no_deref.organization # was triggering the bug self.assertIsInstance(no_derf_org, DBRef) self.assertIsInstance(user.organization, Organization) @@ -4883,7 +4940,6 @@ class QuerySetTest(unittest.TestCase): self.assertIsInstance(user.organization_gen, Organization) def test_no_dereference_embedded_doc(self): - class User(Document): name = StringField() @@ -4906,17 +4962,15 @@ class QuerySetTest(unittest.TestCase): member = Member(name="Flash", user=user) - company = Organization(name="Mongo Inc", - ceo=user, - member=member, - admins=[user], - members=[member]) + company = Organization( + name="Mongo Inc", ceo=user, member=member, admins=[user], members=[member] + ) company.save() org = Organization.objects().no_dereference().first() - self.assertNotEqual(id(org._fields['admins']), id(Organization.admins)) - self.assertFalse(org._fields['admins']._auto_dereference) + self.assertNotEqual(id(org._fields["admins"]), id(Organization.admins)) + self.assertFalse(org._fields["admins"]._auto_dereference) admin = org.admins[0] self.assertIsInstance(admin, DBRef) @@ -4981,14 +5035,14 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() qs = Person.objects.no_cache() - self.assertEqual(repr(qs), '[]') + self.assertEqual(repr(qs), "[]") def test_no_cached_on_a_cached_queryset_raise_error(self): class Person(Document): name = StringField() Person.drop_collection() - Person(name='a').save() + Person(name="a").save() qs = Person.objects() _ = list(qs) with self.assertRaises(OperationError) as ctx_err: @@ -5008,7 +5062,6 @@ class QuerySetTest(unittest.TestCase): self.assertIsInstance(qs, QuerySet) def test_cache_not_cloned(self): - class User(Document): name = StringField() @@ -5020,7 +5073,7 @@ class QuerySetTest(unittest.TestCase): User(name="Alice").save() User(name="Bob").save() - users = User.objects.all().order_by('name') + users = User.objects.all().order_by("name") self.assertEqual("%s" % users, "[, ]") self.assertEqual(2, len(users._result_cache)) @@ -5030,6 +5083,7 @@ class QuerySetTest(unittest.TestCase): def test_no_cache(self): """Ensure you can add meta data to file""" + class Noddy(Document): fields = DictField() @@ -5063,7 +5117,7 @@ class QuerySetTest(unittest.TestCase): def test_nested_queryset_iterator(self): # Try iterating the same queryset twice, nested. - names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George'] + names = ["Alice", "Bob", "Chuck", "David", "Eric", "Francis", "George"] class User(Document): name = StringField() @@ -5076,7 +5130,7 @@ class QuerySetTest(unittest.TestCase): for name in names: User(name=name).save() - users = User.objects.all().order_by('name') + users = User.objects.all().order_by("name") outer_count = 0 inner_count = 0 inner_total_count = 0 @@ -5114,7 +5168,7 @@ class QuerySetTest(unittest.TestCase): x = IntField() y = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class B(A): z = IntField() @@ -5151,6 +5205,7 @@ class QuerySetTest(unittest.TestCase): def test_query_generic_embedded_document(self): """Ensure that querying sub field on generic_embedded_field works """ + class A(EmbeddedDocument): a_name = StringField() @@ -5161,19 +5216,16 @@ class QuerySetTest(unittest.TestCase): document = GenericEmbeddedDocumentField(choices=(A, B)) Doc.drop_collection() - Doc(document=A(a_name='A doc')).save() - Doc(document=B(b_name='B doc')).save() + Doc(document=A(a_name="A doc")).save() + Doc(document=B(b_name="B doc")).save() # Using raw in filter working fine - self.assertEqual(Doc.objects( - __raw__={'document.a_name': 'A doc'}).count(), 1) - self.assertEqual(Doc.objects( - __raw__={'document.b_name': 'B doc'}).count(), 1) - self.assertEqual(Doc.objects(document__a_name='A doc').count(), 1) - self.assertEqual(Doc.objects(document__b_name='B doc').count(), 1) + self.assertEqual(Doc.objects(__raw__={"document.a_name": "A doc"}).count(), 1) + self.assertEqual(Doc.objects(__raw__={"document.b_name": "B doc"}).count(), 1) + self.assertEqual(Doc.objects(document__a_name="A doc").count(), 1) + self.assertEqual(Doc.objects(document__b_name="B doc").count(), 1) def test_query_reference_to_custom_pk_doc(self): - class A(Document): id = StringField(primary_key=True) @@ -5183,7 +5235,7 @@ class QuerySetTest(unittest.TestCase): A.drop_collection() B.drop_collection() - a = A.objects.create(id='custom_id') + a = A.objects.create(id="custom_id") B.objects.create(a=a) self.assertEqual(B.objects.count(), 1) @@ -5191,13 +5243,10 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(B.objects.get(a=a.id).a, a) def test_cls_query_in_subclassed_docs(self): - class Animal(Document): name = StringField() - meta = { - 'allow_inheritance': True - } + meta = {"allow_inheritance": True} class Dog(Animal): pass @@ -5205,21 +5254,23 @@ class QuerySetTest(unittest.TestCase): class Cat(Animal): pass - self.assertEqual(Animal.objects(name='Charlie')._query, { - 'name': 'Charlie', - '_cls': {'$in': ('Animal', 'Animal.Dog', 'Animal.Cat')} - }) - self.assertEqual(Dog.objects(name='Charlie')._query, { - 'name': 'Charlie', - '_cls': 'Animal.Dog' - }) - self.assertEqual(Cat.objects(name='Charlie')._query, { - 'name': 'Charlie', - '_cls': 'Animal.Cat' - }) + self.assertEqual( + Animal.objects(name="Charlie")._query, + { + "name": "Charlie", + "_cls": {"$in": ("Animal", "Animal.Dog", "Animal.Cat")}, + }, + ) + self.assertEqual( + Dog.objects(name="Charlie")._query, + {"name": "Charlie", "_cls": "Animal.Dog"}, + ) + self.assertEqual( + Cat.objects(name="Charlie")._query, + {"name": "Charlie", "_cls": "Animal.Cat"}, + ) def test_can_have_field_same_name_as_query_operator(self): - class Size(Document): name = StringField() @@ -5236,7 +5287,6 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1) def test_cursor_in_an_if_stmt(self): - class Test(Document): test_field = StringField() @@ -5244,23 +5294,23 @@ class QuerySetTest(unittest.TestCase): queryset = Test.objects if queryset: - raise AssertionError('Empty cursor returns True') + raise AssertionError("Empty cursor returns True") test = Test() - test.test_field = 'test' + test.test_field = "test" test.save() queryset = Test.objects if not test: - raise AssertionError('Cursor has data and returned False') + raise AssertionError("Cursor has data and returned False") queryset.next() if not queryset: - raise AssertionError('Cursor has data and it must returns True,' - ' even in the last item.') + raise AssertionError( + "Cursor has data and it must returns True, even in the last item." + ) def test_bool_performance(self): - class Person(Document): name = StringField() @@ -5273,10 +5323,11 @@ class QuerySetTest(unittest.TestCase): pass self.assertEqual(q, 1) - op = q.db.system.profile.find({"ns": - {"$ne": "%s.system.indexes" % q.db.name}})[0] + op = q.db.system.profile.find( + {"ns": {"$ne": "%s.system.indexes" % q.db.name}} + )[0] - self.assertEqual(op['nreturned'], 1) + self.assertEqual(op["nreturned"], 1) def test_bool_with_ordering(self): ORDER_BY_KEY, CMD_QUERY_KEY = get_key_compat(self.mongodb_version) @@ -5289,26 +5340,28 @@ class QuerySetTest(unittest.TestCase): Person(name="Test").save() # Check that bool(queryset) does not uses the orderby - qs = Person.objects.order_by('name') + qs = Person.objects.order_by("name") with query_counter() as q: if bool(qs): pass - op = q.db.system.profile.find({"ns": - {"$ne": "%s.system.indexes" % q.db.name}})[0] + op = q.db.system.profile.find( + {"ns": {"$ne": "%s.system.indexes" % q.db.name}} + )[0] self.assertNotIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) # Check that normal query uses orderby - qs2 = Person.objects.order_by('name') + qs2 = Person.objects.order_by("name") with query_counter() as q: for x in qs2: pass - op = q.db.system.profile.find({"ns": - {"$ne": "%s.system.indexes" % q.db.name}})[0] + op = q.db.system.profile.find( + {"ns": {"$ne": "%s.system.indexes" % q.db.name}} + )[0] self.assertIn(ORDER_BY_KEY, op[CMD_QUERY_KEY]) @@ -5317,9 +5370,7 @@ class QuerySetTest(unittest.TestCase): class Person(Document): name = StringField() - meta = { - 'ordering': ['name'] - } + meta = {"ordering": ["name"]} Person.drop_collection() @@ -5332,15 +5383,20 @@ class QuerySetTest(unittest.TestCase): if Person.objects: pass - op = q.db.system.profile.find({"ns": - {"$ne": "%s.system.indexes" % q.db.name}})[0] + op = q.db.system.profile.find( + {"ns": {"$ne": "%s.system.indexes" % q.db.name}} + )[0] - self.assertNotIn('$orderby', op[CMD_QUERY_KEY], - 'BaseQuerySet must remove orderby from meta in boolen test') + self.assertNotIn( + "$orderby", + op[CMD_QUERY_KEY], + "BaseQuerySet must remove orderby from meta in boolen test", + ) - self.assertEqual(Person.objects.first().name, 'A') - self.assertTrue(Person.objects._has_data(), - 'Cursor has data and returned False') + self.assertEqual(Person.objects.first().name, "A") + self.assertTrue( + Person.objects._has_data(), "Cursor has data and returned False" + ) def test_queryset_aggregation_framework(self): class Person(Document): @@ -5355,40 +5411,44 @@ class QuerySetTest(unittest.TestCase): Person.objects.insert([p1, p2, p3]) data = Person.objects(age__lte=22).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [ - {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, - {'_id': p2.pk, 'name': "WILSON JUNIOR"} - ]) - - data = Person.objects(age__lte=22).order_by('-name').aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + self.assertEqual( + list(data), + [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ], ) - self.assertEqual(list(data), [ - {'_id': p2.pk, 'name': "WILSON JUNIOR"}, - {'_id': p1.pk, 'name': "ISABELLA LUANNA"} - ]) + data = ( + Person.objects(age__lte=22) + .order_by("-name") + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) + ) - data = Person.objects(age__gte=17, age__lte=40).order_by('-age').aggregate({ - '$group': { - '_id': None, - 'total': {'$sum': 1}, - 'avg': {'$avg': '$age'} - } - }) - self.assertEqual(list(data), [ - {'_id': None, 'avg': 29, 'total': 2} - ]) + self.assertEqual( + list(data), + [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + ], + ) - data = Person.objects().aggregate({'$match': {'name': 'Isabella Luanna'}}) - self.assertEqual(list(data), [ - {u'_id': p1.pk, - u'age': 16, - u'name': u'Isabella Luanna'}] - ) + data = ( + Person.objects(age__gte=17, age__lte=40) + .order_by("-age") + .aggregate( + {"$group": {"_id": None, "total": {"$sum": 1}, "avg": {"$avg": "$age"}}} + ) + ) + self.assertEqual(list(data), [{"_id": None, "avg": 29, "total": 2}]) + + data = Person.objects().aggregate({"$match": {"name": "Isabella Luanna"}}) + self.assertEqual( + list(data), [{u"_id": p1.pk, u"age": 16, u"name": u"Isabella Luanna"}] + ) def test_queryset_aggregation_with_skip(self): class Person(Document): @@ -5403,13 +5463,16 @@ class QuerySetTest(unittest.TestCase): Person.objects.insert([p1, p2, p3]) data = Person.objects.skip(1).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [ - {'_id': p2.pk, 'name': "WILSON JUNIOR"}, - {'_id': p3.pk, 'name': "SANDRA MARA"} - ]) + self.assertEqual( + list(data), + [ + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ], + ) def test_queryset_aggregation_with_limit(self): class Person(Document): @@ -5424,12 +5487,10 @@ class QuerySetTest(unittest.TestCase): Person.objects.insert([p1, p2, p3]) data = Person.objects.limit(1).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [ - {'_id': p1.pk, 'name': "ISABELLA LUANNA"} - ]) + self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) def test_queryset_aggregation_with_sort(self): class Person(Document): @@ -5443,15 +5504,18 @@ class QuerySetTest(unittest.TestCase): p3 = Person(name="Sandra Mara", age=37) Person.objects.insert([p1, p2, p3]) - data = Person.objects.order_by('name').aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + data = Person.objects.order_by("name").aggregate( + {"$project": {"name": {"$toUpper": "$name"}}} ) - self.assertEqual(list(data), [ - {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, - {'_id': p3.pk, 'name': "SANDRA MARA"}, - {'_id': p2.pk, 'name': "WILSON JUNIOR"} - ]) + self.assertEqual( + list(data), + [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + {"_id": p2.pk, "name": "WILSON JUNIOR"}, + ], + ) def test_queryset_aggregation_with_skip_with_limit(self): class Person(Document): @@ -5466,18 +5530,18 @@ class QuerySetTest(unittest.TestCase): Person.objects.insert([p1, p2, p3]) data = list( - Person.objects.skip(1).limit(1).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} - ) + Person.objects.skip(1) + .limit(1) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [ - {'_id': p2.pk, 'name': "WILSON JUNIOR"}, - ]) + self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) # Make sure limit/skip chaining order has no impact - data2 = Person.objects.limit(1).skip(1).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + data2 = ( + Person.objects.limit(1) + .skip(1) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) self.assertEqual(data, list(data2)) @@ -5494,34 +5558,40 @@ class QuerySetTest(unittest.TestCase): p3 = Person(name="Sandra Mara", age=37) Person.objects.insert([p1, p2, p3]) - data = Person.objects.order_by('name').limit(2).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + data = ( + Person.objects.order_by("name") + .limit(2) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [ - {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, - {'_id': p3.pk, 'name': "SANDRA MARA"} - ]) + self.assertEqual( + list(data), + [ + {"_id": p1.pk, "name": "ISABELLA LUANNA"}, + {"_id": p3.pk, "name": "SANDRA MARA"}, + ], + ) # Verify adding limit/skip steps works as expected - data = Person.objects.order_by('name').limit(2).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}}, - {'$limit': 1}, + data = ( + Person.objects.order_by("name") + .limit(2) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}, {"$limit": 1}) ) - self.assertEqual(list(data), [ - {'_id': p1.pk, 'name': "ISABELLA LUANNA"}, - ]) + self.assertEqual(list(data), [{"_id": p1.pk, "name": "ISABELLA LUANNA"}]) - data = Person.objects.order_by('name').limit(2).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}}, - {'$skip': 1}, - {'$limit': 1}, + data = ( + Person.objects.order_by("name") + .limit(2) + .aggregate( + {"$project": {"name": {"$toUpper": "$name"}}}, + {"$skip": 1}, + {"$limit": 1}, + ) ) - self.assertEqual(list(data), [ - {'_id': p3.pk, 'name': "SANDRA MARA"}, - ]) + self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) def test_queryset_aggregation_with_sort_with_skip(self): class Person(Document): @@ -5535,13 +5605,13 @@ class QuerySetTest(unittest.TestCase): p3 = Person(name="Sandra Mara", age=37) Person.objects.insert([p1, p2, p3]) - data = Person.objects.order_by('name').skip(2).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + data = ( + Person.objects.order_by("name") + .skip(2) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [ - {'_id': p2.pk, 'name': "WILSON JUNIOR"} - ]) + self.assertEqual(list(data), [{"_id": p2.pk, "name": "WILSON JUNIOR"}]) def test_queryset_aggregation_with_sort_with_skip_with_limit(self): class Person(Document): @@ -5555,35 +5625,42 @@ class QuerySetTest(unittest.TestCase): p3 = Person(name="Sandra Mara", age=37) Person.objects.insert([p1, p2, p3]) - data = Person.objects.order_by('name').skip(1).limit(1).aggregate( - {'$project': {'name': {'$toUpper': '$name'}}} + data = ( + Person.objects.order_by("name") + .skip(1) + .limit(1) + .aggregate({"$project": {"name": {"$toUpper": "$name"}}}) ) - self.assertEqual(list(data), [ - {'_id': p3.pk, 'name': "SANDRA MARA"} - ]) + self.assertEqual(list(data), [{"_id": p3.pk, "name": "SANDRA MARA"}]) def test_delete_count(self): [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count + self.assertEqual( + self.Person.objects().delete(), 3 + ) # test ordinary QuerySey delete count [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] - self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents + self.assertEqual( + self.Person.objects().skip(1).delete(), 2 + ) # test Document delete with existing documents self.Person.objects().delete() - self.assertEqual(self.Person.objects().skip(1).delete(), 0) # test Document delete without existing documents + self.assertEqual( + self.Person.objects().skip(1).delete(), 0 + ) # test Document delete without existing documents def test_max_time_ms(self): # 778: max_time_ms can get only int or None as input - self.assertRaises(TypeError, - self.Person.objects(name="name").max_time_ms, - 'not a number') + self.assertRaises( + TypeError, self.Person.objects(name="name").max_time_ms, "not a number" + ) def test_subclass_field_query(self): class Animal(Document): is_mamal = BooleanField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class Cat(Animal): whiskers_length = FloatField() @@ -5605,14 +5682,15 @@ class QuerySetTest(unittest.TestCase): Person.drop_collection() - Person._get_collection().insert_one({'name': 'a', 'id': ''}) + Person._get_collection().insert_one({"name": "a", "id": ""}) for p in Person.objects(): - self.assertEqual(p.name, 'a') + self.assertEqual(p.name, "a") def test_len_during_iteration(self): """Tests that calling len on a queyset during iteration doesn't stop paging. """ + class Data(Document): pass @@ -5645,6 +5723,7 @@ class QuerySetTest(unittest.TestCase): in a given queryset even if there are multiple iterations of it happening at the same time. """ + class Data(Document): pass @@ -5663,6 +5742,7 @@ class QuerySetTest(unittest.TestCase): """Ensure that using the `__in` operator on a non-iterable raises an error. """ + class User(Document): name = StringField() @@ -5673,9 +5753,10 @@ class QuerySetTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() - author = User.objects.create(name='Test User') - post = BlogPost.objects.create(content='Had a good coffee today...', - authors=[author]) + author = User.objects.create(name="Test User") + post = BlogPost.objects.create( + content="Had a good coffee today...", authors=[author] + ) # Make sure using `__in` with a list works blog_posts = BlogPost.objects(authors__in=[author]) @@ -5699,5 +5780,5 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(self.Person.objects.count(with_limit_and_skip=True), 4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 2c2d018c..cfcd8c22 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -9,25 +9,29 @@ __all__ = ("TransformTest",) class TransformTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") def test_transform_query(self): """Ensure that the _transform_query function operates correctly. """ - self.assertEqual(transform.query(name='test', age=30), - {'name': 'test', 'age': 30}) - self.assertEqual(transform.query(age__lt=30), - {'age': {'$lt': 30}}) - self.assertEqual(transform.query(age__gt=20, age__lt=50), - {'age': {'$gt': 20, '$lt': 50}}) - self.assertEqual(transform.query(age=20, age__gt=50), - {'$and': [{'age': {'$gt': 50}}, {'age': 20}]}) - self.assertEqual(transform.query(friend__age__gte=30), - {'friend.age': {'$gte': 30}}) - self.assertEqual(transform.query(name__exists=True), - {'name': {'$exists': True}}) + self.assertEqual( + transform.query(name="test", age=30), {"name": "test", "age": 30} + ) + self.assertEqual(transform.query(age__lt=30), {"age": {"$lt": 30}}) + self.assertEqual( + transform.query(age__gt=20, age__lt=50), {"age": {"$gt": 20, "$lt": 50}} + ) + self.assertEqual( + transform.query(age=20, age__gt=50), + {"$and": [{"age": {"$gt": 50}}, {"age": 20}]}, + ) + self.assertEqual( + transform.query(friend__age__gte=30), {"friend.age": {"$gte": 30}} + ) + self.assertEqual( + transform.query(name__exists=True), {"name": {"$exists": True}} + ) def test_transform_update(self): class LisDoc(Document): @@ -46,7 +50,11 @@ class TransformTest(unittest.TestCase): DicDoc().save() doc = Doc().save() - for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): + for k, v in ( + ("set", "$set"), + ("set_on_insert", "$setOnInsert"), + ("push", "$push"), + ): update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) self.assertIsInstance(update[v]["dictField.test"], dict) @@ -57,55 +65,61 @@ class TransformTest(unittest.TestCase): update = transform.update(DicDoc, pull__dictField__test=doc) self.assertIsInstance(update["$pull"]["dictField"]["test"], dict) - update = transform.update(LisDoc, pull__foo__in=['a']) - self.assertEqual(update, {'$pull': {'foo': {'$in': ['a']}}}) + update = transform.update(LisDoc, pull__foo__in=["a"]) + self.assertEqual(update, {"$pull": {"foo": {"$in": ["a"]}}}) def test_transform_update_push(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" + class BlogPost(Document): tags = ListField(StringField()) - update = transform.update(BlogPost, push__tags=['mongo', 'db']) - self.assertEqual(update, {'$push': {'tags': ['mongo', 'db']}}) + update = transform.update(BlogPost, push__tags=["mongo", "db"]) + self.assertEqual(update, {"$push": {"tags": ["mongo", "db"]}}) - update = transform.update(BlogPost, push_all__tags=['mongo', 'db']) - self.assertEqual(update, {'$push': {'tags': {'$each': ['mongo', 'db']}}}) + update = transform.update(BlogPost, push_all__tags=["mongo", "db"]) + self.assertEqual(update, {"$push": {"tags": {"$each": ["mongo", "db"]}}}) def test_transform_update_no_operator_default_to_set(self): """Ensure the differences in behvaior between 'push' and 'push_all'""" + class BlogPost(Document): tags = ListField(StringField()) - update = transform.update(BlogPost, tags=['mongo', 'db']) - self.assertEqual(update, {'$set': {'tags': ['mongo', 'db']}}) + update = transform.update(BlogPost, tags=["mongo", "db"]) + self.assertEqual(update, {"$set": {"tags": ["mongo", "db"]}}) def test_query_field_name(self): """Ensure that the correct field name is used when querying. """ + class Comment(EmbeddedDocument): - content = StringField(db_field='commentContent') + content = StringField(db_field="commentContent") class BlogPost(Document): - title = StringField(db_field='postTitle') - comments = ListField(EmbeddedDocumentField(Comment), - db_field='postComments') + title = StringField(db_field="postTitle") + comments = ListField( + EmbeddedDocumentField(Comment), db_field="postComments" + ) BlogPost.drop_collection() - data = {'title': 'Post 1', 'comments': [Comment(content='test')]} + data = {"title": "Post 1", "comments": [Comment(content="test")]} post = BlogPost(**data) post.save() - self.assertIn('postTitle', BlogPost.objects(title=data['title'])._query) - self.assertFalse('title' in - BlogPost.objects(title=data['title'])._query) - self.assertEqual(BlogPost.objects(title=data['title']).count(), 1) + self.assertIn("postTitle", BlogPost.objects(title=data["title"])._query) + self.assertFalse("title" in BlogPost.objects(title=data["title"])._query) + self.assertEqual(BlogPost.objects(title=data["title"]).count(), 1) - self.assertIn('_id', BlogPost.objects(pk=post.id)._query) + self.assertIn("_id", BlogPost.objects(pk=post.id)._query) self.assertEqual(BlogPost.objects(pk=post.id).count(), 1) - self.assertIn('postComments.commentContent', BlogPost.objects(comments__content='test')._query) - self.assertEqual(BlogPost.objects(comments__content='test').count(), 1) + self.assertIn( + "postComments.commentContent", + BlogPost.objects(comments__content="test")._query, + ) + self.assertEqual(BlogPost.objects(comments__content="test").count(), 1) BlogPost.drop_collection() @@ -113,18 +127,19 @@ class TransformTest(unittest.TestCase): """Ensure that the correct "primary key" field name is used when querying """ + class BlogPost(Document): - title = StringField(primary_key=True, db_field='postTitle') + title = StringField(primary_key=True, db_field="postTitle") BlogPost.drop_collection() - data = {'title': 'Post 1'} + data = {"title": "Post 1"} post = BlogPost(**data) post.save() - self.assertIn('_id', BlogPost.objects(pk=data['title'])._query) - self.assertIn('_id', BlogPost.objects(title=data['title'])._query) - self.assertEqual(BlogPost.objects(pk=data['title']).count(), 1) + self.assertIn("_id", BlogPost.objects(pk=data["title"])._query) + self.assertIn("_id", BlogPost.objects(title=data["title"])._query) + self.assertEqual(BlogPost.objects(pk=data["title"]).count(), 1) BlogPost.drop_collection() @@ -156,78 +171,125 @@ class TransformTest(unittest.TestCase): """ Test raw plays nicely """ + class Foo(Document): name = StringField() a = StringField() b = StringField() c = StringField() - meta = { - 'allow_inheritance': False - } + meta = {"allow_inheritance": False} - query = Foo.objects(__raw__={'$nor': [{'name': 'bar'}]})._query - self.assertEqual(query, {'$nor': [{'name': 'bar'}]}) + query = Foo.objects(__raw__={"$nor": [{"name": "bar"}]})._query + self.assertEqual(query, {"$nor": [{"name": "bar"}]}) - q1 = {'$or': [{'a': 1}, {'b': 1}]} + q1 = {"$or": [{"a": 1}, {"b": 1}]} query = Foo.objects(Q(__raw__=q1) & Q(c=1))._query - self.assertEqual(query, {'$or': [{'a': 1}, {'b': 1}], 'c': 1}) + self.assertEqual(query, {"$or": [{"a": 1}, {"b": 1}], "c": 1}) def test_raw_and_merging(self): class Doc(Document): - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} - raw_query = Doc.objects(__raw__={ - 'deleted': False, - 'scraped': 'yes', - '$nor': [ - {'views.extracted': 'no'}, - {'attachments.views.extracted': 'no'} - ] - })._query + raw_query = Doc.objects( + __raw__={ + "deleted": False, + "scraped": "yes", + "$nor": [ + {"views.extracted": "no"}, + {"attachments.views.extracted": "no"}, + ], + } + )._query - self.assertEqual(raw_query, { - 'deleted': False, - 'scraped': 'yes', - '$nor': [ - {'views.extracted': 'no'}, - {'attachments.views.extracted': 'no'} - ] - }) + self.assertEqual( + raw_query, + { + "deleted": False, + "scraped": "yes", + "$nor": [ + {"views.extracted": "no"}, + {"attachments.views.extracted": "no"}, + ], + }, + ) def test_geojson_PointField(self): class Location(Document): loc = PointField() update = transform.update(Location, set__loc=[1, 2]) - self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}}) + self.assertEqual( + update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} + ) - update = transform.update(Location, set__loc={"type": "Point", "coordinates": [1, 2]}) - self.assertEqual(update, {'$set': {'loc': {"type": "Point", "coordinates": [1, 2]}}}) + update = transform.update( + Location, set__loc={"type": "Point", "coordinates": [1, 2]} + ) + self.assertEqual( + update, {"$set": {"loc": {"type": "Point", "coordinates": [1, 2]}}} + ) def test_geojson_LineStringField(self): class Location(Document): line = LineStringField() update = transform.update(Location, set__line=[[1, 2], [2, 2]]) - self.assertEqual(update, {'$set': {'line': {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}) + self.assertEqual( + update, + {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, + ) - update = transform.update(Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]}) - self.assertEqual(update, {'$set': {'line': {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}) + update = transform.update( + Location, set__line={"type": "LineString", "coordinates": [[1, 2], [2, 2]]} + ) + self.assertEqual( + update, + {"$set": {"line": {"type": "LineString", "coordinates": [[1, 2], [2, 2]]}}}, + ) def test_geojson_PolygonField(self): class Location(Document): poly = PolygonField() - update = transform.update(Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]) - self.assertEqual(update, {'$set': {'poly': {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}}}) + update = transform.update( + Location, set__poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]] + ) + self.assertEqual( + update, + { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], + } + } + }, + ) - update = transform.update(Location, set__poly={"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}) - self.assertEqual(update, {'$set': {'poly': {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]]}}}) + update = transform.update( + Location, + set__poly={ + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], + }, + ) + self.assertEqual( + update, + { + "$set": { + "poly": { + "type": "Polygon", + "coordinates": [[[40, 5], [40, 6], [41, 6], [40, 5]]], + } + } + }, + ) def test_type(self): class Doc(Document): df = DynamicField() + Doc(df=True).save() Doc(df=7).save() Doc(df="df").save() @@ -252,7 +314,7 @@ class TransformTest(unittest.TestCase): self.assertEqual(1, Doc.objects(item__type__="axe").count()) self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count()) - Doc.objects(id=doc.id).update(set__item__type__='sword') + Doc.objects(id=doc.id).update(set__item__type__="sword") self.assertEqual(1, Doc.objects(item__type__="sword").count()) self.assertEqual(0, Doc.objects(item__type__="axe").count()) @@ -272,6 +334,7 @@ class TransformTest(unittest.TestCase): Test added to check pull operation in update for EmbeddedDocumentListField which is inside a EmbeddedDocumentField """ + class Word(EmbeddedDocument): word = StringField() index = IntField() @@ -284,18 +347,27 @@ class TransformTest(unittest.TestCase): title = StringField() content = EmbeddedDocumentField(SubDoc) - word = Word(word='abc', index=1) + word = Word(word="abc", index=1) update = transform.update(MainDoc, pull__content__text=word) - self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}}) + self.assertEqual( + update, {"$pull": {"content.text": SON([("word", u"abc"), ("index", 1)])}} + ) - update = transform.update(MainDoc, pull__content__heading='xyz') - self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + update = transform.update(MainDoc, pull__content__heading="xyz") + self.assertEqual(update, {"$pull": {"content.heading": "xyz"}}) - update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) - self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) + update = transform.update(MainDoc, pull__content__text__word__in=["foo", "bar"]) + self.assertEqual( + update, {"$pull": {"content.text": {"word": {"$in": ["foo", "bar"]}}}} + ) - update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar']) - self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}}) + update = transform.update( + MainDoc, pull__content__text__word__nin=["foo", "bar"] + ) + self.assertEqual( + update, {"$pull": {"content.text": {"word": {"$nin": ["foo", "bar"]}}}} + ) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 22d274a8..0a22416f 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -12,14 +12,13 @@ __all__ = ("QTest",) class QTest(unittest.TestCase): - def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") class Person(Document): name = StringField() age = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} Person.drop_collection() self.Person = Person @@ -30,22 +29,22 @@ class QTest(unittest.TestCase): q1 = Q() q2 = Q(age__gte=18) q3 = Q() - q4 = Q(name='test') + q4 = Q(name="test") q5 = Q() class Person(Document): name = StringField() age = IntField() - query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]} + query = {"$or": [{"age": {"$gte": 18}}, {"name": "test"}]} self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query) - query = {'age': {'$gte': 18}, 'name': 'test'} + query = {"age": {"$gte": 18}, "name": "test"} self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" - connect(db='mongoenginetest') + connect(db="mongoenginetest") class User(Document): pass @@ -62,15 +61,18 @@ class QTest(unittest.TestCase): def test_and_combination(self): """Ensure that Q-objects correctly AND together. """ + class TestDoc(Document): x = IntField() y = StringField() query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual(query, {'$and': [{'x': {'$lt': 7}}, {'x': {'$lt': 3}}]}) + self.assertEqual(query, {"$and": [{"x": {"$lt": 7}}, {"x": {"$lt": 3}}]}) query = (Q(y="a") & Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc) - self.assertEqual(query, {'$and': [{'y': "a"}, {'x': {'$lt': 7}}, {'x': {'$lt': 3}}]}) + self.assertEqual( + query, {"$and": [{"y": "a"}, {"x": {"$lt": 7}}, {"x": {"$lt": 3}}]} + ) # Check normal cases work without an error query = Q(x__lt=7) & Q(x__gt=3) @@ -78,69 +80,74 @@ class QTest(unittest.TestCase): q1 = Q(x__lt=7) q2 = Q(x__gt=3) query = (q1 & q2).to_query(TestDoc) - self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) + self.assertEqual(query, {"x": {"$lt": 7, "$gt": 3}}) # More complex nested example - query = Q(x__lt=100) & Q(y__ne='NotMyString') - query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) + query = Q(x__lt=100) & Q(y__ne="NotMyString") + query &= Q(y__in=["a", "b", "c"]) & Q(x__gt=-100) mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, - 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, + "x": {"$lt": 100, "$gt": -100}, + "y": {"$ne": "NotMyString", "$in": ["a", "b", "c"]}, } self.assertEqual(query.to_query(TestDoc), mongo_query) def test_or_combination(self): """Ensure that Q-objects correctly OR together. """ + class TestDoc(Document): x = IntField() q1 = Q(x__lt=3) q2 = Q(x__gt=7) query = (q1 | q2).to_query(TestDoc) - self.assertEqual(query, { - '$or': [ - {'x': {'$lt': 3}}, - {'x': {'$gt': 7}}, - ] - }) + self.assertEqual(query, {"$or": [{"x": {"$lt": 3}}, {"x": {"$gt": 7}}]}) def test_and_or_combination(self): """Ensure that Q-objects handle ANDing ORed components. """ + class TestDoc(Document): x = IntField() y = BooleanField() TestDoc.drop_collection() - query = (Q(x__gt=0) | Q(x__exists=False)) + query = Q(x__gt=0) | Q(x__exists=False) query &= Q(x__lt=100) - self.assertEqual(query.to_query(TestDoc), {'$and': [ - {'$or': [{'x': {'$gt': 0}}, - {'x': {'$exists': False}}]}, - {'x': {'$lt': 100}}] - }) + self.assertEqual( + query.to_query(TestDoc), + { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"x": {"$lt": 100}}, + ] + }, + ) - q1 = (Q(x__gt=0) | Q(x__exists=False)) - q2 = (Q(x__lt=100) | Q(y=True)) + q1 = Q(x__gt=0) | Q(x__exists=False) + q2 = Q(x__lt=100) | Q(y=True) query = (q1 & q2).to_query(TestDoc) TestDoc(x=101).save() TestDoc(x=10).save() TestDoc(y=True).save() - self.assertEqual(query, { - '$and': [ - {'$or': [{'x': {'$gt': 0}}, {'x': {'$exists': False}}]}, - {'$or': [{'x': {'$lt': 100}}, {'y': True}]} - ] - }) + self.assertEqual( + query, + { + "$and": [ + {"$or": [{"x": {"$gt": 0}}, {"x": {"$exists": False}}]}, + {"$or": [{"x": {"$lt": 100}}, {"y": True}]}, + ] + }, + ) self.assertEqual(2, TestDoc.objects(q1 & q2).count()) def test_or_and_or_combination(self): """Ensure that Q-objects handle ORing ANDed ORed components. :) """ + class TestDoc(Document): x = IntField() y = BooleanField() @@ -151,18 +158,29 @@ class QTest(unittest.TestCase): TestDoc(x=99, y=False).save() TestDoc(x=101, y=False).save() - q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False))) - q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False))) + q1 = Q(x__gt=0) & (Q(y=True) | Q(y__exists=False)) + q2 = Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)) query = (q1 | q2).to_query(TestDoc) - self.assertEqual(query, { - '$or': [ - {'$and': [{'x': {'$gt': 0}}, - {'$or': [{'y': True}, {'y': {'$exists': False}}]}]}, - {'$and': [{'x': {'$lt': 100}}, - {'$or': [{'y': False}, {'y': {'$exists': False}}]}]} - ] - }) + self.assertEqual( + query, + { + "$or": [ + { + "$and": [ + {"x": {"$gt": 0}}, + {"$or": [{"y": True}, {"y": {"$exists": False}}]}, + ] + }, + { + "$and": [ + {"x": {"$lt": 100}}, + {"$or": [{"y": False}, {"y": {"$exists": False}}]}, + ] + }, + ] + }, + ) self.assertEqual(2, TestDoc.objects(q1 | q2).count()) def test_multiple_occurence_in_field(self): @@ -170,8 +188,8 @@ class QTest(unittest.TestCase): name = StringField(max_length=40) title = StringField(max_length=40) - q1 = Q(name__contains='te') | Q(title__contains='te') - q2 = Q(name__contains='12') | Q(title__contains='12') + q1 = Q(name__contains="te") | Q(title__contains="te") + q2 = Q(name__contains="12") | Q(title__contains="12") q3 = q1 & q2 @@ -180,7 +198,6 @@ class QTest(unittest.TestCase): self.assertEqual(query["$and"][1], q2.to_query(Test)) def test_q_clone(self): - class TestDoc(Document): x = IntField() @@ -205,6 +222,7 @@ class QTest(unittest.TestCase): def test_q(self): """Ensure that Q objects may be used to query for documents. """ + class BlogPost(Document): title = StringField() publish_date = DateTimeField() @@ -212,22 +230,26 @@ class QTest(unittest.TestCase): BlogPost.drop_collection() - post1 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 8), published=False) + post1 = BlogPost( + title="Test 1", publish_date=datetime.datetime(2010, 1, 8), published=False + ) post1.save() - post2 = BlogPost(title='Test 2', publish_date=datetime.datetime(2010, 1, 15), published=True) + post2 = BlogPost( + title="Test 2", publish_date=datetime.datetime(2010, 1, 15), published=True + ) post2.save() - post3 = BlogPost(title='Test 3', published=True) + post3 = BlogPost(title="Test 3", published=True) post3.save() - post4 = BlogPost(title='Test 4', publish_date=datetime.datetime(2010, 1, 8)) + post4 = BlogPost(title="Test 4", publish_date=datetime.datetime(2010, 1, 8)) post4.save() - post5 = BlogPost(title='Test 1', publish_date=datetime.datetime(2010, 1, 15)) + post5 = BlogPost(title="Test 1", publish_date=datetime.datetime(2010, 1, 15)) post5.save() - post6 = BlogPost(title='Test 1', published=False) + post6 = BlogPost(title="Test 1", published=False) post6.save() # Check ObjectId lookup works @@ -235,13 +257,13 @@ class QTest(unittest.TestCase): self.assertEqual(obj, post1) # Check Q object combination with one does not exist - q = BlogPost.objects(Q(title='Test 5') | Q(published=True)) + q = BlogPost.objects(Q(title="Test 5") | Q(published=True)) posts = [post.id for post in q] published_posts = (post2, post3) self.assertTrue(all(obj.id in posts for obj in published_posts)) - q = BlogPost.objects(Q(title='Test 1') | Q(published=True)) + q = BlogPost.objects(Q(title="Test 1") | Q(published=True)) posts = [post.id for post in q] published_posts = (post1, post2, post3, post5, post6) self.assertTrue(all(obj.id in posts for obj in published_posts)) @@ -259,85 +281,91 @@ class QTest(unittest.TestCase): BlogPost.drop_collection() # Check the 'in' operator - self.Person(name='user1', age=20).save() - self.Person(name='user2', age=20).save() - self.Person(name='user3', age=30).save() - self.Person(name='user4', age=40).save() + self.Person(name="user1", age=20).save() + self.Person(name="user2", age=20).save() + self.Person(name="user3", age=30).save() + self.Person(name="user4", age=40).save() self.assertEqual(self.Person.objects(Q(age__in=[20])).count(), 2) self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) # Test invalid query objs with self.assertRaises(InvalidQueryError): - self.Person.objects('user1') + self.Person.objects("user1") # filter should fail, too with self.assertRaises(InvalidQueryError): - self.Person.objects.filter('user1') + self.Person.objects.filter("user1") def test_q_regex(self): """Ensure that Q objects can be queried using regexes. """ - person = self.Person(name='Guido van Rossum') + person = self.Person(name="Guido van Rossum") person.save() - obj = self.Person.objects(Q(name=re.compile('^Gui'))).first() + obj = self.Person.objects(Q(name=re.compile("^Gui"))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name=re.compile('^gui'))).first() + obj = self.Person.objects(Q(name=re.compile("^gui"))).first() self.assertEqual(obj, None) - obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first() + obj = self.Person.objects(Q(name=re.compile("^gui", re.I))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() + obj = self.Person.objects(Q(name__not=re.compile("^bob"))).first() self.assertEqual(obj, person) - obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() + obj = self.Person.objects(Q(name__not=re.compile("^Gui"))).first() self.assertEqual(obj, None) def test_q_repr(self): - self.assertEqual(repr(Q()), 'Q(**{})') - self.assertEqual(repr(Q(name='test')), "Q(**{'name': 'test'})") + self.assertEqual(repr(Q()), "Q(**{})") + self.assertEqual(repr(Q(name="test")), "Q(**{'name': 'test'})") self.assertEqual( - repr(Q(name='test') & Q(age__gte=18)), - "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))") + repr(Q(name="test") & Q(age__gte=18)), + "(Q(**{'name': 'test'}) & Q(**{'age__gte': 18}))", + ) self.assertEqual( - repr(Q(name='test') | Q(age__gte=18)), - "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))") + repr(Q(name="test") | Q(age__gte=18)), + "(Q(**{'name': 'test'}) | Q(**{'age__gte': 18}))", + ) def test_q_lists(self): """Ensure that Q objects query ListFields correctly. """ + class BlogPost(Document): tags = ListField(StringField()) BlogPost.drop_collection() - BlogPost(tags=['python', 'mongo']).save() - BlogPost(tags=['python']).save() + BlogPost(tags=["python", "mongo"]).save() + BlogPost(tags=["python"]).save() - self.assertEqual(BlogPost.objects(Q(tags='mongo')).count(), 1) - self.assertEqual(BlogPost.objects(Q(tags='python')).count(), 2) + self.assertEqual(BlogPost.objects(Q(tags="mongo")).count(), 1) + self.assertEqual(BlogPost.objects(Q(tags="python")).count(), 2) BlogPost.drop_collection() def test_q_merge_queries_edge_case(self): - class User(Document): email = EmailField(required=False) name = StringField() User.drop_collection() pk = ObjectId() - User(email='example@example.com', pk=pk).save() + User(email="example@example.com", pk=pk).save() - self.assertEqual(1, User.objects.filter(Q(email='example@example.com') | - Q(name='John Doe')).limit(2).filter(pk=pk).count()) + self.assertEqual( + 1, + User.objects.filter(Q(email="example@example.com") | Q(name="John Doe")) + .limit(2) + .filter(pk=pk) + .count(), + ) def test_chained_q_or_filtering(self): - class Post(EmbeddedDocument): name = StringField(required=True) @@ -350,9 +378,16 @@ class QTest(unittest.TestCase): Item(postables=[Post(name="a"), Post(name="c")]).save() Item(postables=[Post(name="a"), Post(name="b"), Post(name="c")]).save() - self.assertEqual(Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2) - self.assertEqual(Item.objects.filter(postables__name="a").filter(postables__name="b").count(), 2) + self.assertEqual( + Item.objects(Q(postables__name="a") & Q(postables__name="b")).count(), 2 + ) + self.assertEqual( + Item.objects.filter(postables__name="a") + .filter(postables__name="b") + .count(), + 2, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_common.py b/tests/test_common.py index 04ad5b34..5d702668 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,6 @@ from mongoengine import Document class TestCommon(unittest.TestCase): - def test__import_class(self): doc_cls = _import_class("Document") self.assertIs(doc_cls, Document) diff --git a/tests/test_connection.py b/tests/test_connection.py index d3fcc395..25007132 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -14,12 +14,21 @@ import pymongo from bson.tz_util import utc from mongoengine import ( - connect, register_connection, - Document, DateTimeField, - disconnect_all, StringField) + connect, + register_connection, + Document, + DateTimeField, + disconnect_all, + StringField, +) import mongoengine.connection -from mongoengine.connection import (MongoEngineConnectionError, get_db, - get_connection, disconnect, DEFAULT_DATABASE_NAME) +from mongoengine.connection import ( + MongoEngineConnectionError, + get_db, + get_connection, + disconnect, + DEFAULT_DATABASE_NAME, +) def get_tz_awareness(connection): @@ -27,7 +36,6 @@ def get_tz_awareness(connection): class ConnectionTest(unittest.TestCase): - @classmethod def setUpClass(cls): disconnect_all() @@ -43,44 +51,46 @@ class ConnectionTest(unittest.TestCase): def test_connect(self): """Ensure that the connect() method works properly.""" - connect('mongoenginetest') + connect("mongoenginetest") conn = get_connection() self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'mongoenginetest') + self.assertEqual(db.name, "mongoenginetest") - connect('mongoenginetest2', alias='testdb') - conn = get_connection('testdb') + connect("mongoenginetest2", alias="testdb") + conn = get_connection("testdb") self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) def test_connect_disconnect_works_properly(self): class History1(Document): name = StringField() - meta = {'db_alias': 'db1'} + meta = {"db_alias": "db1"} class History2(Document): name = StringField() - meta = {'db_alias': 'db2'} + meta = {"db_alias": "db2"} - connect('db1', alias='db1') - connect('db2', alias='db2') + connect("db1", alias="db1") + connect("db2", alias="db2") History1.drop_collection() History2.drop_collection() - h = History1(name='default').save() - h1 = History2(name='db1').save() + h = History1(name="default").save() + h1 = History2(name="db1").save() - self.assertEqual(list(History1.objects().as_pymongo()), - [{'_id': h.id, 'name': 'default'}]) - self.assertEqual(list(History2.objects().as_pymongo()), - [{'_id': h1.id, 'name': 'db1'}]) + self.assertEqual( + list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] + ) + self.assertEqual( + list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] + ) - disconnect('db1') - disconnect('db2') + disconnect("db1") + disconnect("db2") with self.assertRaises(MongoEngineConnectionError): list(History1.objects().as_pymongo()) @@ -88,13 +98,15 @@ class ConnectionTest(unittest.TestCase): with self.assertRaises(MongoEngineConnectionError): list(History2.objects().as_pymongo()) - connect('db1', alias='db1') - connect('db2', alias='db2') + connect("db1", alias="db1") + connect("db2", alias="db2") - self.assertEqual(list(History1.objects().as_pymongo()), - [{'_id': h.id, 'name': 'default'}]) - self.assertEqual(list(History2.objects().as_pymongo()), - [{'_id': h1.id, 'name': 'db1'}]) + self.assertEqual( + list(History1.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] + ) + self.assertEqual( + list(History2.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] + ) def test_connect_different_documents_to_different_database(self): class History(Document): @@ -102,99 +114,110 @@ class ConnectionTest(unittest.TestCase): class History1(Document): name = StringField() - meta = {'db_alias': 'db1'} + meta = {"db_alias": "db1"} class History2(Document): name = StringField() - meta = {'db_alias': 'db2'} + meta = {"db_alias": "db2"} connect() - connect('db1', alias='db1') - connect('db2', alias='db2') + connect("db1", alias="db1") + connect("db2", alias="db2") History.drop_collection() History1.drop_collection() History2.drop_collection() - h = History(name='default').save() - h1 = History1(name='db1').save() - h2 = History2(name='db2').save() + h = History(name="default").save() + h1 = History1(name="db1").save() + h2 = History2(name="db2").save() self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) - self.assertEqual(History1._collection.database.name, 'db1') - self.assertEqual(History2._collection.database.name, 'db2') + self.assertEqual(History1._collection.database.name, "db1") + self.assertEqual(History2._collection.database.name, "db2") - self.assertEqual(list(History.objects().as_pymongo()), - [{'_id': h.id, 'name': 'default'}]) - self.assertEqual(list(History1.objects().as_pymongo()), - [{'_id': h1.id, 'name': 'db1'}]) - self.assertEqual(list(History2.objects().as_pymongo()), - [{'_id': h2.id, 'name': 'db2'}]) + self.assertEqual( + list(History.objects().as_pymongo()), [{"_id": h.id, "name": "default"}] + ) + self.assertEqual( + list(History1.objects().as_pymongo()), [{"_id": h1.id, "name": "db1"}] + ) + self.assertEqual( + list(History2.objects().as_pymongo()), [{"_id": h2.id, "name": "db2"}] + ) def test_connect_fails_if_connect_2_times_with_default_alias(self): - connect('mongoenginetest') + connect("mongoenginetest") with self.assertRaises(MongoEngineConnectionError) as ctx_err: - connect('mongoenginetest2') - self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception)) + connect("mongoenginetest2") + self.assertEqual( + "A different connection with alias `default` was already registered. Use disconnect() first", + str(ctx_err.exception), + ) def test_connect_fails_if_connect_2_times_with_custom_alias(self): - connect('mongoenginetest', alias='alias1') + connect("mongoenginetest", alias="alias1") with self.assertRaises(MongoEngineConnectionError) as ctx_err: - connect('mongoenginetest2', alias='alias1') + connect("mongoenginetest2", alias="alias1") - self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception)) + self.assertEqual( + "A different connection with alias `alias1` was already registered. Use disconnect() first", + str(ctx_err.exception), + ) - def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self): + def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way( + self + ): """Intended to keep the detecton function simple but robust""" - db_name = 'mongoenginetest' - db_alias = 'alias1' - connect(db=db_name, alias=db_alias, host='localhost', port=27017) + db_name = "mongoenginetest" + db_alias = "alias1" + connect(db=db_name, alias=db_alias, host="localhost", port=27017) with self.assertRaises(MongoEngineConnectionError): - connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias) + connect(host="mongodb://localhost:27017/%s" % db_name, alias=db_alias) def test_connect_passes_silently_connect_multiple_times_with_same_config(self): # test default connection to `test` connect() connect() self.assertEqual(len(mongoengine.connection._connections), 1) - connect('test01', alias='test01') - connect('test01', alias='test01') + connect("test01", alias="test01") + connect("test01", alias="test01") self.assertEqual(len(mongoengine.connection._connections), 2) - connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') - connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") + connect(host="mongodb://localhost:27017/mongoenginetest02", alias="test02") self.assertEqual(len(mongoengine.connection._connections), 3) def test_connect_with_invalid_db_name(self): """Ensure that connect() method fails fast if db name is invalid """ with self.assertRaises(InvalidName): - connect('mongomock://localhost') + connect("mongomock://localhost") def test_connect_with_db_name_external(self): """Ensure that connect() works if db name is $external """ """Ensure that the connect() method works properly.""" - connect('$external') + connect("$external") conn = get_connection() self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, '$external') + self.assertEqual(db.name, "$external") - connect('$external', alias='testdb') - conn = get_connection('testdb') + connect("$external", alias="testdb") + conn = get_connection("testdb") self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) def test_connect_with_invalid_db_name_type(self): """Ensure that connect() method fails fast if db name has invalid type """ with self.assertRaises(TypeError): - non_string_db_name = ['e. g. list instead of a string'] + non_string_db_name = ["e. g. list instead of a string"] connect(non_string_db_name) def test_connect_in_mocking(self): @@ -203,34 +226,47 @@ class ConnectionTest(unittest.TestCase): try: import mongomock except ImportError: - raise SkipTest('you need mongomock installed to run this testcase') + raise SkipTest("you need mongomock installed to run this testcase") - connect('mongoenginetest', host='mongomock://localhost') + connect("mongoenginetest", host="mongomock://localhost") conn = get_connection() self.assertIsInstance(conn, mongomock.MongoClient) - connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2') - conn = get_connection('testdb2') + connect("mongoenginetest2", host="mongomock://localhost", alias="testdb2") + conn = get_connection("testdb2") self.assertIsInstance(conn, mongomock.MongoClient) - connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3') - conn = get_connection('testdb3') + connect( + "mongoenginetest3", + host="mongodb://localhost", + is_mock=True, + alias="testdb3", + ) + conn = get_connection("testdb3") self.assertIsInstance(conn, mongomock.MongoClient) - connect('mongoenginetest4', is_mock=True, alias='testdb4') - conn = get_connection('testdb4') + connect("mongoenginetest4", is_mock=True, alias="testdb4") + conn = get_connection("testdb4") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5') - conn = get_connection('testdb5') + connect( + host="mongodb://localhost:27017/mongoenginetest5", + is_mock=True, + alias="testdb5", + ) + conn = get_connection("testdb5") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6') - conn = get_connection('testdb6') + connect(host="mongomock://localhost:27017/mongoenginetest6", alias="testdb6") + conn = get_connection("testdb6") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7') - conn = get_connection('testdb7') + connect( + host="mongomock://localhost:27017/mongoenginetest7", + is_mock=True, + alias="testdb7", + ) + conn = get_connection("testdb7") self.assertIsInstance(conn, mongomock.MongoClient) def test_connect_with_host_list(self): @@ -241,30 +277,39 @@ class ConnectionTest(unittest.TestCase): try: import mongomock except ImportError: - raise SkipTest('you need mongomock installed to run this testcase') + raise SkipTest("you need mongomock installed to run this testcase") - connect(host=['mongomock://localhost']) + connect(host=["mongomock://localhost"]) conn = get_connection() self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') - conn = get_connection('testdb2') + connect(host=["mongodb://localhost"], is_mock=True, alias="testdb2") + conn = get_connection("testdb2") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost'], is_mock=True, alias='testdb3') - conn = get_connection('testdb3') + connect(host=["localhost"], is_mock=True, alias="testdb3") + conn = get_connection("testdb3") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongomock://localhost:27017', 'mongomock://localhost:27018'], alias='testdb4') - conn = get_connection('testdb4') + connect( + host=["mongomock://localhost:27017", "mongomock://localhost:27018"], + alias="testdb4", + ) + conn = get_connection("testdb4") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') - conn = get_connection('testdb5') + connect( + host=["mongodb://localhost:27017", "mongodb://localhost:27018"], + is_mock=True, + alias="testdb5", + ) + conn = get_connection("testdb5") self.assertIsInstance(conn, mongomock.MongoClient) - connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') - conn = get_connection('testdb6') + connect( + host=["localhost:27017", "localhost:27018"], is_mock=True, alias="testdb6" + ) + conn = get_connection("testdb6") self.assertIsInstance(conn, mongomock.MongoClient) def test_disconnect_cleans_globals(self): @@ -273,7 +318,7 @@ class ConnectionTest(unittest.TestCase): dbs = mongoengine.connection._dbs connection_settings = mongoengine.connection._connection_settings - connect('mongoenginetest') + connect("mongoenginetest") self.assertEqual(len(connections), 1) self.assertEqual(len(dbs), 0) @@ -292,7 +337,7 @@ class ConnectionTest(unittest.TestCase): def test_disconnect_cleans_cached_collection_attribute_in_document(self): """Ensure that the disconnect() method works properly""" - conn1 = connect('mongoenginetest') + conn1 = connect("mongoenginetest") class History(Document): pass @@ -301,7 +346,7 @@ class ConnectionTest(unittest.TestCase): History.drop_collection() - History.objects.first() # will trigger the caching of _collection attribute + History.objects.first() # will trigger the caching of _collection attribute self.assertIsNotNone(History._collection) disconnect() @@ -310,15 +355,17 @@ class ConnectionTest(unittest.TestCase): with self.assertRaises(MongoEngineConnectionError) as ctx_err: History.objects.first() - self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) + self.assertEqual( + "You have not defined a default connection", str(ctx_err.exception) + ) def test_connect_disconnect_works_on_same_document(self): """Ensure that the connect/disconnect works properly with a single Document""" - db1 = 'db1' - db2 = 'db2' + db1 = "db1" + db2 = "db2" # Ensure freshness of the 2 databases through pymongo - client = MongoClient('localhost', 27017) + client = MongoClient("localhost", 27017) client.drop_database(db1) client.drop_database(db2) @@ -328,44 +375,44 @@ class ConnectionTest(unittest.TestCase): class User(Document): name = StringField(required=True) - user1 = User(name='John is in db1').save() + user1 = User(name="John is in db1").save() disconnect() # Make sure save doesnt work at this stage with self.assertRaises(MongoEngineConnectionError): - User(name='Wont work').save() + User(name="Wont work").save() # Save in db2 connect(db2) - user2 = User(name='Bob is in db2').save() + user2 = User(name="Bob is in db2").save() disconnect() db1_users = list(client[db1].user.find()) - self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}]) + self.assertEqual(db1_users, [{"_id": user1.id, "name": "John is in db1"}]) db2_users = list(client[db2].user.find()) - self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}]) + self.assertEqual(db2_users, [{"_id": user2.id, "name": "Bob is in db2"}]) def test_disconnect_silently_pass_if_alias_does_not_exist(self): connections = mongoengine.connection._connections self.assertEqual(len(connections), 0) - disconnect(alias='not_exist') + disconnect(alias="not_exist") def test_disconnect_all(self): connections = mongoengine.connection._connections dbs = mongoengine.connection._dbs connection_settings = mongoengine.connection._connection_settings - connect('mongoenginetest') - connect('mongoenginetest2', alias='db1') + connect("mongoenginetest") + connect("mongoenginetest2", alias="db1") class History(Document): pass class History1(Document): name = StringField() - meta = {'db_alias': 'db1'} + meta = {"db_alias": "db1"} - History.drop_collection() # will trigger the caching of _collection attribute + History.drop_collection() # will trigger the caching of _collection attribute History.objects.first() History1.drop_collection() History1.objects.first() @@ -398,11 +445,11 @@ class ConnectionTest(unittest.TestCase): def test_sharing_connections(self): """Ensure that connections are shared when the connection settings are exactly the same """ - connect('mongoenginetests', alias='testdb1') - expected_connection = get_connection('testdb1') + connect("mongoenginetests", alias="testdb1") + expected_connection = get_connection("testdb1") - connect('mongoenginetests', alias='testdb2') - actual_connection = get_connection('testdb2') + connect("mongoenginetests", alias="testdb2") + actual_connection = get_connection("testdb2") expected_connection.server_info() @@ -410,7 +457,7 @@ class ConnectionTest(unittest.TestCase): def test_connect_uri(self): """Ensure that the connect() method works properly with URIs.""" - c = connect(db='mongoenginetest', alias='admin') + c = connect(db="mongoenginetest", alias="admin") c.admin.system.users.delete_many({}) c.mongoenginetest.system.users.delete_many({}) @@ -418,14 +465,16 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.admin.command("createUser", "username", pwd="password", roles=["dbOwner"]) - connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') + connect( + "testdb_uri", host="mongodb://username:password@localhost/mongoenginetest" + ) conn = get_connection() self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'mongoenginetest') + self.assertEqual(db.name, "mongoenginetest") c.admin.system.users.delete_many({}) c.mongoenginetest.system.users.delete_many({}) @@ -434,35 +483,35 @@ class ConnectionTest(unittest.TestCase): """Ensure connect() method works properly if the URI doesn't include a database name. """ - connect("mongoenginetest", host='mongodb://localhost/') + connect("mongoenginetest", host="mongodb://localhost/") conn = get_connection() self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'mongoenginetest') + self.assertEqual(db.name, "mongoenginetest") def test_connect_uri_default_db(self): """Ensure connect() defaults to the right database name if the URI and the database_name don't explicitly specify it. """ - connect(host='mongodb://localhost/') + connect(host="mongodb://localhost/") conn = get_connection() self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') + self.assertEqual(db.name, "test") def test_uri_without_credentials_doesnt_override_conn_settings(self): """Ensure connect() uses the username & password params if the URI doesn't explicitly specify them. """ - c = connect(host='mongodb://localhost/mongoenginetest', - username='user', - password='pass') + c = connect( + host="mongodb://localhost/mongoenginetest", username="user", password="pass" + ) # OperationFailure means that mongoengine attempted authentication # w/ the provided username/password and failed - that's the desired @@ -474,27 +523,31 @@ class ConnectionTest(unittest.TestCase): option in the URI. """ # Create users - c = connect('mongoenginetest') + c = connect("mongoenginetest") c.admin.system.users.delete_many({}) c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) # Authentication fails without "authSource" test_conn = connect( - 'mongoenginetest', alias='test1', - host='mongodb://username2:password@localhost/mongoenginetest' + "mongoenginetest", + alias="test1", + host="mongodb://username2:password@localhost/mongoenginetest", ) self.assertRaises(OperationFailure, test_conn.server_info) # Authentication succeeds with "authSource" authd_conn = connect( - 'mongoenginetest', alias='test2', - host=('mongodb://username2:password@localhost/' - 'mongoenginetest?authSource=admin') + "mongoenginetest", + alias="test2", + host=( + "mongodb://username2:password@localhost/" + "mongoenginetest?authSource=admin" + ), ) - db = get_db('test2') + db = get_db("test2") self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'mongoenginetest') + self.assertEqual(db.name, "mongoenginetest") # Clear all users authd_conn.admin.system.users.delete_many({}) @@ -502,82 +555,86 @@ class ConnectionTest(unittest.TestCase): def test_register_connection(self): """Ensure that connections with different aliases may be registered. """ - register_connection('testdb', 'mongoenginetest2') + register_connection("testdb", "mongoenginetest2") self.assertRaises(MongoEngineConnectionError, get_connection) - conn = get_connection('testdb') + conn = get_connection("testdb") self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) - db = get_db('testdb') + db = get_db("testdb") self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'mongoenginetest2') + self.assertEqual(db.name, "mongoenginetest2") def test_register_connection_defaults(self): """Ensure that defaults are used when the host and port are None. """ - register_connection('testdb', 'mongoenginetest', host=None, port=None) + register_connection("testdb", "mongoenginetest", host=None, port=None) - conn = get_connection('testdb') + conn = get_connection("testdb") self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) def test_connection_kwargs(self): """Ensure that connection kwargs get passed to pymongo.""" - connect('mongoenginetest', alias='t1', tz_aware=True) - conn = get_connection('t1') + connect("mongoenginetest", alias="t1", tz_aware=True) + conn = get_connection("t1") self.assertTrue(get_tz_awareness(conn)) - connect('mongoenginetest2', alias='t2') - conn = get_connection('t2') + connect("mongoenginetest2", alias="t2") + conn = get_connection("t2") self.assertFalse(get_tz_awareness(conn)) def test_connection_pool_via_kwarg(self): """Ensure we can specify a max connection pool size using a connection kwarg. """ - pool_size_kwargs = {'maxpoolsize': 100} + pool_size_kwargs = {"maxpoolsize": 100} - conn = connect('mongoenginetest', alias='max_pool_size_via_kwarg', **pool_size_kwargs) + conn = connect( + "mongoenginetest", alias="max_pool_size_via_kwarg", **pool_size_kwargs + ) self.assertEqual(conn.max_pool_size, 100) def test_connection_pool_via_uri(self): """Ensure we can specify a max connection pool size using an option in a connection URI. """ - conn = connect(host='mongodb://localhost/test?maxpoolsize=100', alias='max_pool_size_via_uri') + conn = connect( + host="mongodb://localhost/test?maxpoolsize=100", + alias="max_pool_size_via_uri", + ) self.assertEqual(conn.max_pool_size, 100) def test_write_concern(self): """Ensure write concern can be specified in connect() via a kwarg or as part of the connection URI. """ - conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') - conn2 = connect('testing', alias='conn2', w=1, j=True) - self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) - self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True}) + conn1 = connect(alias="conn1", host="mongodb://localhost/testing?w=1&j=true") + conn2 = connect("testing", alias="conn2", w=1, j=True) + self.assertEqual(conn1.write_concern.document, {"w": 1, "j": True}) + self.assertEqual(conn2.write_concern.document, {"w": 1, "j": True}) def test_connect_with_replicaset_via_uri(self): """Ensure connect() works when specifying a replicaSet via the MongoDB URI. """ - c = connect(host='mongodb://localhost/test?replicaSet=local-rs') + c = connect(host="mongodb://localhost/test?replicaSet=local-rs") db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') + self.assertEqual(db.name, "test") def test_connect_with_replicaset_via_kwargs(self): """Ensure connect() works when specifying a replicaSet via the connection kwargs """ - c = connect(replicaset='local-rs') - self.assertEqual(c._MongoClient__options.replica_set_name, - 'local-rs') + c = connect(replicaset="local-rs") + self.assertEqual(c._MongoClient__options.replica_set_name, "local-rs") db = get_db() self.assertIsInstance(db, pymongo.database.Database) - self.assertEqual(db.name, 'test') + self.assertEqual(db.name, "test") def test_connect_tz_aware(self): - connect('mongoenginetest', tz_aware=True) + connect("mongoenginetest", tz_aware=True) d = datetime.datetime(2010, 5, 5, tzinfo=utc) class DateDoc(Document): @@ -590,37 +647,39 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(d, date_doc.the_date) def test_read_preference_from_parse(self): - conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred") + conn = connect( + host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred" + ) self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_multiple_connection_settings(self): - connect('mongoenginetest', alias='t1', host="localhost") + connect("mongoenginetest", alias="t1", host="localhost") - connect('mongoenginetest2', alias='t2', host="127.0.0.1") + connect("mongoenginetest2", alias="t2", host="127.0.0.1") mongo_connections = mongoengine.connection._connections self.assertEqual(len(mongo_connections.items()), 2) - self.assertIn('t1', mongo_connections.keys()) - self.assertIn('t2', mongo_connections.keys()) + self.assertIn("t1", mongo_connections.keys()) + self.assertIn("t2", mongo_connections.keys()) # Handle PyMongo 3+ Async Connection # Ensure we are connected, throws ServerSelectionTimeoutError otherwise. # Purposely not catching exception to fail test if thrown. - mongo_connections['t1'].server_info() - mongo_connections['t2'].server_info() - self.assertEqual(mongo_connections['t1'].address[0], 'localhost') - self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + mongo_connections["t1"].server_info() + mongo_connections["t2"].server_info() + self.assertEqual(mongo_connections["t1"].address[0], "localhost") + self.assertEqual(mongo_connections["t2"].address[0], "127.0.0.1") def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): - c1 = connect(alias='testdb1', db='testdb1') - c2 = connect(alias='testdb2', db='testdb2') + c1 = connect(alias="testdb1", db="testdb1") + c2 = connect(alias="testdb2", db="testdb2") self.assertIs(c1, c2) def test_connect_2_databases_uses_different_client_if_different_parameters(self): - c1 = connect(alias='testdb1', db='testdb1', username='u1') - c2 = connect(alias='testdb2', db='testdb2', username='u2') + c1 = connect(alias="testdb1", db="testdb1", username="u1") + c2 = connect(alias="testdb2", db="testdb2", username="u2") self.assertIsNot(c1, c2) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 529032fe..dc9b9bf3 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -2,17 +2,20 @@ import unittest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.context_managers import (switch_db, switch_collection, - no_sub_classes, no_dereference, - query_counter) +from mongoengine.context_managers import ( + switch_db, + switch_collection, + no_sub_classes, + no_dereference, + query_counter, +) from mongoengine.pymongo_support import count_documents class ContextManagersTest(unittest.TestCase): - def test_switch_db_context_manager(self): - connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') + connect("mongoenginetest") + register_connection("testdb-1", "mongoenginetest2") class Group(Document): name = StringField() @@ -22,7 +25,7 @@ class ContextManagersTest(unittest.TestCase): Group(name="hello - default").save() self.assertEqual(1, Group.objects.count()) - with switch_db(Group, 'testdb-1') as Group: + with switch_db(Group, "testdb-1") as Group: self.assertEqual(0, Group.objects.count()) @@ -36,21 +39,21 @@ class ContextManagersTest(unittest.TestCase): self.assertEqual(1, Group.objects.count()) def test_switch_collection_context_manager(self): - connect('mongoenginetest') - register_connection(alias='testdb-1', db='mongoenginetest2') + connect("mongoenginetest") + register_connection(alias="testdb-1", db="mongoenginetest2") class Group(Document): name = StringField() - Group.drop_collection() # drops in default + Group.drop_collection() # drops in default - with switch_collection(Group, 'group1') as Group: - Group.drop_collection() # drops in group1 + with switch_collection(Group, "group1") as Group: + Group.drop_collection() # drops in group1 Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) - with switch_collection(Group, 'group1') as Group: + with switch_collection(Group, "group1") as Group: self.assertEqual(0, Group.objects.count()) @@ -66,7 +69,7 @@ class ContextManagersTest(unittest.TestCase): def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ - connect('mongoenginetest') + connect("mongoenginetest") class User(Document): name = StringField() @@ -80,14 +83,14 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - User(name='user %s' % i).save() + User(name="user %s" % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields['members']._auto_dereference) - self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + self.assertTrue(Group._fields["members"]._auto_dereference) + self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() @@ -104,7 +107,7 @@ class ContextManagersTest(unittest.TestCase): def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ - connect('mongoenginetest') + connect("mongoenginetest") class User(Document): name = StringField() @@ -118,31 +121,29 @@ class ContextManagersTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - User(name='user %s' % i).save() + User(name="user %s" % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields['members']._auto_dereference) - self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + self.assertTrue(Group._fields["members"]._auto_dereference) + self.assertFalse(NoDeRefGroup._fields["members"]._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() - self.assertTrue(all([not isinstance(m, User) - for m in group.members])) + self.assertTrue(all([not isinstance(m, User) for m in group.members])) self.assertNotIsInstance(group.ref, User) self.assertNotIsInstance(group.generic, User) - self.assertTrue(all([isinstance(m, User) - for m in group.members])) + self.assertTrue(all([isinstance(m, User) for m in group.members])) self.assertIsInstance(group.ref, User) self.assertIsInstance(group.generic, User) def test_no_sub_classes(self): class A(Document): x = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class B(A): z = IntField() @@ -188,20 +189,20 @@ class ContextManagersTest(unittest.TestCase): def test_no_sub_classes_modification_to_document_class_are_temporary(self): class A(Document): x = IntField() - meta = {'allow_inheritance': True} + meta = {"allow_inheritance": True} class B(A): z = IntField() - self.assertEqual(A._subclasses, ('A', 'A.B')) + self.assertEqual(A._subclasses, ("A", "A.B")) with no_sub_classes(A): - self.assertEqual(A._subclasses, ('A',)) - self.assertEqual(A._subclasses, ('A', 'A.B')) + self.assertEqual(A._subclasses, ("A",)) + self.assertEqual(A._subclasses, ("A", "A.B")) - self.assertEqual(B._subclasses, ('A.B',)) + self.assertEqual(B._subclasses, ("A.B",)) with no_sub_classes(B): - self.assertEqual(B._subclasses, ('A.B',)) - self.assertEqual(B._subclasses, ('A.B',)) + self.assertEqual(B._subclasses, ("A.B",)) + self.assertEqual(B._subclasses, ("A.B",)) def test_no_subclass_context_manager_does_not_swallow_exception(self): class User(Document): @@ -218,7 +219,7 @@ class ContextManagersTest(unittest.TestCase): raise TypeError() def test_query_counter_temporarily_modifies_profiling_level(self): - connect('mongoenginetest') + connect("mongoenginetest") db = get_db() initial_profiling_level = db.profiling_level() @@ -231,11 +232,13 @@ class ContextManagersTest(unittest.TestCase): self.assertEqual(db.profiling_level(), 2) self.assertEqual(db.profiling_level(), NEW_LEVEL) except Exception: - db.set_profiling_level(initial_profiling_level) # Ensures it gets reseted no matter the outcome of the test + db.set_profiling_level( + initial_profiling_level + ) # Ensures it gets reseted no matter the outcome of the test raise def test_query_counter(self): - connect('mongoenginetest') + connect("mongoenginetest") db = get_db() collection = db.query_counter @@ -245,7 +248,7 @@ class ContextManagersTest(unittest.TestCase): count_documents(collection, {}) def issue_1_insert_query(): - collection.insert_one({'test': 'garbage'}) + collection.insert_one({"test": "garbage"}) def issue_1_find_query(): collection.find_one() @@ -253,7 +256,9 @@ class ContextManagersTest(unittest.TestCase): counter = 0 with query_counter() as q: self.assertEqual(q, counter) - self.assertEqual(q, counter) # Ensures previous count query did not get counted + self.assertEqual( + q, counter + ) # Ensures previous count query did not get counted for _ in range(10): issue_1_insert_query() @@ -270,23 +275,25 @@ class ContextManagersTest(unittest.TestCase): counter += 1 self.assertEqual(q, counter) - self.assertEqual(int(q), counter) # test __int__ + self.assertEqual(int(q), counter) # test __int__ self.assertEqual(repr(q), str(int(q))) # test __repr__ - self.assertGreater(q, -1) # test __gt__ - self.assertGreaterEqual(q, int(q)) # test __gte__ + self.assertGreater(q, -1) # test __gt__ + self.assertGreaterEqual(q, int(q)) # test __gte__ self.assertNotEqual(q, -1) self.assertLess(q, 1000) self.assertLessEqual(q, int(q)) def test_query_counter_counts_getmore_queries(self): - connect('mongoenginetest') + connect("mongoenginetest") db = get_db() collection = db.query_counter collection.drop() - many_docs = [{'test': 'garbage %s' % i} for i in range(150)] - collection.insert_many(many_docs) # first batch of documents contains 101 documents + many_docs = [{"test": "garbage %s" % i} for i in range(150)] + collection.insert_many( + many_docs + ) # first batch of documents contains 101 documents with query_counter() as q: self.assertEqual(q, 0) @@ -294,24 +301,26 @@ class ContextManagersTest(unittest.TestCase): self.assertEqual(q, 2) # 1st select + 1 getmore def test_query_counter_ignores_particular_queries(self): - connect('mongoenginetest') + connect("mongoenginetest") db = get_db() collection = db.query_counter - collection.insert_many([{'test': 'garbage %s' % i} for i in range(10)]) + collection.insert_many([{"test": "garbage %s" % i} for i in range(10)]) with query_counter() as q: self.assertEqual(q, 0) cursor = collection.find() - self.assertEqual(q, 0) # cursor wasn't opened yet - _ = next(cursor) # opens the cursor and fires the find query + self.assertEqual(q, 0) # cursor wasn't opened yet + _ = next(cursor) # opens the cursor and fires the find query self.assertEqual(q, 1) - cursor.close() # issues a `killcursors` query that is ignored by the context + cursor.close() # issues a `killcursors` query that is ignored by the context self.assertEqual(q, 1) - _ = db.system.indexes.find_one() # queries on db.system.indexes are ignored as well + _ = ( + db.system.indexes.find_one() + ) # queries on db.system.indexes are ignored as well self.assertEqual(q, 1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index a9ef98e7..7def2ac7 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -14,128 +14,129 @@ class DocumentStub(object): class TestBaseDict(unittest.TestCase): - @staticmethod def _get_basedict(dict_items): """Get a BaseList bound to a fake document instance""" fake_doc = DocumentStub() - base_list = BaseDict(dict_items, instance=None, name='my_name') - base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor + base_list = BaseDict(dict_items, instance=None, name="my_name") + base_list._instance = ( + fake_doc + ) # hack to inject the mock, it does not work in the constructor return base_list def test___init___(self): class MyDoc(Document): pass - dict_items = {'k': 'v'} + dict_items = {"k": "v"} doc = MyDoc() - base_dict = BaseDict(dict_items, instance=doc, name='my_name') + base_dict = BaseDict(dict_items, instance=doc, name="my_name") self.assertIsInstance(base_dict._instance, Document) - self.assertEqual(base_dict._name, 'my_name') + self.assertEqual(base_dict._name, "my_name") self.assertEqual(base_dict, dict_items) def test_setdefault_calls_mark_as_changed(self): base_dict = self._get_basedict({}) - base_dict.setdefault('k', 'v') + base_dict.setdefault("k", "v") self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) def test_popitems_calls_mark_as_changed(self): - base_dict = self._get_basedict({'k': 'v'}) - self.assertEqual(base_dict.popitem(), ('k', 'v')) + base_dict = self._get_basedict({"k": "v"}) + self.assertEqual(base_dict.popitem(), ("k", "v")) self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) self.assertFalse(base_dict) def test_pop_calls_mark_as_changed(self): - base_dict = self._get_basedict({'k': 'v'}) - self.assertEqual(base_dict.pop('k'), 'v') + base_dict = self._get_basedict({"k": "v"}) + self.assertEqual(base_dict.pop("k"), "v") self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) self.assertFalse(base_dict) def test_pop_calls_does_not_mark_as_changed_when_it_fails(self): - base_dict = self._get_basedict({'k': 'v'}) + base_dict = self._get_basedict({"k": "v"}) with self.assertRaises(KeyError): - base_dict.pop('X') + base_dict.pop("X") self.assertFalse(base_dict._instance._changed_fields) def test_clear_calls_mark_as_changed(self): - base_dict = self._get_basedict({'k': 'v'}) + base_dict = self._get_basedict({"k": "v"}) base_dict.clear() - self.assertEqual(base_dict._instance._changed_fields, ['my_name']) + self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) self.assertEqual(base_dict, {}) def test___delitem___calls_mark_as_changed(self): - base_dict = self._get_basedict({'k': 'v'}) - del base_dict['k'] - self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) + base_dict = self._get_basedict({"k": "v"}) + del base_dict["k"] + self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) self.assertEqual(base_dict, {}) def test___getitem____KeyError(self): base_dict = self._get_basedict({}) with self.assertRaises(KeyError): - base_dict['new'] + base_dict["new"] def test___getitem____simple_value(self): - base_dict = self._get_basedict({'k': 'v'}) - base_dict['k'] = 'v' + base_dict = self._get_basedict({"k": "v"}) + base_dict["k"] = "v" def test___getitem____sublist_gets_converted_to_BaseList(self): - base_dict = self._get_basedict({'k': [0, 1, 2]}) - sub_list = base_dict['k'] + base_dict = self._get_basedict({"k": [0, 1, 2]}) + sub_list = base_dict["k"] self.assertEqual(sub_list, [0, 1, 2]) self.assertIsInstance(sub_list, BaseList) self.assertIs(sub_list._instance, base_dict._instance) - self.assertEqual(sub_list._name, 'my_name.k') + self.assertEqual(sub_list._name, "my_name.k") self.assertEqual(base_dict._instance._changed_fields, []) # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.1']) + self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.1"]) def test___getitem____subdict_gets_converted_to_BaseDict(self): - base_dict = self._get_basedict({'k': {'subk': 'subv'}}) - sub_dict = base_dict['k'] - self.assertEqual(sub_dict, {'subk': 'subv'}) + base_dict = self._get_basedict({"k": {"subk": "subv"}}) + sub_dict = base_dict["k"] + self.assertEqual(sub_dict, {"subk": "subv"}) self.assertIsInstance(sub_dict, BaseDict) self.assertIs(sub_dict._instance, base_dict._instance) - self.assertEqual(sub_dict._name, 'my_name.k') + self.assertEqual(sub_dict._name, "my_name.k") self.assertEqual(base_dict._instance._changed_fields, []) # Challenge mark_as_changed from subdict - sub_dict['subk'] = None - self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk']) + sub_dict["subk"] = None + self.assertEqual(base_dict._instance._changed_fields, ["my_name.k.subk"]) def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self): - base_dict = self._get_basedict({'k': [0, 1, 2]}) - sub_list = base_dict.get('k') + base_dict = self._get_basedict({"k": [0, 1, 2]}) + sub_list = base_dict.get("k") self.assertEqual(sub_list, [0, 1, 2]) self.assertIsInstance(sub_list, BaseList) def test_get_returns_the_same_as___getitem__(self): - base_dict = self._get_basedict({'k': [0, 1, 2]}) - get_ = base_dict.get('k') - getitem_ = base_dict['k'] + base_dict = self._get_basedict({"k": [0, 1, 2]}) + get_ = base_dict.get("k") + getitem_ = base_dict["k"] self.assertEqual(get_, getitem_) def test_get_default(self): base_dict = self._get_basedict({}) sentinel = object() - self.assertEqual(base_dict.get('new'), None) - self.assertIs(base_dict.get('new', sentinel), sentinel) + self.assertEqual(base_dict.get("new"), None) + self.assertIs(base_dict.get("new", sentinel), sentinel) def test___setitem___calls_mark_as_changed(self): base_dict = self._get_basedict({}) - base_dict['k'] = 'v' - self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) - self.assertEqual(base_dict, {'k': 'v'}) + base_dict["k"] = "v" + self.assertEqual(base_dict._instance._changed_fields, ["my_name.k"]) + self.assertEqual(base_dict, {"k": "v"}) def test_update_calls_mark_as_changed(self): base_dict = self._get_basedict({}) - base_dict.update({'k': 'v'}) - self.assertEqual(base_dict._instance._changed_fields, ['my_name']) + base_dict.update({"k": "v"}) + self.assertEqual(base_dict._instance._changed_fields, ["my_name"]) def test___setattr____not_tracked_by_changes(self): base_dict = self._get_basedict({}) - base_dict.a_new_attr = 'test' + base_dict.a_new_attr = "test" self.assertEqual(base_dict._instance._changed_fields, []) def test___delattr____tracked_by_changes(self): @@ -143,19 +144,20 @@ class TestBaseDict(unittest.TestCase): # This is even bad because it could be that there is an attribute # with the same name as a key base_dict = self._get_basedict({}) - base_dict.a_new_attr = 'test' + base_dict.a_new_attr = "test" del base_dict.a_new_attr - self.assertEqual(base_dict._instance._changed_fields, ['my_name.a_new_attr']) + self.assertEqual(base_dict._instance._changed_fields, ["my_name.a_new_attr"]) class TestBaseList(unittest.TestCase): - @staticmethod def _get_baselist(list_items): """Get a BaseList bound to a fake document instance""" fake_doc = DocumentStub() - base_list = BaseList(list_items, instance=None, name='my_name') - base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor + base_list = BaseList(list_items, instance=None, name="my_name") + base_list._instance = ( + fake_doc + ) # hack to inject the mock, it does not work in the constructor return base_list def test___init___(self): @@ -164,19 +166,19 @@ class TestBaseList(unittest.TestCase): list_items = [True] doc = MyDoc() - base_list = BaseList(list_items, instance=doc, name='my_name') + base_list = BaseList(list_items, instance=doc, name="my_name") self.assertIsInstance(base_list._instance, Document) - self.assertEqual(base_list._name, 'my_name') + self.assertEqual(base_list._name, "my_name") self.assertEqual(base_list, list_items) def test___iter__(self): values = [True, False, True, False] - base_list = BaseList(values, instance=None, name='my_name') + base_list = BaseList(values, instance=None, name="my_name") self.assertEqual(values, list(base_list)) def test___iter___allow_modification_while_iterating_withou_error(self): # regular list allows for this, thus this subclass must comply to that - base_list = BaseList([True, False, True, False], instance=None, name='my_name') + base_list = BaseList([True, False, True, False], instance=None, name="my_name") for idx, val in enumerate(base_list): if val: base_list.pop(idx) @@ -185,7 +187,7 @@ class TestBaseList(unittest.TestCase): base_list = self._get_baselist([]) self.assertFalse(base_list._instance._changed_fields) base_list.append(True) - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_subclass_append(self): # Due to the way mark_as_changed_wrapper is implemented @@ -193,7 +195,7 @@ class TestBaseList(unittest.TestCase): class SubBaseList(BaseList): pass - base_list = SubBaseList([], instance=None, name='my_name') + base_list = SubBaseList([], instance=None, name="my_name") base_list.append(True) def test___getitem__using_simple_index(self): @@ -217,54 +219,45 @@ class TestBaseList(unittest.TestCase): self.assertEqual(base_list._instance._changed_fields, []) def test___getitem__sublist_returns_BaseList_bound_to_instance(self): - base_list = self._get_baselist( - [ - [1, 2], - [3, 4] - ] - ) + base_list = self._get_baselist([[1, 2], [3, 4]]) sub_list = base_list[0] self.assertEqual(sub_list, [1, 2]) self.assertIsInstance(sub_list, BaseList) self.assertIs(sub_list._instance, base_list._instance) - self.assertEqual(sub_list._name, 'my_name.0') + self.assertEqual(sub_list._name, "my_name.0") self.assertEqual(base_list._instance._changed_fields, []) # Challenge mark_as_changed from sublist sub_list[1] = None - self.assertEqual(base_list._instance._changed_fields, ['my_name.0.1']) + self.assertEqual(base_list._instance._changed_fields, ["my_name.0.1"]) def test___getitem__subdict_returns_BaseList_bound_to_instance(self): - base_list = self._get_baselist( - [ - {'subk': 'subv'} - ] - ) + base_list = self._get_baselist([{"subk": "subv"}]) sub_dict = base_list[0] - self.assertEqual(sub_dict, {'subk': 'subv'}) + self.assertEqual(sub_dict, {"subk": "subv"}) self.assertIsInstance(sub_dict, BaseDict) self.assertIs(sub_dict._instance, base_list._instance) - self.assertEqual(sub_dict._name, 'my_name.0') + self.assertEqual(sub_dict._name, "my_name.0") self.assertEqual(base_list._instance._changed_fields, []) # Challenge mark_as_changed from subdict - sub_dict['subk'] = None - self.assertEqual(base_list._instance._changed_fields, ['my_name.0.subk']) + sub_dict["subk"] = None + self.assertEqual(base_list._instance._changed_fields, ["my_name.0.subk"]) def test_extend_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.extend([True]) - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_insert_calls_mark_as_changed(self): base_list = self._get_baselist([]) base_list.insert(0, True) - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_remove_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.remove(True) - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_remove_not_mark_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) @@ -275,70 +268,76 @@ class TestBaseList(unittest.TestCase): def test_pop_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list.pop() - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_reverse_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.reverse() - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test___delitem___calls_mark_as_changed(self): base_list = self._get_baselist([True]) del base_list[0] - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test___setitem___calls_with_full_slice_mark_as_changed(self): base_list = self._get_baselist([]) - base_list[:] = [0, 1] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + base_list[:] = [ + 0, + 1, + ] # Will use __setslice__ under py2 and __setitem__ under py3 + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [0, 1]) def test___setitem___calls_with_partial_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) - base_list[0:2] = [1, 0] # Will use __setslice__ under py2 and __setitem__ under py3 - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + base_list[0:2] = [ + 1, + 0, + ] # Will use __setslice__ under py2 and __setitem__ under py3 + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [1, 0, 2]) def test___setitem___calls_with_step_slice_mark_as_changed(self): base_list = self._get_baselist([0, 1, 2]) - base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3 - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3 + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [-1, 1, -2]) def test___setitem___with_slice(self): base_list = self._get_baselist([0, 1, 2, 3, 4, 5]) base_list[0:6:2] = [None, None, None] - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [None, 1, None, 3, None, 5]) def test___setitem___item_0_calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list[0] = False - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [False]) def test___setitem___item_1_calls_mark_as_changed(self): base_list = self._get_baselist([True, True]) base_list[1] = False - self.assertEqual(base_list._instance._changed_fields, ['my_name.1']) + self.assertEqual(base_list._instance._changed_fields, ["my_name.1"]) self.assertEqual(base_list, [True, False]) def test___delslice___calls_mark_as_changed(self): base_list = self._get_baselist([0, 1]) del base_list[0:1] - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) self.assertEqual(base_list, [1]) def test___iadd___calls_mark_as_changed(self): base_list = self._get_baselist([True]) base_list += [False] - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test___imul___calls_mark_as_changed(self): base_list = self._get_baselist([True]) self.assertEqual(base_list._instance._changed_fields, []) base_list *= 2 - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_sort_calls_not_marked_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) @@ -350,7 +349,7 @@ class TestBaseList(unittest.TestCase): def test_sort_calls_mark_as_changed(self): base_list = self._get_baselist([True, False]) base_list.sort() - self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list._instance._changed_fields, ["my_name"]) def test_sort_calls_with_key(self): base_list = self._get_baselist([1, 2, 11]) @@ -371,7 +370,7 @@ class TestStrictDict(unittest.TestCase): def test_iterkeys(self): d = self.dtype(a=1) - self.assertEqual(list(iterkeys(d)), ['a']) + self.assertEqual(list(iterkeys(d)), ["a"]) def test_len(self): d = self.dtype(a=1) @@ -379,9 +378,9 @@ class TestStrictDict(unittest.TestCase): def test_pop(self): d = self.dtype(a=1) - self.assertIn('a', d) - d.pop('a') - self.assertNotIn('a', d) + self.assertIn("a", d) + d.pop("a") + self.assertNotIn("a", d) def test_repr(self): d = self.dtype(a=1, b=2, c=3) @@ -416,7 +415,7 @@ class TestStrictDict(unittest.TestCase): d = self.dtype() d.a = 1 self.assertEqual(d.a, 1) - self.assertRaises(AttributeError, getattr, d, 'b') + self.assertRaises(AttributeError, getattr, d, "b") def test_setattr_raises_on_nonexisting_attr(self): d = self.dtype() @@ -430,20 +429,20 @@ class TestStrictDict(unittest.TestCase): def test_get(self): d = self.dtype(a=1) - self.assertEqual(d.get('a'), 1) - self.assertEqual(d.get('b', 'bla'), 'bla') + self.assertEqual(d.get("a"), 1) + self.assertEqual(d.get("b", "bla"), "bla") def test_items(self): d = self.dtype(a=1) - self.assertEqual(d.items(), [('a', 1)]) + self.assertEqual(d.items(), [("a", 1)]) d = self.dtype(a=1, b=2) - self.assertEqual(d.items(), [('a', 1), ('b', 2)]) + self.assertEqual(d.items(), [("a", 1), ("b", 2)]) def test_mappings_protocol(self): d = self.dtype(a=1, b=2) - self.assertEqual(dict(d), {'a': 1, 'b': 2}) - self.assertEqual(dict(**d), {'a': 1, 'b': 2}) + self.assertEqual(dict(d), {"a": 1, "b": 2}) + self.assertEqual(dict(**d), {"a": 1, "b": 2}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 9c565810..4730e2e3 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -10,18 +10,18 @@ from mongoengine.context_managers import query_counter class FieldTest(unittest.TestCase): - @classmethod def setUpClass(cls): - cls.db = connect(db='mongoenginetest') + cls.db = connect(db="mongoenginetest") @classmethod def tearDownClass(cls): - cls.db.drop_database('mongoenginetest') + cls.db.drop_database("mongoenginetest") def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -32,7 +32,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - user = User(name='user %s' % i) + user = User(name="user %s" % i) user.save() group = Group(members=User.objects) @@ -47,7 +47,7 @@ class FieldTest(unittest.TestCase): group_obj = Group.objects.first() self.assertEqual(q, 1) - len(group_obj._data['members']) + len(group_obj._data["members"]) self.assertEqual(q, 1) len(group_obj.members) @@ -80,6 +80,7 @@ class FieldTest(unittest.TestCase): def test_list_item_dereference_dref_false(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -90,7 +91,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - user = User(name='user %s' % i) + user = User(name="user %s" % i) user.save() group = Group(members=User.objects) @@ -105,14 +106,14 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] self.assertEqual(q, 2) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) # Document select_related with query_counter() as q: @@ -136,6 +137,7 @@ class FieldTest(unittest.TestCase): def test_list_item_dereference_orphan_dbref(self): """Ensure that orphan DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -146,7 +148,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - user = User(name='user %s' % i) + user = User(name="user %s" % i) user.save() group = Group(members=User.objects) @@ -164,14 +166,14 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) # verifies that no additional queries gets executed # if we re-iterate over the ListField once it is # dereferenced [m for m in group_obj.members] self.assertEqual(q, 2) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) User.drop_collection() Group.drop_collection() @@ -179,6 +181,7 @@ class FieldTest(unittest.TestCase): def test_list_item_dereference_dref_false_stores_as_type(self): """Ensure that DBRef items are stored as their type """ + class User(Document): my_id = IntField(primary_key=True) name = StringField() @@ -189,17 +192,18 @@ class FieldTest(unittest.TestCase): User.drop_collection() Group.drop_collection() - user = User(my_id=1, name='user 1').save() + user = User(my_id=1, name="user 1").save() Group(members=User.objects).save() group = Group.objects.first() - self.assertEqual(Group._get_collection().find_one()['members'], [1]) + self.assertEqual(Group._get_collection().find_one()["members"], [1]) self.assertEqual(group.members, [user]) def test_handle_old_style_references(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -210,7 +214,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 26): - user = User(name='user %s' % i) + user = User(name="user %s" % i) user.save() group = Group(members=User.objects) @@ -227,8 +231,8 @@ class FieldTest(unittest.TestCase): group.save() group = Group.objects.first() - self.assertEqual(group.members[0].name, 'user 1') - self.assertEqual(group.members[-1].name, 'String!') + self.assertEqual(group.members[0].name, "user 1") + self.assertEqual(group.members[-1].name, "String!") def test_migrate_references(self): """Example of migrating ReferenceField storage @@ -249,8 +253,8 @@ class FieldTest(unittest.TestCase): group = Group(author=user, members=[user]).save() raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data['author'], DBRef) - self.assertIsInstance(raw_data['members'][0], DBRef) + self.assertIsInstance(raw_data["author"], DBRef) + self.assertIsInstance(raw_data["members"][0], DBRef) group = Group.objects.first() self.assertEqual(group.author, user) @@ -264,8 +268,8 @@ class FieldTest(unittest.TestCase): # Migrate the data for g in Group.objects(): # Explicitly mark as changed so resets - g._mark_as_changed('author') - g._mark_as_changed('members') + g._mark_as_changed("author") + g._mark_as_changed("members") g.save() group = Group.objects.first() @@ -273,35 +277,36 @@ class FieldTest(unittest.TestCase): self.assertEqual(group.members, [user]) raw_data = Group._get_collection().find_one() - self.assertIsInstance(raw_data['author'], ObjectId) - self.assertIsInstance(raw_data['members'][0], ObjectId) + self.assertIsInstance(raw_data["author"], ObjectId) + self.assertIsInstance(raw_data["members"][0], ObjectId) def test_recursive_reference(self): """Ensure that ReferenceFields can reference their own documents. """ + class Employee(Document): name = StringField() - boss = ReferenceField('self') - friends = ListField(ReferenceField('self')) + boss = ReferenceField("self") + friends = ListField(ReferenceField("self")) Employee.drop_collection() - bill = Employee(name='Bill Lumbergh') + bill = Employee(name="Bill Lumbergh") bill.save() - michael = Employee(name='Michael Bolton') + michael = Employee(name="Michael Bolton") michael.save() - samir = Employee(name='Samir Nagheenanajar') + samir = Employee(name="Samir Nagheenanajar") samir.save() friends = [michael, samir] - peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) + peter = Employee(name="Peter Gibbons", boss=bill, friends=friends) peter.save() - Employee(name='Funky Gibbon', boss=bill, friends=friends).save() - Employee(name='Funky Gibbon', boss=bill, friends=friends).save() - Employee(name='Funky Gibbon', boss=bill, friends=friends).save() + Employee(name="Funky Gibbon", boss=bill, friends=friends).save() + Employee(name="Funky Gibbon", boss=bill, friends=friends).save() + Employee(name="Funky Gibbon", boss=bill, friends=friends).save() with query_counter() as q: self.assertEqual(q, 0) @@ -343,7 +348,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) def test_list_of_lists_of_references(self): - class User(Document): name = StringField() @@ -357,9 +361,9 @@ class FieldTest(unittest.TestCase): Post.drop_collection() SimpleList.drop_collection() - u1 = User.objects.create(name='u1') - u2 = User.objects.create(name='u2') - u3 = User.objects.create(name='u3') + u1 = User.objects.create(name="u1") + u2 = User.objects.create(name="u2") + u3 = User.objects.create(name="u3") SimpleList.objects.create(users=[u1, u2, u3]) self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3]) @@ -370,13 +374,14 @@ class FieldTest(unittest.TestCase): def test_circular_reference(self): """Ensure you can handle circular references """ + class Relation(EmbeddedDocument): name = StringField() - person = ReferenceField('Person') + person = ReferenceField("Person") class Person(Document): name = StringField() - relations = ListField(EmbeddedDocumentField('Relation')) + relations = ListField(EmbeddedDocumentField("Relation")) def __repr__(self): return "" % self.name @@ -398,14 +403,17 @@ class FieldTest(unittest.TestCase): daughter.relations.append(self_rel) daughter.save() - self.assertEqual("[, ]", "%s" % Person.objects()) + self.assertEqual( + "[, ]", "%s" % Person.objects() + ) def test_circular_reference_on_self(self): """Ensure you can handle circular references """ + class Person(Document): name = StringField() - relations = ListField(ReferenceField('self')) + relations = ListField(ReferenceField("self")) def __repr__(self): return "" % self.name @@ -424,14 +432,17 @@ class FieldTest(unittest.TestCase): daughter.relations.append(daughter) daughter.save() - self.assertEqual("[, ]", "%s" % Person.objects()) + self.assertEqual( + "[, ]", "%s" % Person.objects() + ) def test_circular_tree_reference(self): """Ensure you can handle circular references with more than one level """ + class Other(EmbeddedDocument): name = StringField() - friends = ListField(ReferenceField('Person')) + friends = ListField(ReferenceField("Person")) class Person(Document): name = StringField() @@ -443,8 +454,8 @@ class FieldTest(unittest.TestCase): Person.drop_collection() paul = Person(name="Paul").save() maria = Person(name="Maria").save() - julia = Person(name='Julia').save() - anna = Person(name='Anna').save() + julia = Person(name="Julia").save() + anna = Person(name="Anna").save() paul.other.friends = [maria, julia, anna] paul.other.name = "Paul's friends" @@ -464,11 +475,10 @@ class FieldTest(unittest.TestCase): self.assertEqual( "[, , , ]", - "%s" % Person.objects() + "%s" % Person.objects(), ) def test_generic_reference(self): - class UserA(Document): name = StringField() @@ -488,13 +498,13 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() - b = UserB(name='User B %s' % i) + b = UserB(name="User B %s" % i) b.save() - c = UserC(name='User C %s' % i) + c = UserC(name="User C %s" % i) c.save() members += [a, b, c] @@ -518,7 +528,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Document select_related with query_counter() as q: @@ -534,7 +544,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Queryset select_related with query_counter() as q: @@ -551,8 +561,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) - + self.assertIn("User", m.__class__.__name__) def test_generic_reference_orphan_dbref(self): """Ensure that generic orphan DBRef items in ListFields are dereferenced. @@ -577,13 +586,13 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() - b = UserB(name='User B %s' % i) + b = UserB(name="User B %s" % i) b.save() - c = UserC(name='User C %s' % i) + c = UserC(name="User C %s" % i) c.save() members += [a, b, c] @@ -602,11 +611,11 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) [m for m in group_obj.members] self.assertEqual(q, 4) - self.assertTrue(group_obj._data['members']._dereferenced) + self.assertTrue(group_obj._data["members"]._dereferenced) UserA.drop_collection() UserB.drop_collection() @@ -614,7 +623,6 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_list_field_complex(self): - class UserA(Document): name = StringField() @@ -634,13 +642,13 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() - b = UserB(name='User B %s' % i) + b = UserB(name="User B %s" % i) b.save() - c = UserC(name='User C %s' % i) + c = UserC(name="User C %s" % i) c.save() members += [a, b, c] @@ -664,7 +672,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Document select_related with query_counter() as q: @@ -680,7 +688,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Queryset select_related with query_counter() as q: @@ -697,7 +705,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for m in group_obj.members: - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) UserA.drop_collection() UserB.drop_collection() @@ -705,7 +713,6 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_map_field_reference(self): - class User(Document): name = StringField() @@ -717,7 +724,7 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - user = User(name='user %s' % i) + user = User(name="user %s" % i) user.save() members.append(user) @@ -752,7 +759,7 @@ class FieldTest(unittest.TestCase): for k, m in iteritems(group_obj.members): self.assertIsInstance(m, User) - # Queryset select_related + # Queryset select_related with query_counter() as q: self.assertEqual(q, 0) @@ -770,7 +777,6 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_dict_field(self): - class UserA(Document): name = StringField() @@ -790,13 +796,13 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() - b = UserB(name='User B %s' % i) + b = UserB(name="User B %s" % i) b.save() - c = UserC(name='User C %s' % i) + c = UserC(name="User C %s" % i) c.save() members += [a, b, c] @@ -819,7 +825,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Document select_related with query_counter() as q: @@ -835,7 +841,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Queryset select_related with query_counter() as q: @@ -852,7 +858,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) Group.objects.delete() Group().save() @@ -873,10 +879,9 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_dict_field_no_field_inheritance(self): - class UserA(Document): name = StringField() - meta = {'allow_inheritance': False} + meta = {"allow_inheritance": False} class Group(Document): members = DictField() @@ -886,7 +891,7 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() members += [a] @@ -949,7 +954,6 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_generic_reference_map_field(self): - class UserA(Document): name = StringField() @@ -969,13 +973,13 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i) + a = UserA(name="User A %s" % i) a.save() - b = UserB(name='User B %s' % i) + b = UserB(name="User B %s" % i) b.save() - c = UserC(name='User C %s' % i) + c = UserC(name="User C %s" % i) c.save() members += [a, b, c] @@ -998,7 +1002,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Document select_related with query_counter() as q: @@ -1014,7 +1018,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) # Queryset select_related with query_counter() as q: @@ -1031,7 +1035,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 4) for k, m in iteritems(group_obj.members): - self.assertIn('User', m.__class__.__name__) + self.assertIn("User", m.__class__.__name__) Group.objects.delete() Group().save() @@ -1051,7 +1055,6 @@ class FieldTest(unittest.TestCase): Group.drop_collection() def test_multidirectional_lists(self): - class Asset(Document): name = StringField(max_length=250, required=True) path = StringField() @@ -1062,10 +1065,10 @@ class FieldTest(unittest.TestCase): Asset.drop_collection() - root = Asset(name='', path="/", title="Site Root") + root = Asset(name="", path="/", title="Site Root") root.save() - company = Asset(name='company', title='Company', parent=root, parents=[root]) + company = Asset(name="company", title="Company", parent=root, parents=[root]) company.save() root.children = [company] @@ -1076,7 +1079,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(company.parents, [root]) def test_dict_in_dbref_instance(self): - class Person(Document): name = StringField(max_length=250, required=True) @@ -1087,34 +1089,35 @@ class FieldTest(unittest.TestCase): Person.drop_collection() Room.drop_collection() - bob = Person.objects.create(name='Bob') + bob = Person.objects.create(name="Bob") bob.save() - sarah = Person.objects.create(name='Sarah') + sarah = Person.objects.create(name="Sarah") sarah.save() room_101 = Room.objects.create(number="101") room_101.staffs_with_position = [ - {'position_key': 'window', 'staff': sarah}, - {'position_key': 'door', 'staff': bob.to_dbref()}] + {"position_key": "window", "staff": sarah}, + {"position_key": "door", "staff": bob.to_dbref()}, + ] room_101.save() room = Room.objects.first().select_related() - self.assertEqual(room.staffs_with_position[0]['staff'], sarah) - self.assertEqual(room.staffs_with_position[1]['staff'], bob) + self.assertEqual(room.staffs_with_position[0]["staff"], sarah) + self.assertEqual(room.staffs_with_position[1]["staff"], bob) def test_document_reload_no_inheritance(self): class Foo(Document): - meta = {'allow_inheritance': False} - bar = ReferenceField('Bar') - baz = ReferenceField('Baz') + meta = {"allow_inheritance": False} + bar = ReferenceField("Bar") + baz = ReferenceField("Baz") class Bar(Document): - meta = {'allow_inheritance': False} - msg = StringField(required=True, default='Blammo!') + meta = {"allow_inheritance": False} + msg = StringField(required=True, default="Blammo!") class Baz(Document): - meta = {'allow_inheritance': False} - msg = StringField(required=True, default='Kaboom!') + meta = {"allow_inheritance": False} + msg = StringField(required=True, default="Kaboom!") Foo.drop_collection() Bar.drop_collection() @@ -1138,11 +1141,14 @@ class FieldTest(unittest.TestCase): Ensure reloading a document with multiple similar id in different collections doesn't mix them. """ + class Topic(Document): id = IntField(primary_key=True) + class User(Document): id = IntField(primary_key=True) name = StringField() + class Message(Document): id = IntField(primary_key=True) topic = ReferenceField(Topic) @@ -1154,23 +1160,24 @@ class FieldTest(unittest.TestCase): # All objects share the same id, but each in a different collection topic = Topic(id=1).save() - user = User(id=1, name='user-name').save() + user = User(id=1, name="user-name").save() Message(id=1, topic=topic, author=user).save() concurrent_change_user = User.objects.get(id=1) - concurrent_change_user.name = 'new-name' + concurrent_change_user.name = "new-name" concurrent_change_user.save() - self.assertNotEqual(user.name, 'new-name') + self.assertNotEqual(user.name, "new-name") msg = Message.objects.get(id=1) msg.reload() self.assertEqual(msg.topic, topic) self.assertEqual(msg.author, user) - self.assertEqual(msg.author.name, 'new-name') + self.assertEqual(msg.author.name, "new-name") def test_list_lookup_not_checked_in_map(self): """Ensure we dereference list data correctly """ + class Comment(Document): id = IntField(primary_key=True) text = StringField() @@ -1182,8 +1189,8 @@ class FieldTest(unittest.TestCase): Comment.drop_collection() Message.drop_collection() - c1 = Comment(id=0, text='zero').save() - c2 = Comment(id=1, text='one').save() + c1 = Comment(id=0, text="zero").save() + c2 = Comment(id=1, text="one").save() Message(id=1, comments=[c1, c2]).save() msg = Message.objects.get(id=1) @@ -1193,6 +1200,7 @@ class FieldTest(unittest.TestCase): def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -1204,7 +1212,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - User(name='user %s' % i).save() + User(name="user %s" % i).save() Group(name="Test", members=User.objects).save() @@ -1222,6 +1230,7 @@ class FieldTest(unittest.TestCase): def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self): """Ensure that DBRef items in ListFields are dereferenced. """ + class User(Document): name = StringField() @@ -1233,7 +1242,7 @@ class FieldTest(unittest.TestCase): Group.drop_collection() for i in range(1, 51): - User(name='user %s' % i).save() + User(name="user %s" % i).save() Group(name="Test", members=User.objects).save() @@ -1249,7 +1258,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) def test_generic_reference_save_doesnt_cause_extra_queries(self): - class UserA(Document): name = StringField() @@ -1270,9 +1278,9 @@ class FieldTest(unittest.TestCase): members = [] for i in range(1, 51): - a = UserA(name='User A %s' % i).save() - b = UserB(name='User B %s' % i).save() - c = UserC(name='User C %s' % i).save() + a = UserA(name="User A %s" % i).save() + b = UserB(name="User B %s" % i).save() + c = UserC(name="User C %s" % i).save() members += [a, b, c] @@ -1292,7 +1300,7 @@ class FieldTest(unittest.TestCase): def test_objectid_reference_across_databases(self): # mongoenginetest - Is default connection alias from setUp() # Register Aliases - register_connection('testdb-1', 'mongoenginetest2') + register_connection("testdb-1", "mongoenginetest2") class User(Document): name = StringField() @@ -1311,16 +1319,17 @@ class FieldTest(unittest.TestCase): # Can't use query_counter across databases - so test the _data object book = Book.objects.first() - self.assertNotIsInstance(book._data['author'], User) + self.assertNotIsInstance(book._data["author"], User) book.select_related() - self.assertIsInstance(book._data['author'], User) + self.assertIsInstance(book._data["author"], User) def test_non_ascii_pk(self): """ Ensure that dbref conversion to string does not fail when non-ascii characters are used in primary key """ + class Brand(Document): title = StringField(max_length=255, primary_key=True) @@ -1341,7 +1350,7 @@ class FieldTest(unittest.TestCase): def test_dereferencing_embedded_listfield_referencefield(self): class Tag(Document): - meta = {'collection': 'tags'} + meta = {"collection": "tags"} name = StringField() class Post(EmbeddedDocument): @@ -1349,22 +1358,21 @@ class FieldTest(unittest.TestCase): tags = ListField(ReferenceField("Tag", dbref=True)) class Page(Document): - meta = {'collection': 'pages'} + meta = {"collection": "pages"} tags = ListField(ReferenceField("Tag", dbref=True)) posts = ListField(EmbeddedDocumentField(Post)) Tag.drop_collection() Page.drop_collection() - tag = Tag(name='test').save() - post = Post(body='test body', tags=[tag]) + tag = Tag(name="test").save() + post = Post(body="test body", tags=[tag]) Page(tags=[tag], posts=[post]).save() page = Page.objects.first() self.assertEqual(page.tags[0], page.posts[0].tags[0]) def test_select_related_follows_embedded_referencefields(self): - class Song(Document): title = StringField() @@ -1390,5 +1398,5 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index cacdce8b..5e3aa493 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -12,7 +12,6 @@ READ_PREF = ReadPreference.SECONDARY class ConnectionTest(unittest.TestCase): - def setUp(self): mongoengine.connection._connection_settings = {} mongoengine.connection._connections = {} @@ -28,9 +27,11 @@ class ConnectionTest(unittest.TestCase): """ try: - conn = mongoengine.connect(db='mongoenginetest', - host="mongodb://localhost/mongoenginetest?replicaSet=rs", - read_preference=READ_PREF) + conn = mongoengine.connect( + db="mongoenginetest", + host="mongodb://localhost/mongoenginetest?replicaSet=rs", + read_preference=READ_PREF, + ) except MongoEngineConnectionError as e: return @@ -41,5 +42,5 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(conn.read_preference, READ_PREF) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_signals.py b/tests/test_signals.py index 34cb43c3..1d0607d7 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -20,7 +20,7 @@ class SignalTests(unittest.TestCase): return signal_output def setUp(self): - connect(db='mongoenginetest') + connect(db="mongoenginetest") class Author(Document): # Make the id deterministic for easier testing @@ -32,60 +32,63 @@ class SignalTests(unittest.TestCase): @classmethod def pre_init(cls, sender, document, *args, **kwargs): - signal_output.append('pre_init signal, %s' % cls.__name__) - signal_output.append(kwargs['values']) + signal_output.append("pre_init signal, %s" % cls.__name__) + signal_output.append(kwargs["values"]) @classmethod def post_init(cls, sender, document, **kwargs): - signal_output.append('post_init signal, %s, document._created = %s' % (document, document._created)) + signal_output.append( + "post_init signal, %s, document._created = %s" + % (document, document._created) + ) @classmethod def pre_save(cls, sender, document, **kwargs): - signal_output.append('pre_save signal, %s' % document) + signal_output.append("pre_save signal, %s" % document) signal_output.append(kwargs) @classmethod def pre_save_post_validation(cls, sender, document, **kwargs): - signal_output.append('pre_save_post_validation signal, %s' % document) - if kwargs.pop('created', False): - signal_output.append('Is created') + signal_output.append("pre_save_post_validation signal, %s" % document) + if kwargs.pop("created", False): + signal_output.append("Is created") else: - signal_output.append('Is updated') + signal_output.append("Is updated") signal_output.append(kwargs) @classmethod def post_save(cls, sender, document, **kwargs): dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() - signal_output.append('post_save signal, %s' % document) - signal_output.append('post_save dirty keys, %s' % dirty_keys) - if kwargs.pop('created', False): - signal_output.append('Is created') + signal_output.append("post_save signal, %s" % document) + signal_output.append("post_save dirty keys, %s" % dirty_keys) + if kwargs.pop("created", False): + signal_output.append("Is created") else: - signal_output.append('Is updated') + signal_output.append("Is updated") signal_output.append(kwargs) @classmethod def pre_delete(cls, sender, document, **kwargs): - signal_output.append('pre_delete signal, %s' % document) + signal_output.append("pre_delete signal, %s" % document) signal_output.append(kwargs) @classmethod def post_delete(cls, sender, document, **kwargs): - signal_output.append('post_delete signal, %s' % document) + signal_output.append("post_delete signal, %s" % document) signal_output.append(kwargs) @classmethod def pre_bulk_insert(cls, sender, documents, **kwargs): - signal_output.append('pre_bulk_insert signal, %s' % documents) + signal_output.append("pre_bulk_insert signal, %s" % documents) signal_output.append(kwargs) @classmethod def post_bulk_insert(cls, sender, documents, **kwargs): - signal_output.append('post_bulk_insert signal, %s' % documents) - if kwargs.pop('loaded', False): - signal_output.append('Is loaded') + signal_output.append("post_bulk_insert signal, %s" % documents) + if kwargs.pop("loaded", False): + signal_output.append("Is loaded") else: - signal_output.append('Not loaded') + signal_output.append("Not loaded") signal_output.append(kwargs) self.Author = Author @@ -101,12 +104,12 @@ class SignalTests(unittest.TestCase): @classmethod def pre_delete(cls, sender, document, **kwargs): - signal_output.append('pre_delete signal, %s' % document) + signal_output.append("pre_delete signal, %s" % document) signal_output.append(kwargs) @classmethod def post_delete(cls, sender, document, **kwargs): - signal_output.append('post_delete signal, %s' % document) + signal_output.append("post_delete signal, %s" % document) signal_output.append(kwargs) self.Another = Another @@ -117,11 +120,11 @@ class SignalTests(unittest.TestCase): @classmethod def post_save(cls, sender, document, **kwargs): - if 'created' in kwargs: - if kwargs['created']: - signal_output.append('Is created') + if "created" in kwargs: + if kwargs["created"]: + signal_output.append("Is created") else: - signal_output.append('Is updated') + signal_output.append("Is updated") self.ExplicitId = ExplicitId ExplicitId.drop_collection() @@ -136,9 +139,13 @@ class SignalTests(unittest.TestCase): @classmethod def pre_bulk_insert(cls, sender, documents, **kwargs): - signal_output.append('pre_bulk_insert signal, %s' % - [(doc, {'active': documents[n].active}) - for n, doc in enumerate(documents)]) + signal_output.append( + "pre_bulk_insert signal, %s" + % [ + (doc, {"active": documents[n].active}) + for n, doc in enumerate(documents) + ] + ) # make changes here, this is just an example - # it could be anything that needs pre-validation or looks-ups before bulk bulk inserting @@ -149,13 +156,17 @@ class SignalTests(unittest.TestCase): @classmethod def post_bulk_insert(cls, sender, documents, **kwargs): - signal_output.append('post_bulk_insert signal, %s' % - [(doc, {'active': documents[n].active}) - for n, doc in enumerate(documents)]) - if kwargs.pop('loaded', False): - signal_output.append('Is loaded') + signal_output.append( + "post_bulk_insert signal, %s" + % [ + (doc, {"active": documents[n].active}) + for n, doc in enumerate(documents) + ] + ) + if kwargs.pop("loaded", False): + signal_output.append("Is loaded") else: - signal_output.append('Not loaded') + signal_output.append("Not loaded") signal_output.append(kwargs) self.Post = Post @@ -178,7 +189,9 @@ class SignalTests(unittest.TestCase): signals.pre_init.connect(Author.pre_init, sender=Author) signals.post_init.connect(Author.post_init, sender=Author) signals.pre_save.connect(Author.pre_save, sender=Author) - signals.pre_save_post_validation.connect(Author.pre_save_post_validation, sender=Author) + signals.pre_save_post_validation.connect( + Author.pre_save_post_validation, sender=Author + ) signals.post_save.connect(Author.post_save, sender=Author) signals.pre_delete.connect(Author.pre_delete, sender=Author) signals.post_delete.connect(Author.post_delete, sender=Author) @@ -199,7 +212,9 @@ class SignalTests(unittest.TestCase): signals.post_delete.disconnect(self.Author.post_delete) signals.pre_delete.disconnect(self.Author.pre_delete) signals.post_save.disconnect(self.Author.post_save) - signals.pre_save_post_validation.disconnect(self.Author.pre_save_post_validation) + signals.pre_save_post_validation.disconnect( + self.Author.pre_save_post_validation + ) signals.pre_save.disconnect(self.Author.pre_save) signals.pre_bulk_insert.disconnect(self.Author.pre_bulk_insert) signals.post_bulk_insert.disconnect(self.Author.post_bulk_insert) @@ -236,203 +251,236 @@ class SignalTests(unittest.TestCase): """ Model saves should throw some signals. """ def create_author(): - self.Author(name='Bill Shakespeare') + self.Author(name="Bill Shakespeare") def bulk_create_author_with_load(): - a1 = self.Author(name='Bill Shakespeare') + a1 = self.Author(name="Bill Shakespeare") self.Author.objects.insert([a1], load_bulk=True) def bulk_create_author_without_load(): - a1 = self.Author(name='Bill Shakespeare') + a1 = self.Author(name="Bill Shakespeare") self.Author.objects.insert([a1], load_bulk=False) def load_existing_author(): - a = self.Author(name='Bill Shakespeare') + a = self.Author(name="Bill Shakespeare") a.save() self.get_signal_output(lambda: None) # eliminate signal output - a1 = self.Author.objects(name='Bill Shakespeare')[0] + a1 = self.Author.objects(name="Bill Shakespeare")[0] - self.assertEqual(self.get_signal_output(create_author), [ - "pre_init signal, Author", - {'name': 'Bill Shakespeare'}, - "post_init signal, Bill Shakespeare, document._created = True", - ]) + self.assertEqual( + self.get_signal_output(create_author), + [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + ], + ) - a1 = self.Author(name='Bill Shakespeare') - self.assertEqual(self.get_signal_output(a1.save), [ - "pre_save signal, Bill Shakespeare", - {}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {} - ]) + a1 = self.Author(name="Bill Shakespeare") + self.assertEqual( + self.get_signal_output(a1.save), + [ + "pre_save signal, Bill Shakespeare", + {}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {}, + ], + ) a1.reload() - a1.name = 'William Shakespeare' - self.assertEqual(self.get_signal_output(a1.save), [ - "pre_save signal, William Shakespeare", - {}, - "pre_save_post_validation signal, William Shakespeare", - "Is updated", - {}, - "post_save signal, William Shakespeare", - "post_save dirty keys, ['name']", - "Is updated", - {} - ]) + a1.name = "William Shakespeare" + self.assertEqual( + self.get_signal_output(a1.save), + [ + "pre_save signal, William Shakespeare", + {}, + "pre_save_post_validation signal, William Shakespeare", + "Is updated", + {}, + "post_save signal, William Shakespeare", + "post_save dirty keys, ['name']", + "Is updated", + {}, + ], + ) - self.assertEqual(self.get_signal_output(a1.delete), [ - 'pre_delete signal, William Shakespeare', - {}, - 'post_delete signal, William Shakespeare', - {} - ]) + self.assertEqual( + self.get_signal_output(a1.delete), + [ + "pre_delete signal, William Shakespeare", + {}, + "post_delete signal, William Shakespeare", + {}, + ], + ) - self.assertEqual(self.get_signal_output(load_existing_author), [ - "pre_init signal, Author", - {'id': 2, 'name': 'Bill Shakespeare'}, - "post_init signal, Bill Shakespeare, document._created = False" - ]) + self.assertEqual( + self.get_signal_output(load_existing_author), + [ + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + ], + ) - self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [ - 'pre_init signal, Author', - {'name': 'Bill Shakespeare'}, - 'post_init signal, Bill Shakespeare, document._created = True', - 'pre_bulk_insert signal, []', - {}, - 'pre_init signal, Author', - {'id': 3, 'name': 'Bill Shakespeare'}, - 'post_init signal, Bill Shakespeare, document._created = False', - 'post_bulk_insert signal, []', - 'Is loaded', - {} - ]) + self.assertEqual( + self.get_signal_output(bulk_create_author_with_load), + [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "pre_init signal, Author", + {"id": 3, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {}, + ], + ) - self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [ - "pre_init signal, Author", - {'name': 'Bill Shakespeare'}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_bulk_insert signal, []", - {}, - "post_bulk_insert signal, []", - "Not loaded", - {} - ]) + self.assertEqual( + self.get_signal_output(bulk_create_author_without_load), + [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {}, + "post_bulk_insert signal, []", + "Not loaded", + {}, + ], + ) def test_signal_kwargs(self): """ Make sure signal_kwargs is passed to signals calls. """ def live_and_let_die(): - a = self.Author(name='Bill Shakespeare') - a.save(signal_kwargs={'live': True, 'die': False}) - a.delete(signal_kwargs={'live': False, 'die': True}) + a = self.Author(name="Bill Shakespeare") + a.save(signal_kwargs={"live": True, "die": False}) + a.delete(signal_kwargs={"live": False, "die": True}) - self.assertEqual(self.get_signal_output(live_and_let_die), [ - "pre_init signal, Author", - {'name': 'Bill Shakespeare'}, - "post_init signal, Bill Shakespeare, document._created = True", - "pre_save signal, Bill Shakespeare", - {'die': False, 'live': True}, - "pre_save_post_validation signal, Bill Shakespeare", - "Is created", - {'die': False, 'live': True}, - "post_save signal, Bill Shakespeare", - "post_save dirty keys, ['name']", - "Is created", - {'die': False, 'live': True}, - 'pre_delete signal, Bill Shakespeare', - {'die': True, 'live': False}, - 'post_delete signal, Bill Shakespeare', - {'die': True, 'live': False} - ]) + self.assertEqual( + self.get_signal_output(live_and_let_die), + [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_save signal, Bill Shakespeare", + {"die": False, "live": True}, + "pre_save_post_validation signal, Bill Shakespeare", + "Is created", + {"die": False, "live": True}, + "post_save signal, Bill Shakespeare", + "post_save dirty keys, ['name']", + "Is created", + {"die": False, "live": True}, + "pre_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + "post_delete signal, Bill Shakespeare", + {"die": True, "live": False}, + ], + ) def bulk_create_author(): - a1 = self.Author(name='Bill Shakespeare') - self.Author.objects.insert([a1], signal_kwargs={'key': True}) + a1 = self.Author(name="Bill Shakespeare") + self.Author.objects.insert([a1], signal_kwargs={"key": True}) - self.assertEqual(self.get_signal_output(bulk_create_author), [ - 'pre_init signal, Author', - {'name': 'Bill Shakespeare'}, - 'post_init signal, Bill Shakespeare, document._created = True', - 'pre_bulk_insert signal, []', - {'key': True}, - 'pre_init signal, Author', - {'id': 2, 'name': 'Bill Shakespeare'}, - 'post_init signal, Bill Shakespeare, document._created = False', - 'post_bulk_insert signal, []', - 'Is loaded', - {'key': True} - ]) + self.assertEqual( + self.get_signal_output(bulk_create_author), + [ + "pre_init signal, Author", + {"name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = True", + "pre_bulk_insert signal, []", + {"key": True}, + "pre_init signal, Author", + {"id": 2, "name": "Bill Shakespeare"}, + "post_init signal, Bill Shakespeare, document._created = False", + "post_bulk_insert signal, []", + "Is loaded", + {"key": True}, + ], + ) def test_queryset_delete_signals(self): """ Queryset delete should throw some signals. """ - self.Another(name='Bill Shakespeare').save() - self.assertEqual(self.get_signal_output(self.Another.objects.delete), [ - 'pre_delete signal, Bill Shakespeare', - {}, - 'post_delete signal, Bill Shakespeare', - {} - ]) + self.Another(name="Bill Shakespeare").save() + self.assertEqual( + self.get_signal_output(self.Another.objects.delete), + [ + "pre_delete signal, Bill Shakespeare", + {}, + "post_delete signal, Bill Shakespeare", + {}, + ], + ) def test_signals_with_explicit_doc_ids(self): """ Model saves must have a created flag the first time.""" ei = self.ExplicitId(id=123) # post save must received the created flag, even if there's already # an object id present - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) # second time, it must be an update - self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) def test_signals_with_switch_collection(self): ei = self.ExplicitId(id=123) ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) def test_signals_with_switch_db(self): - connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') + connect("mongoenginetest") + register_connection("testdb-1", "mongoenginetest2") ei = self.ExplicitId(id=123) ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + self.assertEqual(self.get_signal_output(ei.save), ["Is updated"]) ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save), ["Is created"]) def test_signals_bulk_insert(self): def bulk_set_active_post(): posts = [ - self.Post(title='Post 1'), - self.Post(title='Post 2'), - self.Post(title='Post 3') + self.Post(title="Post 1"), + self.Post(title="Post 2"), + self.Post(title="Post 3"), ] self.Post.objects.insert(posts) results = self.get_signal_output(bulk_set_active_post) - self.assertEqual(results, [ - "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", - {}, - "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", - 'Is loaded', - {} - ]) + self.assertEqual( + results, + [ + "pre_bulk_insert signal, [(, {'active': False}), (, {'active': False}), (, {'active': False})]", + {}, + "post_bulk_insert signal, [(, {'active': True}), (, {'active': True}), (, {'active': True})]", + "Is loaded", + {}, + ], + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 562cc1ff..2d1e8b00 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,32 +7,33 @@ signal_output = [] class LazyRegexCompilerTest(unittest.TestCase): - def test_lazy_regex_compiler_verify_laziness_of_descriptor(self): class UserEmail(object): - EMAIL_REGEX = LazyRegexCompiler('@', flags=32) + EMAIL_REGEX = LazyRegexCompiler("@", flags=32) - descriptor = UserEmail.__dict__['EMAIL_REGEX'] + descriptor = UserEmail.__dict__["EMAIL_REGEX"] self.assertIsNone(descriptor._compiled_regex) regex = UserEmail.EMAIL_REGEX - self.assertEqual(regex, re.compile('@', flags=32)) - self.assertEqual(regex.search('user@domain.com').group(), '@') + self.assertEqual(regex, re.compile("@", flags=32)) + self.assertEqual(regex.search("user@domain.com").group(), "@") user_email = UserEmail() self.assertIs(user_email.EMAIL_REGEX, UserEmail.EMAIL_REGEX) def test_lazy_regex_compiler_verify_cannot_set_descriptor_on_instance(self): class UserEmail(object): - EMAIL_REGEX = LazyRegexCompiler('@') + EMAIL_REGEX = LazyRegexCompiler("@") user_email = UserEmail() with self.assertRaises(AttributeError): - user_email.EMAIL_REGEX = re.compile('@') + user_email.EMAIL_REGEX = re.compile("@") def test_lazy_regex_compiler_verify_can_override_class_attr(self): class UserEmail(object): - EMAIL_REGEX = LazyRegexCompiler('@') + EMAIL_REGEX = LazyRegexCompiler("@") - UserEmail.EMAIL_REGEX = re.compile('cookies') - self.assertEqual(UserEmail.EMAIL_REGEX.search('Cake & cookies').group(), 'cookies') + UserEmail.EMAIL_REGEX = re.compile("cookies") + self.assertEqual( + UserEmail.EMAIL_REGEX.search("Cake & cookies").group(), "cookies" + ) diff --git a/tests/utils.py b/tests/utils.py index 27d5ada7..eb3f016f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ from mongoengine.connection import get_db, disconnect_all from mongoengine.mongodb_support import get_mongodb_version -MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database +MONGO_TEST_DB = "mongoenginetest" # standard name for the test database class MongoDBTestCase(unittest.TestCase): @@ -53,12 +53,15 @@ def _decorated_with_ver_requirement(func, mongo_version_req, oper): :param mongo_version_req: The mongodb version requirement (tuple(int, int)) :param oper: The operator to apply (e.g: operator.ge) """ + def _inner(*args, **kwargs): mongodb_v = get_mongodb_version() if oper(mongodb_v, mongo_version_req): return func(*args, **kwargs) - raise SkipTest('Needs MongoDB v{}+'.format('.'.join(str(n) for n in mongo_version_req))) + raise SkipTest( + "Needs MongoDB v{}+".format(".".join(str(n) for n in mongo_version_req)) + ) _inner.__name__ = func.__name__ _inner.__doc__ = func.__doc__