Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
df12211c25 dont let the MongoDB URI override connection settings it doesnt explicitly specify 2016-12-03 21:08:18 -05:00
10 changed files with 96 additions and 122 deletions

View File

@@ -438,7 +438,7 @@ class StrictDict(object):
__slots__ = allowed_keys_tuple __slots__ = allowed_keys_tuple
def __repr__(self): def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
cls._classes[allowed_keys] = SpecificStrictDict cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys] return cls._classes[allowed_keys]

View File

@@ -121,7 +121,7 @@ class BaseDocument(object):
else: else:
self._data[key] = value self._data[key] = value
# Set any get_<field>_display methods # Set any get_fieldname_display methods
self.__set_field_display() self.__set_field_display()
if self._dynamic: if self._dynamic:
@@ -1005,18 +1005,19 @@ class BaseDocument(object):
return '.'.join(parts) return '.'.join(parts)
def __set_field_display(self): def __set_field_display(self):
"""For each field that specifies choices, create a """Dynamically set the display value for a field with choices"""
get_<field>_display method. for attr_name, field in self._fields.items():
""" if field.choices:
fields_with_choices = [(n, f) for n, f in self._fields.items() if self._dynamic:
if f.choices] obj = self
for attr_name, field in fields_with_choices: else:
setattr(self, obj = type(self)
setattr(obj,
'get_%s_display' % attr_name, 'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field)) partial(self.__get_field_display, field=field))
def __get_field_display(self, field): def __get_field_display(self, field):
"""Return the display value for a choice field""" """Returns the display value for a choice field"""
value = getattr(self, field.name) value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)): if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value) return dict(field.choices).get(value, value)

View File

@@ -25,7 +25,8 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None, def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE, read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None, username=None, password=None,
authentication_source=None,
authentication_mechanism=None, authentication_mechanism=None,
**kwargs): **kwargs):
"""Add a connection. """Add a connection.
@@ -70,20 +71,26 @@ def register_connection(alias, name=None, host=None, port=None,
resolved_hosts = [] resolved_hosts = []
for entity in conn_host: for entity in conn_host:
# Handle uri style connections
# Handle Mongomock
if entity.startswith('mongomock://'): if entity.startswith('mongomock://'):
conn_settings['is_mock'] = True conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
elif '://' in entity: elif '://' in entity:
uri_dict = uri_parser.parse_uri(entity) uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity) resolved_hosts.append(entity)
conn_settings.update({
'name': uri_dict.get('database') or name, if uri_dict.get('database'):
'username': uri_dict.get('username'), conn_settings['name'] = uri_dict.get('database')
'password': uri_dict.get('password'),
'read_preference': read_preference, for param in ('read_preference', 'username', 'password'):
}) if uri_dict.get(param):
conn_settings[param] = uri_dict[param]
uri_options = uri_dict['options'] uri_options = uri_dict['options']
if 'replicaset' in uri_options: if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True conn_settings['replicaSet'] = True

View File

@@ -577,7 +577,7 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if value is not None and not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
value = self.document_type._from_son(value) value = self.document_type._from_son(value)
super(EmbeddedDocumentField, self).prepare_query_value(op, value) super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value) return self.to_mongo(value)

View File

@@ -933,14 +933,6 @@ class BaseQuerySet(object):
queryset._ordering = queryset._get_order_by(keys) queryset._ordering = queryset._get_order_by(keys)
return queryset return queryset
def comment(self, text):
"""Add a comment to the query.
See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment
for details.
"""
return self._chainable_method("comment", text)
def explain(self, format=False): def explain(self, format=False):
"""Return an explain plan record for the """Return an explain plan record for the
:class:`~mongoengine.queryset.QuerySet`\ 's cursor. :class:`~mongoengine.queryset.QuerySet`\ 's cursor.

View File

@@ -2,8 +2,10 @@
import unittest import unittest
import sys import sys
sys.path[0:0] = [""]
import pymongo import pymongo
from random import randint
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from datetime import datetime from datetime import datetime
@@ -15,9 +17,11 @@ __all__ = ("IndexesTest", )
class IndexesTest(unittest.TestCase): class IndexesTest(unittest.TestCase):
_MAX_RAND = 10 ** 10
def setUp(self): def setUp(self):
self.connection = connect(db='mongoenginetest') self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
self.connection = connect(db=self.db_name)
self.db = get_db() self.db = get_db()
class Person(Document): class Person(Document):

View File

@@ -3001,32 +3001,28 @@ class FieldTest(unittest.TestCase):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=( style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
Shirt.drop_collection() Shirt.drop_collection()
shirt1 = Shirt() shirt = Shirt()
shirt2 = Shirt()
# Make sure get_<field>_display returns the default value (or None) self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt1.get_size_display(), None) self.assertEqual(shirt.get_style_display(), 'Small')
self.assertEqual(shirt1.get_style_display(), 'Wide')
shirt1.size = 'XXL' shirt.size = "XXL"
shirt1.style = 'B' shirt.style = "B"
shirt2.size = 'M' self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
shirt2.style = 'S' self.assertEqual(shirt.get_style_display(), 'Baggy')
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')
# Set as Z - an invalid choice # Set as Z - an invalid choice
shirt1.size = 'Z' shirt.size = "Z"
shirt1.style = 'Z' shirt.style = "Z"
self.assertEqual(shirt1.get_size_display(), 'Z') self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt1.get_style_display(), 'Z') self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt1.validate) self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_simple_choices_validation(self): def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values. """Ensure that value is in a container of allowed values.

View File

@@ -339,6 +339,7 @@ class QuerySetTest(unittest.TestCase):
def test_update_write_concern(self): def test_update_write_concern(self):
"""Test that passing write_concern works""" """Test that passing write_concern works"""
self.Person.drop_collection() self.Person.drop_collection()
write_concern = {"fsync": True} write_concern = {"fsync": True}
@@ -1238,8 +1239,7 @@ class QuerySetTest(unittest.TestCase):
self.assertFalse('$orderby' in q.get_ops()[0]['query']) self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from """Ensure that an embedded document is properly returned from a query.
a query.
""" """
class User(EmbeddedDocument): class User(EmbeddedDocument):
name = StringField() name = StringField()
@@ -1250,31 +1250,16 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
BlogPost.objects.create( post = BlogPost(content='Had a good coffee today...')
author=User(name='Test User'), post.author = User(name='Test User')
content='Had a good coffee today...' post.save()
)
result = BlogPost.objects.first() result = BlogPost.objects.first()
self.assertTrue(isinstance(result.author, User)) self.assertTrue(isinstance(result.author, User))
self.assertEqual(result.author.name, 'Test User') self.assertEqual(result.author.name, 'Test User')
def test_find_empty_embedded(self):
"""Ensure that you can save and find an empty embedded document."""
class User(EmbeddedDocument):
name = StringField()
class BlogPost(Document):
content = StringField()
author = EmbeddedDocumentField(User)
BlogPost.drop_collection() BlogPost.drop_collection()
BlogPost.objects.create(content='Anonymous post...')
result = BlogPost.objects.get(author=None)
self.assertEqual(result.author, None)
def test_find_dict_item(self): def test_find_dict_item(self):
"""Ensure that DictField items may be found. """Ensure that DictField items may be found.
""" """
@@ -2214,21 +2199,6 @@ class QuerySetTest(unittest.TestCase):
a.author.name for a in Author.objects.order_by('-author__age')] a.author.name for a in Author.objects.order_by('-author__age')]
self.assertEqual(names, ['User A', 'User B', 'User C']) self.assertEqual(names, ['User A', 'User B', 'User C'])
def test_comment(self):
"""Make sure adding a comment to the query works."""
class User(Document):
age = IntField()
with db_ops_tracker() as q:
adult = (User.objects.filter(age__gte=18)
.comment('looking for an adult')
.first())
ops = q.get_ops()
self.assertEqual(len(ops), 1)
op = ops[0]
self.assertEqual(op['query']['$query'], {'age': {'$gte': 18}})
self.assertEqual(op['query']['$comment'], 'looking for an adult')
def test_map_reduce(self): def test_map_reduce(self):
"""Ensure map/reduce is both mapping and reducing. """Ensure map/reduce is both mapping and reducing.
""" """

View File

@@ -174,19 +174,9 @@ class ConnectionTest(unittest.TestCase):
c.mongoenginetest.system.users.remove({}) c.mongoenginetest.system.users.remove({})
def test_connect_uri_without_db(self): def test_connect_uri_without_db(self):
"""Ensure connect() method works properly with uri's without database_name """Ensure connect() method works properly if the URI doesn't
include a database name.
""" """
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/') connect("mongoenginetest", host='mongodb://localhost/')
conn = get_connection() conn = get_connection()
@@ -196,8 +186,31 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest') self.assertEqual(db.name, 'mongoenginetest')
c.admin.system.users.remove({}) def test_connect_uri_default_db(self):
c.mongoenginetest.system.users.remove({}) """Ensure connect() defaults to the right database name if
the URI and the database_name don't explicitly specify it.
"""
connect(host='mongodb://localhost/')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
"""
c = connect(host='mongodb://localhost/mongoenginetest',
username='user',
password='pass')
# OperationFailure means that mongoengine attempted authentication
# w/ the provided username/password and failed - that's the desired
# behavior. If the MongoDB URI would override the credentials
self.assertRaises(OperationFailure, get_db)
def test_connect_uri_with_authsource(self): def test_connect_uri_with_authsource(self):
"""Ensure that the connect() method works well with """Ensure that the connect() method works well with

View File

@@ -1,5 +1,4 @@
import unittest import unittest
from mongoengine.base.datastructures import StrictDict, SemiStrictDict from mongoengine.base.datastructures import StrictDict, SemiStrictDict
@@ -14,14 +13,6 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype(a=1, b=1, c=1) d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_repr(self):
d = self.dtype(a=1, b=2, c=3)
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
# make sure quotes are escaped properly
d = self.dtype(a='"', b="'", c="")
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self): def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))