diff --git a/AUTHORS b/AUTHORS index f64093d8..42082278 100644 --- a/AUTHORS +++ b/AUTHORS @@ -224,3 +224,5 @@ that much better: * Matthieu Rigal (https://github.com/MRigal) * Charanpal Dhanjal (https://github.com/charanpald) * Emmanuel Leblond (https://github.com/touilleMan) + * Breeze.Kay (https://github.com/9nix00) + diff --git a/docs/changelog.rst b/docs/changelog.rst index b9ad5b0e..55ff7754 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Improve Document._created status when switch collection and db #1020 - Queryset update doesn't go through field validation #453 - Added support for specifying authentication source as option `authSource` in URI. #967 - Fixed mark_as_changed to handle higher/lower level fields changed. #927 diff --git a/mongoengine/document.py b/mongoengine/document.py index 429f6065..0ceedfc1 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -491,7 +491,7 @@ class Document(BaseDocument): raise OperationError(message) signals.post_delete.send(self.__class__, document=self) - def switch_db(self, db_alias): + def switch_db(self, db_alias, keep_created=True): """ Temporarily switch the database for a document instance. @@ -503,6 +503,9 @@ class Document(BaseDocument): :param str db_alias: The database alias to use for saving the document + :param bool keep_created: keep self._created value after switching db, else is reset to True + + .. seealso:: Use :class:`~mongoengine.context_managers.switch_collection` if you need to read from another collection @@ -513,12 +516,12 @@ class Document(BaseDocument): self._get_collection = lambda: collection self._get_db = lambda: db self._collection = collection - self._created = True + self._created = True if not keep_created else self._created self.__objects = self._qs self.__objects._collection_obj = collection return self - def switch_collection(self, collection_name): + def switch_collection(self, collection_name, keep_created=True): """ Temporarily switch the collection for a document instance. @@ -531,6 +534,9 @@ class Document(BaseDocument): :param str collection_name: The database alias to use for saving the document + :param bool keep_created: keep self._created value after switching collection, else is reset to True + + .. seealso:: Use :class:`~mongoengine.context_managers.switch_db` if you need to read from another database @@ -539,7 +545,7 @@ class Document(BaseDocument): collection = cls._get_collection() self._get_collection = lambda: collection self._collection = collection - self._created = True + self._created = True if not keep_created else self._created self.__objects = self._qs self.__objects._collection_obj = collection return self diff --git a/tests/test_signals.py b/tests/test_signals.py index 6ab061d1..8672925c 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -279,5 +279,33 @@ class SignalTests(unittest.TestCase): # second time, it must be an update self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + def test_signals_with_switch_collection(self): + ei = self.ExplicitId(id=123) + ei.switch_collection("explicit__1") + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + ei.switch_collection("explicit__1") + self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + + ei.switch_collection("explicit__1", keep_created=False) + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + ei.switch_collection("explicit__1", keep_created=False) + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + + def test_signals_with_switch_db(self): + connect('mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + ei = self.ExplicitId(id=123) + ei.switch_db("testdb-1") + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + ei.switch_db("testdb-1") + self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + + ei.switch_db("testdb-1", keep_created=False) + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + ei.switch_db("testdb-1", keep_created=False) + self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + + if __name__ == '__main__': unittest.main()