From 3b469a75ccd85a6ae543c74188f1395cd24adf38 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sun, 5 May 2024 16:36:24 -0400 Subject: [PATCH] Fix type issues in RedshiftProvisioningScore --- .../scoring/performance/unified_redshift.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/brad/planner/scoring/performance/unified_redshift.py b/src/brad/planner/scoring/performance/unified_redshift.py index 4f51b85b..e509cc9d 100644 --- a/src/brad/planner/scoring/performance/unified_redshift.py +++ b/src/brad/planner/scoring/performance/unified_redshift.py @@ -53,10 +53,10 @@ def compute( ctx.metrics.redshift_cpu_list is not None and ctx.metrics.redshift_cpu_list.shape[0] > 0 ): - avg_cpu = ctx.metrics.redshift_cpu_list.mean() + avg_cpu: float = ctx.metrics.redshift_cpu_list.mean().item() else: # This won't be used. This is actually max. - avg_cpu = ctx.metrics.redshift_cpu_avg + avg_cpu = float(ctx.metrics.redshift_cpu_avg) gamma_norm_factor = HotConfig.instance().get_value( "query_lat_p90", default=30.0 @@ -180,7 +180,7 @@ def predict_max_node_cpu_util( curr_cpu_util *= gamma curr_cpu_denorm = curr_cpu_util * redshift_num_cpus(curr_prov) - curr_max_cpu_denorm = curr_cpu_denorm.max() + curr_max_cpu_denorm = curr_cpu_denorm.max().item() ( peak_load, @@ -262,11 +262,11 @@ def compute_direct_cpu_denorm( per_query_cpu_denorm = np.clip( query_run_times * alpha, a_min=0.0, a_max=load_max ) - total_denorm = np.dot(per_query_cpu_denorm, arrival_weights) - max_query_cpu_denorm = per_query_cpu_denorm.max() + total_denorm = np.dot(per_query_cpu_denorm, arrival_weights).item() + max_query_cpu_denorm = (per_query_cpu_denorm * arrival_weights).max().item() else: # Edge case: Query with 0 arrival count (used as a constraint). - total_denorm = np.zeros_like(query_run_times) + total_denorm = 0.0 max_query_cpu_denorm = 0.0 if debug_dict is not None: debug_dict["redshift_total_cpu_denorm"] = total_denorm @@ -294,7 +294,7 @@ def query_movement_factor( total_next_latency = np.dot( curr_query_run_times, workload.get_arrival_counts_batch(query_indices) ) - return total_next_latency / norm_factor + return total_next_latency.item() / norm_factor @staticmethod def predict_query_latency_load_resources(