Skip to content

Commit bb5bcdf

Browse files
authored
Add the database primary key mode config (#953)
* Add the database primary key mode config * Update auto to autoincrement
1 parent aad9afa commit bb5bcdf

File tree

9 files changed

+39
-27
lines changed

9 files changed

+39
-27
lines changed

backend/app/task/celery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import celery_aio_pool
55

66
from backend.app.task.tasks.beat import LOCAL_BEAT_SCHEDULE
7+
from backend.common.enums import DataBaseType
78
from backend.core.conf import settings
89
from backend.core.path_conf import BASE_PATH
910

@@ -32,7 +33,7 @@ def init_celery() -> celery.Celery:
3233
broker_url = f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER_REDIS_DATABASE}'
3334

3435
result_backend = f'db+postgresql+psycopg://{settings.DATABASE_USER}:{settings.DATABASE_PASSWORD}@{settings.DATABASE_HOST}:{settings.DATABASE_PORT}/{settings.DATABASE_SCHEMA}'
35-
if settings.DATABASE_TYPE == 'mysql':
36+
if DataBaseType.mysql == settings.DATABASE_TYPE:
3637
result_backend = result_backend.replace('postgresql+psycopg', 'mysql+pymysql')
3738

3839
# https://docs.celeryq.dev/en/stable/userguide/configuration.html

backend/common/model.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlalchemy.ext.asyncio import AsyncAttrs
77
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
88

9+
from backend.common.enums import DataBaseType, PrimaryKeyType
910
from backend.core.conf import settings
1011
from backend.utils.snowflake import snowflake
1112
from backend.utils.timezone import timezone
@@ -23,15 +24,11 @@
2324
autoincrement=True,
2425
sort_order=-999,
2526
comment='主键 ID',
26-
),
27-
]
28-
29-
30-
# 雪花算法 Mapped 类型主键,使用方法与 id_key 相同
31-
# 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html
32-
snowflake_id_key = Annotated[
33-
int,
34-
mapped_column(
27+
)
28+
if PrimaryKeyType.autoincrement == settings.DATABASE_PK_MODE
29+
# 雪花算法 Mapped 类型主键
30+
# 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html
31+
else mapped_column(
3532
BigInteger,
3633
primary_key=True,
3734
unique=True,
@@ -46,7 +43,7 @@
4643
class UniversalText(TypeDecorator[str]):
4744
"""PostgreSQL、MySQL 兼容性(长)文本类型"""
4845

49-
impl = LONGTEXT if settings.DATABASE_TYPE == 'mysql' else Text
46+
impl = LONGTEXT if DataBaseType.mysql == settings.DATABASE_TYPE else Text
5047
cache_ok = True
5148

5249
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001

backend/common/schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic import BaseModel, ConfigDict, EmailStr, Field, validate_email
55

6+
from backend.core.conf import settings
67
from backend.utils.timezone import timezone
78

89
CustomPhoneNumber = Annotated[str, Field(pattern=r'^1[3-9]\d{9}$')]
@@ -28,6 +29,14 @@ class SchemaBase(BaseModel):
2829
},
2930
)
3031

32+
if settings.DATABASE_PK_MODE:
33+
from pydantic import field_serializer
34+
35+
# 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html#%E6%B3%A8%E6%84%8F%E4%BA%8B%E9%A1%B9
36+
@field_serializer('id', check_fields=False)
37+
def serialize_id(self, value: int) -> str:
38+
return str(value)
39+
3140

3241
def ser_string(value: Any) -> str | None:
3342
if value:

backend/core/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Settings(BaseSettings):
4242
DATABASE_POOL_ECHO: bool | Literal['debug'] = False
4343
DATABASE_SCHEMA: str = 'fba'
4444
DATABASE_CHARSET: str = 'utf8mb4'
45+
DATABASE_PK_MODE: Literal['autoincrement', 'snowflake'] = 'autoincrement'
4546

4647
# .env Redis
4748
REDIS_HOST: str

backend/database/db.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
create_async_engine,
1414
)
1515

16+
from backend.common.enums import DataBaseType
1617
from backend.common.log import log
1718
from backend.common.model import MappedBase
1819
from backend.core.conf import settings
@@ -26,14 +27,14 @@ def create_database_url(*, unittest: bool = False) -> URL:
2627
:return:
2728
"""
2829
url = URL.create(
29-
drivername='mysql+asyncmy' if settings.DATABASE_TYPE == 'mysql' else 'postgresql+asyncpg',
30+
drivername='mysql+asyncmy' if DataBaseType.mysql == settings.DATABASE_TYPE else 'postgresql+asyncpg',
3031
username=settings.DATABASE_USER,
3132
password=settings.DATABASE_PASSWORD,
3233
host=settings.DATABASE_HOST,
3334
port=settings.DATABASE_PORT,
3435
database=settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test',
3536
)
36-
if settings.DATABASE_TYPE == 'mysql':
37+
if DataBaseType.mysql == settings.DATABASE_TYPE:
3738
url.update_query_dict({'charset': settings.DATABASE_CHARSET})
3839
return url
3940

backend/plugin/code_generator/crud/crud_code.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlalchemy import Row, RowMapping, text
44
from sqlalchemy.ext.asyncio import AsyncSession
55

6+
from backend.common.enums import DataBaseType
67
from backend.core.conf import settings
78

89

@@ -18,7 +19,7 @@ async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[RowMap
1819
:param table_schema: 数据库 schema 名称
1920
:return:
2021
"""
21-
if settings.DATABASE_TYPE == 'mysql':
22+
if DataBaseType.mysql == settings.DATABASE_TYPE:
2223
sql = """
2324
SELECT table_name AS table_name, table_comment AS table_comment
2425
FROM information_schema.tables
@@ -48,7 +49,7 @@ async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]:
4849
:param table_name: 表名
4950
:return:
5051
"""
51-
if settings.DATABASE_TYPE == 'mysql':
52+
if DataBaseType.mysql == settings.DATABASE_TYPE:
5253
sql = """
5354
SELECT table_name AS table_name, table_comment AS table_comment
5455
FROM information_schema.tables
@@ -79,7 +80,7 @@ async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str)
7980
:param table_name: 表名
8081
:return:
8182
"""
82-
if settings.DATABASE_TYPE == 'mysql':
83+
if DataBaseType.mysql == settings.DATABASE_TYPE:
8384
sql = """
8485
SELECT column_name AS column_name,
8586
CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_pk,

backend/plugin/code_generator/service/column_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sqlalchemy.ext.asyncio import AsyncSession
44

5+
from backend.common.enums import DataBaseType
56
from backend.common.exception import errors
67
from backend.core.conf import settings
78
from backend.plugin.code_generator.crud.crud_column import gen_column_dao
@@ -32,7 +33,7 @@ async def get(*, db: AsyncSession, pk: int) -> GenColumn:
3233
@staticmethod
3334
async def get_types() -> list[str]:
3435
"""获取所有列类型"""
35-
if settings.DATABASE_TYPE == 'mysql':
36+
if DataBaseType.mysql == settings.DATABASE_TYPE:
3637
types = GenMySQLColumnType.get_member_keys()
3738
else:
3839
types = GenPostgreSQLColumnType.get_member_keys()

backend/plugin/code_generator/utils/type_conversion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import lru_cache
22

3+
from backend.common.enums import DataBaseType
34
from backend.core.conf import settings
45
from backend.plugin.code_generator.enums import GenMySQLColumnType, GenPostgreSQLColumnType
56

@@ -12,7 +13,7 @@ def sql_type_to_sqlalchemy(typing: str) -> str:
1213
:param typing: SQL 类型字符串
1314
:return:
1415
"""
15-
if settings.DATABASE_TYPE == 'mysql':
16+
if DataBaseType.mysql == settings.DATABASE_TYPE:
1617
if typing in GenMySQLColumnType.get_member_keys():
1718
return typing
1819
else:
@@ -30,7 +31,7 @@ def sql_type_to_pydantic(typing: str) -> str:
3031
:return:
3132
"""
3233
try:
33-
if settings.DATABASE_TYPE == 'mysql':
34+
if DataBaseType.mysql == settings.DATABASE_TYPE:
3435
return GenMySQLColumnType[typing].value
3536
if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名
3637
return 'str'

backend/plugin/tools.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKey
7979
"""
8080
if db_type == DataBaseType.mysql:
8181
mysql_dir = PLUGIN_DIR / plugin / 'sql' / 'mysql'
82-
if pk_type == PrimaryKeyType.autoincrement:
83-
sql_file = mysql_dir / 'init.sql'
84-
else:
85-
sql_file = mysql_dir / 'init_snowflake.sql'
82+
sql_file = (
83+
mysql_dir / 'init.sql' if pk_type == PrimaryKeyType.autoincrement else mysql_dir / 'init_snowflake.sql'
84+
)
8685
else:
8786
postgresql_dir = PLUGIN_DIR / plugin / 'sql' / 'postgresql'
88-
if pk_type == PrimaryKeyType.autoincrement:
89-
sql_file = postgresql_dir / 'init.sql'
90-
else:
91-
sql_file = postgresql_dir / 'init_snowflake.sql'
87+
sql_file = (
88+
postgresql_dir / 'init.sql'
89+
if pk_type == PrimaryKeyType.autoincrement
90+
else postgresql_dir / 'init_snowflake.sql'
91+
)
9292

9393
path = anyio.Path(sql_file)
9494
if not await path.exists():

0 commit comments

Comments
 (0)