diff --git a/mongoengine/base.py b/mongoengine/base.py index 8c038c8c..b4920eec 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -15,13 +15,14 @@ class BaseField(object): _index_with_types = True def __init__(self, name=None, required=False, default=None, unique=False, - unique_with=None, primary_key=False): + unique_with=None, primary_key=False, modified=False): self.name = name if not primary_key else '_id' self.required = required or primary_key self.default = default self.unique = bool(unique or unique_with) self.unique_with = unique_with self.primary_key = primary_key + self.modified = modified def __get__(self, instance, owner): """Descriptor for retrieving a value from a field in a document. Do @@ -44,6 +45,7 @@ class BaseField(object): """Descriptor for assigning a value to a field in a document. """ instance._data[self.name] = value + self.modified = True def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -252,8 +254,11 @@ class BaseDocument(object): def __init__(self, **values): self._data = {} + + modified = 'id' in values.keys() # Assign initial values to instance for attr_name, attr_value in self._fields.items(): + attr_value.modified = modified if attr_name in values: setattr(self, attr_name, values.pop(attr_name)) else: @@ -381,9 +386,9 @@ class BaseDocument(object): # that has been queried to return this SON return None cls = subclasses[class_name] - + for field_name, field in cls._fields.items(): if field.name in data: data[field_name] = field.to_python(data[field.name]) - - return cls(**data) + + return cls(**data) \ No newline at end of file diff --git a/mongoengine/document.py b/mongoengine/document.py index 62f9ecce..eec31e09 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,8 +1,9 @@ from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError) -from queryset import OperationError +from queryset import OperationError, QuerySet from connection import _get_db + import pymongo @@ -75,12 +76,30 @@ class Document(BaseDocument): if force_insert: object_id = collection.insert(doc, safe=safe) else: - object_id = collection.save(doc, safe=safe) + if getattr(self, 'id', None) == None: + # new document + object_id = collection.save(doc, safe=safe) + else: + # update document + modified_fields = map(lambda obj: obj[0], filter(lambda obj: obj[1].modified, self._fields.items())) + modified_doc = dict(filter(lambda k: k[0] in modified_fields, doc.items())) + try: + id_field = self._meta['id_field'] + idObj = self._fields[id_field].to_mongo(self['id']) + collection.update({'_id': idObj}, {'$set': modified_doc}, safe=safe) + except pymongo.errors.OperationFailure, err: + if str(err) == 'multi not coded yet': + raise OperationError('update() method requires MongoDB 1.1.3+') + raise OperationError('Update failed (%s)' % str(err)) + object_id = self['id'] + + for field in self._fields.values(): field.modified = False except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if 'duplicate key' in str(err): message = 'Tried to save duplicate unique keys (%s)' raise OperationError(message % str(err)) + id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) @@ -106,6 +125,7 @@ class Document(BaseDocument): obj = self.__class__.objects(**{id_field: self[id_field]}).first() for field in self._fields: setattr(self, field, obj[field]) + obj.modified = False @classmethod def drop_collection(cls): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index beb8ae00..90a3b3d2 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -71,6 +71,7 @@ class FloatField(BaseField): return float(value) def validate(self, value): + if isinstance(value, int): value = float(value) assert isinstance(value, float) if self.min_value is not None and value < self.min_value: diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 69bec002..e71f0598 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -316,15 +316,17 @@ class QuerySet(object): elif cnt == 1: return dataset.first() else: - raise MultipleObjectsReturned(u'%d items returned, instead of 1' % cnt) + raise MultipleObjectsReturned(u'%d items returned, expected exactly one' % cnt) def get(self, **kwargs): + """Retreive exactly one document. Raise DoesNotExist if it's not found. + """ dataset = self.filter(**kwargs) cnt = dataset.count() if cnt == 1: return dataset.first() elif cnt > 1: - raise MultipleObjectsReturned(u'%d items returned, instead of 1' % cnt) + raise MultipleObjectsReturned(u'%d items returned, expected exactly one' % cnt) else: raise DoesNotExist('Document not found')