Merge branch 'dev' into feature/where

This commit is contained in:
Ross Lawley 2011-08-16 09:54:56 +01:00
commit 89ad7ef1ab
8 changed files with 163 additions and 30 deletions

View File

@ -53,7 +53,7 @@ Changes in dev
- Added reverse delete rules - Added reverse delete rules
- Fixed issue with unset operation - Fixed issue with unset operation
- Fixed Q-object bug - Fixed Q-object bug
- Added ``QuerySet.all_fields`` resets previous .only() and .exlude() - Added ``QuerySet.all_fields`` resets previous .only() and .exclude()
- Added ``QuerySet.exclude`` - Added ``QuerySet.exclude``
- Added django style choices - Added django style choices
- Fixed order and filter issue - Fixed order and filter issue

View File

@ -381,8 +381,8 @@ class DocumentMetaclass(type):
attr_value.db_field = attr_name attr_value.db_field = attr_name
doc_fields[attr_name] = attr_value doc_fields[attr_name] = attr_value
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items()]) attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items() if k!=v.db_field])
attrs['_reverse_db_field_map'] = dict([(v.db_field, k) for k, v in doc_fields.items()]) attrs['_reverse_db_field_map'] = dict([(v, k) for k, v in attrs['_db_field_map'].items()])
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
for field in new_class._fields.values(): for field in new_class._fields.values():
@ -577,6 +577,7 @@ class BaseDocument(object):
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {} self._data = {}
self._initialised = False
# Assign default values to instance # Assign default values to instance
for attr_name, field in self._fields.items(): for attr_name, field in self._fields.items():
value = getattr(self, attr_name, None) value = getattr(self, attr_name, None)
@ -720,12 +721,18 @@ class BaseDocument(object):
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
elif isinstance(field, (list, tuple)) and db_field_name not in _changed_fields: # Loop list fields as they contain documents elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
for index, value in enumerate(field): # Determine the iterator to use
if not hasattr(field, 'items'):
iterator = enumerate(field)
else:
iterator = field.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'): if not hasattr(value, '_get_changed_fields'):
continue continue
list_key = "%s%s." % (key, index) list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
return _changed_fields return _changed_fields
def _delta(self): def _delta(self):
@ -735,7 +742,6 @@ class BaseDocument(object):
# Handles cases where not loaded from_son but has _id # Handles cases where not loaded from_son but has _id
doc = self.to_mongo() doc = self.to_mongo()
set_fields = self._get_changed_fields() set_fields = self._get_changed_fields()
set_data = {} set_data = {}
unset_data = {} unset_data = {}
if hasattr(self, '_changed_fields'): if hasattr(self, '_changed_fields'):
@ -762,7 +768,7 @@ class BaseDocument(object):
if value: if value:
continue continue
# If we've set a value that ain't the default value unset it. # If we've set a value that ain't the default value dont unset it.
default = None default = None
if path in self._fields: if path in self._fields:
@ -774,7 +780,7 @@ class BaseDocument(object):
for p in parts: for p in parts:
if p.isdigit(): if p.isdigit():
d = d[int(p)] d = d[int(p)]
elif hasattr(d, '__getattribute__'): elif hasattr(d, '__getattribute__') and not isinstance(d, dict):
real_path = d._reverse_db_field_map.get(p, p) real_path = d._reverse_db_field_map.get(p, p)
d = getattr(d, real_path) d = getattr(d, real_path)
else: else:

View File

@ -3,6 +3,7 @@ import operator
import pymongo import pymongo
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
from fields import ReferenceField
from connection import _get_db from connection import _get_db
from queryset import QuerySet from queryset import QuerySet
from document import Document from document import Document
@ -32,8 +33,16 @@ class DeReference(object):
items = [i for i in items] items = [i for i in items]
self.max_depth = max_depth self.max_depth = max_depth
doc_type = None
if instance and instance._fields:
doc_type = instance._fields[name].field
if isinstance(doc_type, ReferenceField):
doc_type = doc_type.document_type
self.reference_map = self._find_references(items) self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects() self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name, get) return self._attach_objects(items, 0, instance, name, get)
def _find_references(self, items, depth=0): def _find_references(self, items, depth=0):
@ -80,7 +89,7 @@ class DeReference(object):
depth += 1 depth += 1
return reference_map return reference_map
def _fetch_objects(self): def _fetch_objects(self, doc_type=None):
"""Fetch all references and convert to their document objects """Fetch all references and convert to their document objects
""" """
object_map = {} object_map = {}
@ -94,7 +103,10 @@ class DeReference(object):
else: # Generic reference: use the refs data to convert to document else: # Generic reference: use the refs data to convert to document
references = _get_db()[col].find({'_id': {'$in': refs}}) references = _get_db()[col].find({'_id': {'$in': refs}})
for ref in references: for ref in references:
if '_cls' in ref:
doc = get_document(ref['_cls'])._from_son(ref) doc = get_document(ref['_cls'])._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc object_map[doc.id] = doc
return object_map return object_map

View File

@ -3,6 +3,7 @@ from mongoengine import *
from django.utils.hashcompat import md5_constructor, sha_constructor from django.utils.hashcompat import md5_constructor, sha_constructor
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.utils.translation import ugettext_lazy as _
import datetime import datetime
@ -21,16 +22,32 @@ class User(Document):
"""A User document that aims to mirror most of the API specified by Django """A User document that aims to mirror most of the API specified by Django
at http://docs.djangoproject.com/en/dev/topics/auth/#users at http://docs.djangoproject.com/en/dev/topics/auth/#users
""" """
username = StringField(max_length=30, required=True) username = StringField(max_length=30, required=True,
first_name = StringField(max_length=30) verbose_name=_('username'),
last_name = StringField(max_length=30) help_text=_("Required. 30 characters or fewer. Letters, numbers and @/./+/-/_ characters"))
email = StringField()
password = StringField(max_length=128) first_name = StringField(max_length=30,
is_staff = BooleanField(default=False) verbose_name=_('first name'))
is_active = BooleanField(default=True)
is_superuser = BooleanField(default=False) last_name = StringField(max_length=30,
last_login = DateTimeField(default=datetime.datetime.now) verbose_name=_('last name'))
date_joined = DateTimeField(default=datetime.datetime.now) email = EmailField(verbose_name=_('e-mail address'))
password = StringField(max_length=128,
verbose_name=_('password'),
help_text=_("Use '[algo]$[salt]$[hexdigest]' or use the <a href=\"password/\">change password form</a>."))
is_staff = BooleanField(default=False,
verbose_name=_('staff status'),
help_text=_("Designates whether the user can log into this admin site."))
is_active = BooleanField(default=True,
verbose_name=_('active'),
help_text=_("Designates whether this user should be treated as active. Unselect this instead of deleting accounts."))
is_superuser = BooleanField(default=False,
verbose_name=_('superuser status'),
help_text=_("Designates that this user has all permissions without explicitly assigning them."))
last_login = DateTimeField(default=datetime.datetime.now,
verbose_name=_('last login'))
date_joined = DateTimeField(default=datetime.datetime.now,
verbose_name=_('date joined'))
meta = { meta = {
'indexes': [ 'indexes': [

View File

@ -143,12 +143,14 @@ class Document(BaseDocument):
doc = self.to_mongo() doc = self.to_mongo()
created = '_id' not in doc created = '_id' in doc
creation_mode = force_insert or not created
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if creation_mode:
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options) object_id = collection.insert(doc, safe=safe, **write_options)
if created: else:
object_id = collection.save(doc, safe=safe, **write_options) object_id = collection.save(doc, safe=safe, **write_options)
else: else:
object_id = doc['_id'] object_id = doc['_id']
@ -191,7 +193,7 @@ class Document(BaseDocument):
reset_changed_fields(field, inspected_docs) reset_changed_fields(field, inspected_docs)
reset_changed_fields(self) reset_changed_fields(self)
signals.post_save.send(self.__class__, document=self, created=created) signals.post_save.send(self.__class__, document=self, created=creation_mode)
def update(self, **kwargs): def update(self, **kwargs):
"""Performs an update on the :class:`~mongoengine.Document` """Performs an update on the :class:`~mongoengine.Document`

View File

@ -911,14 +911,26 @@ class SequenceField(IntField):
if instance is None: if instance is None:
return self return self
if not instance._data: if not instance._data:
return return
value = instance._data.get(self.name) value = instance._data.get(self.name)
if not value and instance._initialised: if not value and instance._initialised:
value = self.generate_new_value() value = self.generate_new_value()
instance._data[self.name] = value instance._data[self.name] = value
instance._mark_as_changed(self.name)
return value return value
def __set__(self, instance, value):
if value is None and instance._initialised:
value = self.generate_new_value()
return super(SequenceField, self).__set__(instance, value)
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
value = self.generate_new_value() value = self.generate_new_value()

View File

@ -289,6 +289,31 @@ class DocumentTest(unittest.TestCase):
Zoo.drop_collection() Zoo.drop_collection()
Animal.drop_collection() Animal.drop_collection()
def test_reference_inheritance(self):
class Stats(Document):
created = DateTimeField(default=datetime.now)
meta = {'allow_inheritance': False}
class CompareStats(Document):
generated = DateTimeField(default=datetime.now)
stats = ListField(ReferenceField(Stats))
Stats.drop_collection()
CompareStats.drop_collection()
list_stats = []
for i in xrange(10):
s = Stats()
s.save()
list_stats.append(s)
cmp_stats = CompareStats(stats=list_stats)
cmp_stats.save()
self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_inheritance(self): def test_inheritance(self):
"""Ensure that document may inherit fields from a superclass document. """Ensure that document may inherit fields from a superclass document.
""" """
@ -1048,6 +1073,26 @@ class DocumentTest(unittest.TestCase):
except ValidationError: except ValidationError:
self.fail() self.fail()
def test_save_to_a_value_that_equates_to_false(self):
class Thing(EmbeddedDocument):
count = IntField()
class User(Document):
thing = EmbeddedDocumentField(Thing)
User.drop_collection()
user = User(thing=Thing(count=1))
user.save()
user.reload()
user.thing.count = 0
user.save()
user.reload()
self.assertEquals(user.thing.count, 0)
def test_save_max_recursion_not_hit(self): def test_save_max_recursion_not_hit(self):
class Person(Document): class Person(Document):
@ -1484,6 +1529,18 @@ class DocumentTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field) del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
doc.save()
doc.reload()
doc.dict_field['Embedded'] = embedded_1
doc.save()
doc.reload()
doc.dict_field['Embedded'].string_field = 'Hello World'
self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field'])
self.assertEquals(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {}))
def test_delta_db_field(self): def test_delta_db_field(self):
class Doc(Document): class Doc(Document):
@ -1775,7 +1832,8 @@ class DocumentTest(unittest.TestCase):
person.save() person.save()
person = self.Person.objects.get() person = self.Person.objects.get()
self.assertTrue(person.comments_dict['first_post'].published) self.assertFalse(person.comments_dict['first_post'].published)
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method.
""" """

View File

@ -1425,6 +1425,32 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10) self.assertEqual(c['next'], 10)
def test_sequence_fields_reload(self):
class Animal(Document):
counter = SequenceField()
type = StringField()
self.db['mongoengine.counters'].drop()
Animal.drop_collection()
a = Animal(type="Boi")
a.save()
self.assertEqual(a.counter, 1)
a.reload()
self.assertEqual(a.counter, 1)
a.counter = None
self.assertEqual(a.counter, 2)
a.save()
self.assertEqual(a.counter, 2)
a = Animal.objects.first()
self.assertEqual(a.counter, 2)
a.reload()
self.assertEqual(a.counter, 2)
def test_multiple_sequence_fields_on_docs(self): def test_multiple_sequence_fields_on_docs(self):
class Animal(Document): class Animal(Document):