fix: modify multiple times. (#279)

This commit is contained in:
long2ice 2023-01-27 13:49:07 +08:00
parent 3fbf9febfb
commit 1c9b65cc37
3 changed files with 38 additions and 23 deletions

View File

@ -5,6 +5,7 @@
### 0.7.2 ### 0.7.2
- Support virtual fields. - Support virtual fields.
- Fix modify multiple times. (#279)
### 0.7.1 ### 0.7.1

View File

@ -40,7 +40,9 @@ class BaseDDL:
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]"): def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(";") return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
";"
)
def drop_table(self, table_name: str): def drop_table(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)

View File

@ -281,17 +281,25 @@ class Migrate:
# remove indexes # remove indexes
for index in old_indexes.difference(new_indexes): for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True) cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
old_data_fields = list(filter(lambda x: x.get('db_field_types') is not None, old_data_fields = list(
old_model_describe.get("data_fields"))) filter(
new_data_fields = list(filter(lambda x: x.get('db_field_types') is not None, lambda x: x.get("db_field_types") is not None,
new_model_describe.get("data_fields"))) old_model_describe.get("data_fields"),
)
)
new_data_fields = list(
filter(
lambda x: x.get("db_field_types") is not None,
new_model_describe.get("data_fields"),
)
)
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = next( new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields) filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
@ -303,22 +311,22 @@ class Migrate:
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
if ( if (
changes[0] changes[0]
== ( == (
"change", "change",
"name", "name",
(old_data_field_name, new_data_field_name), (old_data_field_name, new_data_field_name),
) )
and changes[1] and changes[1]
== ( == (
"change", "change",
"db_column", "db_column",
( (
old_data_field.get("db_column"), old_data_field.get("db_column"),
new_data_field.get("db_column"), new_data_field.get("db_column"),
), ),
) )
and old_data_field_name not in new_data_fields_name and old_data_field_name not in new_data_fields_name
): ):
if upgrade: if upgrade:
is_rename = click.prompt( is_rename = click.prompt(
@ -334,9 +342,9 @@ class Migrate:
cls._rename_old.append(old_data_field_name) cls._rename_old.append(old_data_field_name)
# only MySQL8+ has rename syntax # only MySQL8+ has rename syntax
if ( if (
cls.dialect == "mysql" cls.dialect == "mysql"
and cls._db_version and cls._db_version
and cls._db_version.startswith("5.") and cls._db_version.startswith("5.")
): ):
cls._add_operator( cls._add_operator(
cls._change_field( cls._change_field(
@ -367,11 +375,11 @@ class Migrate:
) )
# remove fields # remove fields
for old_data_field_name in set(old_data_fields_name).difference( for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name) set(new_data_fields_name)
): ):
# don't remove field if is renamed # don't remove field if is renamed
if (upgrade and old_data_field_name in cls._rename_old) or ( if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new not upgrade and old_data_field_name in cls._rename_new
): ):
continue continue
old_data_field = next( old_data_field = next(
@ -403,7 +411,7 @@ class Migrate:
# add fk # add fk
for new_fk_field_name in set(new_fk_fields_name).difference( for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name) set(old_fk_fields_name)
): ):
fk_field = next( fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields) filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
@ -418,7 +426,7 @@ class Migrate:
) )
# drop fk # drop fk
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
): ):
old_fk_field = next( old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
@ -440,6 +448,7 @@ class Migrate:
filter(lambda x: x.get("name") == field_name, new_data_fields) filter(lambda x: x.get("name") == field_name, new_data_fields)
) )
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
modified = False
for change in changes: for change in changes:
_, option, old_new = change _, option, old_new = change
if option == "indexed": if option == "indexed":
@ -464,7 +473,7 @@ class Migrate:
continue continue
elif option == "default": elif option == "default":
if not ( if not (
is_default_function(old_new[0]) or is_default_function(old_new[1]) is_default_function(old_new[0]) or is_default_function(old_new[1])
): ):
# change column default # change column default
cls._add_operator( cls._add_operator(
@ -477,11 +486,14 @@ class Migrate:
# change nullable # change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade) cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
else: else:
if modified:
continue
# modify column # modify column
cls._add_operator( cls._add_operator(
cls._modify_field(model, new_data_field), cls._modify_field(model, new_data_field),
upgrade, upgrade,
) )
modified = True
for old_model in old_models: for old_model in old_models:
if old_model not in new_models.keys(): if old_model not in new_models.keys():