Merge remote-tracking branch 'origin/pr/268'

This commit is contained in:
Ross Lawley 2013-04-16 20:49:33 +00:00
commit 4401a309ee
2 changed files with 42 additions and 4 deletions

View File

@ -558,8 +558,11 @@ class DocumentMetaclass(type):
# Set _fields and db_field maps # Set _fields and db_field maps
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_fields_ordered'] = tuple(i[1]
for k, v in doc_fields.iteritems()]) for i in sorted((v.creation_counter, v.name)
for v in doc_fields.itervalues()))
attrs['_db_field_map'] = dict((k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems())
attrs['_reverse_db_field_map'] = dict( attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems()) (v, k) for k, v in attrs['_db_field_map'].iteritems())
@ -902,7 +905,17 @@ class BaseDocument(object):
_dynamic_lock = True _dynamic_lock = True
_initialised = False _initialised = False
def __init__(self, **values): def __init__(self, *args, **values):
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
field = iter(self._fields_ordered)
for value in args:
name = next(field)
if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
values[name] = value
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {} self._data = {}
@ -1316,7 +1329,10 @@ class BaseDocument(object):
return value return value
def __iter__(self): def __iter__(self):
return iter(self._fields) if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.

View File

@ -1386,6 +1386,28 @@ class DocumentTest(unittest.TestCase):
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 30) self.assertEqual(person.age, 30)
def test_positional_creation(self):
"""Ensure that document may be created using positional arguments.
"""
person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments.
"""
person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments
"""
def construct_bad_instance():
return self.Person("Test User", 42, name="Bad User")
self.assertRaises(TypeError, construct_bad_instance)
def test_to_dbref(self): def test_to_dbref(self):
"""Ensure that you can get a dbref of a document""" """Ensure that you can get a dbref of a document"""