diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 5f889ff00c..0f16342762 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -23,6 +23,7 @@ ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost +from src.web.v1.services import Configuration from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -76,6 +77,7 @@ ### QUESTION ### User's Follow-up Question: {{ query }} +Current Time: {{ current_time }} ### REASONING PLAN ### {{ sql_generation_reasoning }} @@ -97,6 +99,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + configuration: Configuration = Configuration(), ) -> dict: _prompt = prompt_builder.run( query=query, @@ -112,6 +115,7 @@ def prompt( json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, + current_time=configuration.show_current_time(), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 4d6cd313cd..e48ac45c7e 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -156,6 +156,7 @@ User's current question: {{query}} Output Language: {{ language }} +Current Time: {{ current_time }} Let's think step by step """ @@ -286,6 +287,7 @@ def prompt( instructions=instructions, ), docs=wren_ai_docs, + current_time=configuration.show_current_time(), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index a1c4c2852b..997a94ca4b 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -19,6 +19,7 @@ construct_instructions, ) from src.utils import trace_cost +from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -62,6 +63,7 @@ ### QUESTION ### SQL: {{ invalid_generation_result.sql }} Error Message: {{ invalid_generation_result.error }} +Current Time: {{ current_time }} Let's think step by step. """ @@ -73,11 +75,13 @@ def prompt( documents: List[Document], invalid_generation_result: Dict, prompt_builder: PromptBuilder, + configuration: Configuration = Configuration(), instructions: list[dict] | None = None, ) -> dict: _prompt = prompt_builder.run( documents=documents, invalid_generation_result=invalid_generation_result, + current_time=configuration.show_current_time(), instructions=construct_instructions( instructions=instructions, ), diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 27d3bc5eab..07e1848b6c 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -22,6 +22,7 @@ ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost +from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -70,6 +71,7 @@ ### QUESTION ### User's Question: {{ query }} +Current Time: {{ current_time }} {% if sql_generation_reasoning %} ### REASONING PLAN ### @@ -93,6 +95,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + configuration: Configuration = Configuration(), ) -> dict: _prompt = prompt_builder.run( query=query, @@ -108,6 +111,7 @@ def prompt( json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, + current_time=configuration.show_current_time(), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index 0dbf2fd808..6ffdb90ee5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -22,7 +22,7 @@ ) from src.pipelines.retrieval.sql_functions import SqlFunction from src.utils import trace_cost - +from src.web.v1.services import Configuration logger = logging.getLogger("wren-ai-service") @@ -88,6 +88,7 @@ ### QUESTION ### SQL generation reasoning: {{ sql_generation_reasoning }} Original SQL query: {{ sql }} +Current Time: {{ current_time }} Let's think step by step. """ @@ -106,6 +107,7 @@ def prompt( has_metric: bool = False, has_json_field: bool = False, sql_functions: list[SqlFunction] | None = None, + configuration: Configuration = Configuration(), ) -> dict: _prompt = prompt_builder.run( sql=sql, @@ -121,6 +123,7 @@ def prompt( json_field_instructions=(json_field_instructions if has_json_field else ""), sql_samples=sql_samples, sql_functions=sql_functions, + current_time=configuration.show_current_time(), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}