diff --git a/tests/document.py b/tests/document.py index 45f1c3c7..6f9d9ecb 100644 --- a/tests/document.py +++ b/tests/document.py @@ -588,6 +588,34 @@ class DocumentTest(unittest.TestCase): # Ensure that the 'details' embedded object saved correctly self.assertEqual(employee_obj['details']['position'], 'Developer') + def test_updating_an_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Test updating an embedded document + promoted_employee = Employee.objects.get(name='Test Employee') + promoted_employee.details.position = 'Senior Developer' + promoted_employee.save() + + collection = self.db[self.Person._meta['collection']] + employee_obj = collection.find_one({'name': 'Test Employee'}) + self.assertEqual(employee_obj['name'], 'Test Employee') + self.assertEqual(employee_obj['age'], 50) + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(employee_obj['details']['position'], 'Senior Developer') + def test_save_reference(self): """Ensure that a document reference field may be saved in the database. """