Merge branch 'dev' into feature/where

This commit is contained in:
Ross Lawley 2011-08-16 09:54:56 +01:00
commit 89ad7ef1ab
8 changed files with 163 additions and 30 deletions

View File

@ -53,7 +53,7 @@ Changes in dev
- Added reverse delete rules
- Fixed issue with unset operation
- Fixed Q-object bug
- Added ``QuerySet.all_fields`` resets previous .only() and .exlude()
- Added ``QuerySet.all_fields`` resets previous .only() and .exclude()
- Added ``QuerySet.exclude``
- Added django style choices
- Fixed order and filter issue

View File

@ -381,8 +381,8 @@ class DocumentMetaclass(type):
attr_value.db_field = attr_name
doc_fields[attr_name] = attr_value
attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items()])
attrs['_reverse_db_field_map'] = dict([(v.db_field, k) for k, v in doc_fields.items()])
attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items() if k!=v.db_field])
attrs['_reverse_db_field_map'] = dict([(v, k) for k, v in attrs['_db_field_map'].items()])
new_class = super_new(cls, name, bases, attrs)
for field in new_class._fields.values():
@ -577,6 +577,7 @@ class BaseDocument(object):
signals.pre_init.send(self.__class__, document=self, values=values)
self._data = {}
self._initialised = False
# Assign default values to instance
for attr_name, field in self._fields.items():
value = getattr(self, attr_name, None)
@ -720,12 +721,18 @@ class BaseDocument(object):
field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
elif isinstance(field, (list, tuple)) and db_field_name not in _changed_fields: # Loop list fields as they contain documents
for index, value in enumerate(field):
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(field, 'items'):
iterator = enumerate(field)
else:
iterator = field.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
return _changed_fields
def _delta(self):
@ -735,7 +742,6 @@ class BaseDocument(object):
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
if hasattr(self, '_changed_fields'):
@ -762,7 +768,7 @@ class BaseDocument(object):
if value:
continue
# If we've set a value that ain't the default value unset it.
# If we've set a value that ain't the default value dont unset it.
default = None
if path in self._fields:
@ -774,7 +780,7 @@ class BaseDocument(object):
for p in parts:
if p.isdigit():
d = d[int(p)]
elif hasattr(d, '__getattribute__'):
elif hasattr(d, '__getattribute__') and not isinstance(d, dict):
real_path = d._reverse_db_field_map.get(p, p)
d = getattr(d, real_path)
else:
@ -789,8 +795,8 @@ class BaseDocument(object):
if default is not None:
if callable(default):
default = default()
if default != value:
continue
if default != value:
continue
del(set_data[path])
unset_data[path] = 1

View File

@ -3,6 +3,7 @@ import operator
import pymongo
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
from fields import ReferenceField
from connection import _get_db
from queryset import QuerySet
from document import Document
@ -32,8 +33,16 @@ class DeReference(object):
items = [i for i in items]
self.max_depth = max_depth
doc_type = None
if instance and instance._fields:
doc_type = instance._fields[name].field
if isinstance(doc_type, ReferenceField):
doc_type = doc_type.document_type
self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects()
self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name, get)
def _find_references(self, items, depth=0):
@ -80,7 +89,7 @@ class DeReference(object):
depth += 1
return reference_map
def _fetch_objects(self):
def _fetch_objects(self, doc_type=None):
"""Fetch all references and convert to their document objects
"""
object_map = {}
@ -94,7 +103,10 @@ class DeReference(object):
else: # Generic reference: use the refs data to convert to document
references = _get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
doc = get_document(ref['_cls'])._from_son(ref)
if '_cls' in ref:
doc = get_document(ref['_cls'])._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
return object_map

View File

@ -3,6 +3,7 @@ from mongoengine import *
from django.utils.hashcompat import md5_constructor, sha_constructor
from django.utils.encoding import smart_str
from django.contrib.auth.models import AnonymousUser
from django.utils.translation import ugettext_lazy as _
import datetime
@ -21,16 +22,32 @@ class User(Document):
"""A User document that aims to mirror most of the API specified by Django
at http://docs.djangoproject.com/en/dev/topics/auth/#users
"""
username = StringField(max_length=30, required=True)
first_name = StringField(max_length=30)
last_name = StringField(max_length=30)
email = StringField()
password = StringField(max_length=128)
is_staff = BooleanField(default=False)
is_active = BooleanField(default=True)
is_superuser = BooleanField(default=False)
last_login = DateTimeField(default=datetime.datetime.now)
date_joined = DateTimeField(default=datetime.datetime.now)
username = StringField(max_length=30, required=True,
verbose_name=_('username'),
help_text=_("Required. 30 characters or fewer. Letters, numbers and @/./+/-/_ characters"))
first_name = StringField(max_length=30,
verbose_name=_('first name'))
last_name = StringField(max_length=30,
verbose_name=_('last name'))
email = EmailField(verbose_name=_('e-mail address'))
password = StringField(max_length=128,
verbose_name=_('password'),
help_text=_("Use '[algo]$[salt]$[hexdigest]' or use the <a href=\"password/\">change password form</a>."))
is_staff = BooleanField(default=False,
verbose_name=_('staff status'),
help_text=_("Designates whether the user can log into this admin site."))
is_active = BooleanField(default=True,
verbose_name=_('active'),
help_text=_("Designates whether this user should be treated as active. Unselect this instead of deleting accounts."))
is_superuser = BooleanField(default=False,
verbose_name=_('superuser status'),
help_text=_("Designates that this user has all permissions without explicitly assigning them."))
last_login = DateTimeField(default=datetime.datetime.now,
verbose_name=_('last login'))
date_joined = DateTimeField(default=datetime.datetime.now,
verbose_name=_('date joined'))
meta = {
'indexes': [

View File

@ -143,13 +143,15 @@ class Document(BaseDocument):
doc = self.to_mongo()
created = '_id' not in doc
created = '_id' in doc
creation_mode = force_insert or not created
try:
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options)
if created:
object_id = collection.save(doc, safe=safe, **write_options)
if creation_mode:
if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options)
else:
object_id = collection.save(doc, safe=safe, **write_options)
else:
object_id = doc['_id']
updates, removals = self._delta()
@ -191,7 +193,7 @@ class Document(BaseDocument):
reset_changed_fields(field, inspected_docs)
reset_changed_fields(self)
signals.post_save.send(self.__class__, document=self, created=created)
signals.post_save.send(self.__class__, document=self, created=creation_mode)
def update(self, **kwargs):
"""Performs an update on the :class:`~mongoengine.Document`

View File

@ -911,14 +911,26 @@ class SequenceField(IntField):
if instance is None:
return self
if not instance._data:
return
value = instance._data.get(self.name)
if not value and instance._initialised:
value = self.generate_new_value()
instance._data[self.name] = value
instance._mark_as_changed(self.name)
return value
def __set__(self, instance, value):
if value is None and instance._initialised:
value = self.generate_new_value()
return super(SequenceField, self).__set__(instance, value)
def to_python(self, value):
if value is None:
value = self.generate_new_value()

View File

@ -289,6 +289,31 @@ class DocumentTest(unittest.TestCase):
Zoo.drop_collection()
Animal.drop_collection()
def test_reference_inheritance(self):
class Stats(Document):
created = DateTimeField(default=datetime.now)
meta = {'allow_inheritance': False}
class CompareStats(Document):
generated = DateTimeField(default=datetime.now)
stats = ListField(ReferenceField(Stats))
Stats.drop_collection()
CompareStats.drop_collection()
list_stats = []
for i in xrange(10):
s = Stats()
s.save()
list_stats.append(s)
cmp_stats = CompareStats(stats=list_stats)
cmp_stats.save()
self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_inheritance(self):
"""Ensure that document may inherit fields from a superclass document.
"""
@ -1048,6 +1073,26 @@ class DocumentTest(unittest.TestCase):
except ValidationError:
self.fail()
def test_save_to_a_value_that_equates_to_false(self):
class Thing(EmbeddedDocument):
count = IntField()
class User(Document):
thing = EmbeddedDocumentField(Thing)
User.drop_collection()
user = User(thing=Thing(count=1))
user.save()
user.reload()
user.thing.count = 0
user.save()
user.reload()
self.assertEquals(user.thing.count, 0)
def test_save_max_recursion_not_hit(self):
class Person(Document):
@ -1484,6 +1529,18 @@ class DocumentTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
doc.save()
doc.reload()
doc.dict_field['Embedded'] = embedded_1
doc.save()
doc.reload()
doc.dict_field['Embedded'].string_field = 'Hello World'
self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field'])
self.assertEquals(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {}))
def test_delta_db_field(self):
class Doc(Document):
@ -1775,7 +1832,8 @@ class DocumentTest(unittest.TestCase):
person.save()
person = self.Person.objects.get()
self.assertTrue(person.comments_dict['first_post'].published)
self.assertFalse(person.comments_dict['first_post'].published)
def test_delete(self):
"""Ensure that document may be deleted using the delete method.
"""

View File

@ -1425,6 +1425,32 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
self.assertEqual(c['next'], 10)
def test_sequence_fields_reload(self):
class Animal(Document):
counter = SequenceField()
type = StringField()
self.db['mongoengine.counters'].drop()
Animal.drop_collection()
a = Animal(type="Boi")
a.save()
self.assertEqual(a.counter, 1)
a.reload()
self.assertEqual(a.counter, 1)
a.counter = None
self.assertEqual(a.counter, 2)
a.save()
self.assertEqual(a.counter, 2)
a = Animal.objects.first()
self.assertEqual(a.counter, 2)
a.reload()
self.assertEqual(a.counter, 2)
def test_multiple_sequence_fields_on_docs(self):
class Animal(Document):