diff --git a/application_sdk/activities/query_extraction/sql.py b/application_sdk/activities/query_extraction/sql.py index f712e3439..a2e5eaddf 100644 --- a/application_sdk/activities/query_extraction/sql.py +++ b/application_sdk/activities/query_extraction/sql.py @@ -202,6 +202,7 @@ async def fetch_queries( try: state = await self._get_state(workflow_args) + sql_input = SQLQueryInput( engine=state.sql_client.engine, query=self.get_formatted_query(self.fetch_queries_sql, workflow_args), @@ -209,11 +210,13 @@ async def fetch_queries( ) sql_input = await sql_input.get_dataframe() + sql_input.columns = [str(c).upper() for c in sql_input.columns] + raw_output = ParquetOutput( output_path=workflow_args["output_path"], output_suffix="raw/query", chunk_size=workflow_args["miner_args"].get("chunk_size", 100000), - start_marker=workflow_args["start_marker"], + start_marker=str(workflow_args["start_marker"]), end_marker=workflow_args["end_marker"], ) await raw_output.write_dataframe(sql_input) @@ -528,9 +531,107 @@ async def get_query_batches( store_name=UPSTREAM_OBJECT_STORE_NAME, ) + # Persist the full marker list in StateStore to avoid oversized activity results try: - await self.write_marker(parallel_markers, workflow_args) + from application_sdk.services.statestore import StateStore, StateType + + workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id() + await StateStore.save_state( + key="query_batches", + value=parallel_markers, + id=workflow_id, + type=StateType.WORKFLOWS, + ) + logger.info( + f"Saved {len(parallel_markers)} query batches to StateStore for {workflow_id}" + ) except Exception as e: - logger.warning(f"Failed to write marker file: {e}") + logger.error(f"Failed to save query batches in StateStore: {e}") + # Re-raise to ensure the workflow can retry per standards + raise - return parallel_markers + # Return a small handle to keep activity result size minimal + return [{"state_key": "query_batches", "count": len(parallel_markers)}] + + @activity.defn + @auto_heartbeater + async def load_query_batches( + self, workflow_args: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Load previously saved query batches from StateStore. + + Args: + workflow_args (Dict[str, Any]): Workflow arguments containing workflow_id + + Returns: + List[Dict[str, Any]]: The list of parallelized query batch descriptors + + Raises: + Exception: If retrieval from StateStore fails + """ + try: + from application_sdk.services.statestore import StateStore, StateType + + workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id() + state = await StateStore.get_state(workflow_id, StateType.WORKFLOWS) + batches: List[Dict[str, Any]] = state.get("query_batches", []) + logger.info( + f"Loaded {len(batches)} query batches from StateStore for {workflow_id}" + ) + return batches + except Exception as e: + logger.error( + f"Failed to load query batches from StateStore: {e}", + exc_info=True, + ) + raise + + @activity.defn + @auto_heartbeater + async def fetch_single_batch( + self, workflow_args: Dict[str, Any], batch_index: int + ) -> Dict[str, Any]: + """Fetch a single batch by index from StateStore. + + Args: + workflow_args (Dict[str, Any]): Workflow arguments containing workflow_id + batch_index (int): Index of the batch to fetch + + Returns: + Dict[str, Any]: The single batch data + + Raises: + Exception: If batch retrieval fails + """ + try: + from application_sdk.services.statestore import StateStore, StateType + + workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id() + state = await StateStore.get_state(workflow_id, StateType.WORKFLOWS) + batches: List[Dict[str, Any]] = state.get("query_batches", []) + + if batch_index >= len(batches): + raise IndexError( + f"Batch index {batch_index} out of range for {len(batches)} batches" + ) + + batch = batches[batch_index] + logger.info(f"Fetched batch {batch_index + 1}/{len(batches)}") + return batch + + except Exception as e: + logger.error(f"Failed to fetch batch {batch_index}: {e}") + raise + + @activity.defn + @auto_heartbeater + async def write_final_marker(self, workflow_args: Dict[str, Any]) -> None: + """Write final marker after all fetches complete. + + Loads batches from StateStore and writes the last end marker as the markerfile. + """ + try: + batches = await self.load_query_batches(workflow_args) + await self.write_marker(batches, workflow_args) + except Exception as e: + logger.warning(f"Failed to write final marker file: {e}") diff --git a/application_sdk/outputs/__init__.py b/application_sdk/outputs/__init__.py index 302828fda..5b84ef040 100644 --- a/application_sdk/outputs/__init__.py +++ b/application_sdk/outputs/__init__.py @@ -96,7 +96,6 @@ def path_gen( chunk_count: Optional[int] = None, chunk_part: int = 0, start_marker: Optional[str] = None, - end_marker: Optional[str] = None, ) -> str: """Generate a file path for a chunk. @@ -109,9 +108,12 @@ def path_gen( Returns: str: Generated file path for the chunk. """ - # For Query Extraction - use start and end markers without chunk count - if start_marker and end_marker: - return f"{start_marker}_{end_marker}{self._EXTENSION}" + # For Query Extraction - use start marker + if start_marker: + if chunk_count is None: + return f"atlan_raw_mined_{str(start_marker)}_{str(chunk_part)}{self._EXTENSION}" + else: + return f"atlan_raw_mined_{str(start_marker)}_{str(chunk_count)}_{str(chunk_part)}{self._EXTENSION}" # For regular chunking - include chunk count if chunk_count is None: @@ -213,7 +215,7 @@ async def write_dataframe(self, dataframe: "pd.DataFrame"): self.current_buffer_size_bytes + chunk_size_bytes > self.max_file_size_bytes ): - output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}" + output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part, self.start_marker)}" if os.path.exists(output_file_name): await self._upload_file(output_file_name) self.chunk_part += 1 @@ -227,7 +229,7 @@ async def write_dataframe(self, dataframe: "pd.DataFrame"): if self.current_buffer_size_bytes > 0: # Finally upload the final file to the object store - output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}" + output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part, self.start_marker)}" if os.path.exists(output_file_name): await self._upload_file(output_file_name) self.chunk_part += 1 @@ -361,9 +363,7 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int): try: if not is_empty_dataframe(chunk): self.total_record_count += len(chunk) - output_file_name = ( - f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part)}" - ) + output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part, self.start_marker)}" await self.write_chunk(chunk, output_file_name) self.current_buffer_size = 0 diff --git a/application_sdk/workflows/query_extraction/sql.py b/application_sdk/workflows/query_extraction/sql.py index cdd65e631..bd5d1987a 100644 --- a/application_sdk/workflows/query_extraction/sql.py +++ b/application_sdk/workflows/query_extraction/sql.py @@ -62,7 +62,10 @@ def get_activities( """ return [ activities.get_query_batches, + activities.load_query_batches, + activities.fetch_single_batch, activities.fetch_queries, + activities.write_final_marker, activities.preflight_check, activities.get_workflow_args, ] @@ -97,7 +100,8 @@ async def run(self, workflow_config: Dict[str, Any]): backoff_coefficient=2, ) - results: List[Dict[str, Any]] = await workflow.execute_activity_method( + # Generate and persist batch markers (activity returns only a small handle) + batch_handles: List[Dict[str, Any]] = await workflow.execute_activity_method( self.activities_cls.get_query_batches, workflow_args, retry_policy=retry_policy, @@ -105,14 +109,27 @@ async def run(self, workflow_config: Dict[str, Any]): heartbeat_timeout=self.default_heartbeat_timeout, ) + batch_count = batch_handles[0]["count"] if batch_handles else 0 + logger.info(f"Processing {batch_count} query batches") + miner_activities: List[Coroutine[Any, Any, None]] = [] - # Extract Queries - for result in results: + # Fetch and process each batch individually to avoid size limits + for batch_index in range(batch_count): + # Fetch the specific batch + batch_data = await workflow.execute_activity( + self.activities_cls.fetch_single_batch, + args=[workflow_args, batch_index], + retry_policy=retry_policy, + start_to_close_timeout=self.default_start_to_close_timeout, + heartbeat_timeout=self.default_heartbeat_timeout, + ) + + # Create activity args for this specific batch activity_args = workflow_args.copy() - activity_args["sql_query"] = result["sql"] - activity_args["start_marker"] = result["start"] - activity_args["end_marker"] = result["end"] + activity_args["sql_query"] = batch_data["sql"] + activity_args["start_marker"] = batch_data["start"] + activity_args["end_marker"] = batch_data["end"] miner_activities.append( workflow.execute_activity( @@ -126,4 +143,13 @@ async def run(self, workflow_config: Dict[str, Any]): await asyncio.gather(*miner_activities) + # Write marker only after all fetches complete + await workflow.execute_activity_method( + self.activities_cls.write_final_marker, + workflow_args, + retry_policy=retry_policy, + start_to_close_timeout=self.default_start_to_close_timeout, + heartbeat_timeout=self.default_heartbeat_timeout, + ) + logger.info(f"Miner workflow completed for {workflow_id}")