Skip to content

Commit f906454

Browse files
Fix database initialization
1 parent de2c04f commit f906454

4 files changed

Lines changed: 43 additions & 21 deletions

File tree

compose.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ services:
1010
condition: service_healthy
1111
api:
1212
condition: service_healthy
13+
grafana-probe:
14+
condition: service_healthy
1315
healthcheck:
1416
test: ["CMD", "wget", "-qO-", "http://localhost/"]
1517
interval: 30s

scripts/create_db.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,43 @@ def create_database(settings: Settings) -> None:
4141
sql.Identifier(settings.database_db)
4242
)
4343
)
44-
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS vector"))
45-
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS pg_textsearch"))
46-
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS zhparser"))
47-
cur.execute(
48-
sql.SQL("DROP TEXT SEARCH CONFIGURATION IF EXISTS chinese CASCADE")
49-
)
50-
cur.execute(
51-
sql.SQL("CREATE TEXT SEARCH CONFIGURATION chinese (parser = zhparser)")
52-
)
53-
cur.execute(
54-
sql.SQL(
55-
"ALTER TEXT SEARCH CONFIGURATION chinese "
56-
"ADD MAPPING FOR n,v,a,i,e,l WITH simple"
57-
)
58-
)
5944
logger.info("Created database: %s", settings.database_db)
6045
else:
6146
logger.info("Database already exists: %s", settings.database_db)
6247

6348

49+
def enable_extensions(settings: Settings) -> None:
50+
"""Enable the pgvector extension in the database.
51+
52+
Args:
53+
settings: Application settings containing database connection details.
54+
"""
55+
with (
56+
psycopg.connect(
57+
dbname=settings.database_db,
58+
user=settings.database_user,
59+
password=settings.database_password.get_secret_value(),
60+
host=settings.database_host,
61+
port=settings.database_port,
62+
autocommit=True,
63+
) as conn,
64+
conn.cursor() as cur,
65+
):
66+
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS vector"))
67+
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS pg_textsearch"))
68+
cur.execute(sql.SQL("CREATE EXTENSION IF NOT EXISTS zhparser"))
69+
cur.execute(sql.SQL("DROP TEXT SEARCH CONFIGURATION IF EXISTS chinese CASCADE"))
70+
cur.execute(
71+
sql.SQL("CREATE TEXT SEARCH CONFIGURATION chinese (parser = zhparser)")
72+
)
73+
cur.execute(
74+
sql.SQL(
75+
"ALTER TEXT SEARCH CONFIGURATION chinese "
76+
"ADD MAPPING FOR n,v,a,i,e,l WITH simple"
77+
)
78+
)
79+
80+
6481
def main() -> None:
6582
"""Create the database.
6683
@@ -70,6 +87,7 @@ def main() -> None:
7087
settings: Settings = get_settings()
7188
logging.basicConfig(level=settings.log_level)
7289
create_database(settings)
90+
enable_extensions(settings)
7391

7492

7593
if __name__ == "__main__":

scripts/import_csv.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
logger: logging.Logger = logging.getLogger(__name__)
2929

3030

31-
def read_csv(csv_path: str | os.PathLike) -> pd.DataFrame:
31+
async def read_csv(csv_path: str | os.PathLike) -> pd.DataFrame:
3232
"""Read CSV file into DataFrame.
3333
3434
Args:
@@ -42,11 +42,13 @@ def read_csv(csv_path: str | os.PathLike) -> pd.DataFrame:
4242
ValueError: If CSV format is invalid or headers don't match expected format.
4343
"""
4444
csv_file = Path(csv_path)
45-
if not csv_file.exists():
45+
if not await csv_file.exists():
4646
msg: str = f"CSV file not found: {csv_path}"
4747
raise FileNotFoundError(msg)
4848

49-
df: pd.DataFrame = pd.read_csv(csv_file, encoding="utf-8", dtype=str)
49+
df: pd.DataFrame = await asyncio.to_thread(
50+
pd.read_csv, csv_file, encoding="utf-8", dtype=str
51+
)
5052

5153
expected_headers: list[str] = [
5254
"上市公司代码",
@@ -93,11 +95,11 @@ async def import_stocks_from_csv(db: AsyncSession, csv_path: str | os.PathLike)
9395
ValueError: If CSV format is invalid or headers don't match expected format.
9496
"""
9597
csv_file = Path(csv_path)
96-
if not csv_file.exists():
98+
if not await csv_file.exists():
9799
msg: str = f"CSV file not found: {csv_path}"
98100
raise FileNotFoundError(msg)
99101

100-
df: pd.DataFrame = await asyncio.to_thread(read_csv, csv_file)
102+
df: pd.DataFrame = await read_csv(csv_file)
101103
valid_records: list[dict[str, str]] = []
102104
for r in df.to_dict(orient="records"):
103105
try:

src/stock_analysis/agent/ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ async def _process_pdf(
589589
finally:
590590
if tmp_pdf:
591591
tmp_pdf_path: Path = Path(tmp_pdf)
592-
if tmp_pdf_path.exists():
592+
if await tmp_pdf_path.exists():
593593
await tmp_pdf_path.unlink()
594594

595595
async def ingest(self, object_key: str) -> None:

0 commit comments

Comments
 (0)