Test checking can save if not included

ref: MongoEngine/mongoengine#70
This commit is contained in:
Ross Lawley 2012-08-24 11:45:04 +01:00
parent 4c8296acc6
commit 87792e1921
2 changed files with 48 additions and 0 deletions

View File

@ -393,6 +393,8 @@ class ComplexDateTimeField(StringField):
data = super(ComplexDateTimeField, self).__get__(instance, owner) data = super(ComplexDateTimeField, self).__get__(instance, owner)
if data == None: if data == None:
return datetime.datetime.now() return datetime.datetime.now()
if isinstance(data, datetime.datetime):
return data
return self._convert_from_string(data) return self._convert_from_string(data)
def __set__(self, instance, value): def __set__(self, instance, value):

View File

@ -1640,6 +1640,52 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person.name, None) self.assertEqual(person.name, None)
self.assertEqual(person.age, None) self.assertEqual(person.age, None)
def test_can_save_if_not_included(self):
class EmbeddedDoc(EmbeddedDocument):
pass
class Simple(Document):
pass
class Doc(Document):
string_field = StringField(default='1')
int_field = IntField(default=1)
float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.datetime.now)
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc())
list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=ObjectId)
reference_field = ReferenceField(Simple, default=lambda: Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.datetime.now)
url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(default=lambda: EmbeddedDoc())
Simple.drop_collection()
Doc.drop_collection()
Doc().save()
my_doc = Doc.objects.only("string_field").first()
my_doc.string_field = "string"
my_doc.save()
my_doc = Doc.objects.get(string_field="string")
self.assertEqual(my_doc.string_field, "string")
self.assertEqual(my_doc.int_field, 1)
def test_document_update(self): def test_document_update(self):
def update_not_saved_raises(): def update_not_saved_raises():