Compare commits

..

1 Commits

Author SHA1 Message Date
Stefan Wojcik
df12211c25 dont let the MongoDB URI override connection settings it doesnt explicitly specify 2016-12-03 21:08:18 -05:00
4 changed files with 77 additions and 60 deletions

View File

@@ -121,7 +121,7 @@ class BaseDocument(object):
else:
self._data[key] = value
# Set any get_<field>_display methods
# Set any get_fieldname_display methods
self.__set_field_display()
if self._dynamic:
@@ -1005,18 +1005,19 @@ class BaseDocument(object):
return '.'.join(parts)
def __set_field_display(self):
"""For each field that specifies choices, create a
get_<field>_display method.
"""
fields_with_choices = [(n, f) for n, f in self._fields.items()
if f.choices]
for attr_name, field in fields_with_choices:
setattr(self,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
def __get_field_display(self, field):
"""Return the display value for a choice field"""
"""Returns the display value for a choice field"""
value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value)

View File

@@ -25,7 +25,8 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
"""Add a connection.
@@ -70,20 +71,26 @@ def register_connection(alias, name=None, host=None, port=None,
resolved_hosts = []
for entity in conn_host:
# Handle uri style connections
# Handle Mongomock
if entity.startswith('mongomock://'):
conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
elif '://' in entity:
uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity)
conn_settings.update({
'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'),
'password': uri_dict.get('password'),
'read_preference': read_preference,
})
if uri_dict.get('database'):
conn_settings['name'] = uri_dict.get('database')
for param in ('read_preference', 'username', 'password'):
if uri_dict.get(param):
conn_settings[param] = uri_dict[param]
uri_options = uri_dict['options']
if 'replicaset' in uri_options:
conn_settings['replicaSet'] = True

View File

@@ -1047,7 +1047,7 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection()
def test_list_assignment(self):
"""Ensure that list field element assignment and slicing work
"""Ensure that list field element assignment and slicing work
"""
class BlogPost(Document):
info = ListField()
@@ -1057,12 +1057,12 @@ class FieldTest(unittest.TestCase):
post = BlogPost()
post.info = ['e1', 'e2', 3, '4', 5]
post.save()
post.info[0] = 1
post.save()
post.reload()
self.assertEqual(post.info[0], 1)
post.info[1:3] = ['n2', 'n3']
post.save()
post.reload()
@@ -1209,7 +1209,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(simple.widgets, [4])
def test_list_field_with_negative_indices(self):
class Simple(Document):
widgets = ListField()
@@ -1823,7 +1823,7 @@ class FieldTest(unittest.TestCase):
'parent': "50a234ea469ac1eda42d347d"})
mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId))
def test_cached_reference_field_get_and_save(self):
"""
Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo
@@ -1835,11 +1835,11 @@ class FieldTest(unittest.TestCase):
class Ocorrence(Document):
person = StringField()
animal = CachedReferenceField(Animal)
Animal.drop_collection()
Ocorrence.drop_collection()
Ocorrence(person="testte",
Ocorrence(person="testte",
animal=Animal(name="Leopard", tag="heavy").save()).save()
p = Ocorrence.objects.get()
p.person = 'new_testte'
@@ -3001,32 +3001,28 @@ class FieldTest(unittest.TestCase):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
Shirt.drop_collection()
shirt1 = Shirt()
shirt2 = Shirt()
shirt = Shirt()
# Make sure get_<field>_display returns the default value (or None)
self.assertEqual(shirt1.get_size_display(), None)
self.assertEqual(shirt1.get_style_display(), 'Wide')
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
shirt1.size = 'XXL'
shirt1.style = 'B'
shirt2.size = 'M'
shirt2.style = 'S'
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
# Set as Z - an invalid choice
shirt1.size = 'Z'
shirt1.style = 'Z'
self.assertEqual(shirt1.get_size_display(), 'Z')
self.assertEqual(shirt1.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt1.validate)
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.

View File

@@ -174,19 +174,9 @@ class ConnectionTest(unittest.TestCase):
c.mongoenginetest.system.users.remove({})
def test_connect_uri_without_db(self):
"""Ensure connect() method works properly with uri's without database_name
"""Ensure connect() method works properly if the URI doesn't
include a database name.
"""
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/')
conn = get_connection()
@@ -196,8 +186,31 @@ class ConnectionTest(unittest.TestCase):
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
def test_connect_uri_default_db(self):
"""Ensure connect() defaults to the right database name if
the URI and the database_name don't explicitly specify it.
"""
connect(host='mongodb://localhost/')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'test')
def test_uri_without_credentials_doesnt_override_conn_settings(self):
"""Ensure connect() uses the username & password params if the URI
doesn't explicitly specify them.
"""
c = connect(host='mongodb://localhost/mongoenginetest',
username='user',
password='pass')
# OperationFailure means that mongoengine attempted authentication
# w/ the provided username/password and failed - that's the desired
# behavior. If the MongoDB URI would override the credentials
self.assertRaises(OperationFailure, get_db)
def test_connect_uri_with_authsource(self):
"""Ensure that the connect() method works well with