Skip to content

Commit c32f8ab

Browse files
fcollmanCopilot
andauthored
Fix precomputed queries (#208)
* fix return types to work for nglui * fixing NAs in segment properties * Update materializationengine/blueprints/client/query.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update materializationengine/blueprints/client/query.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent bb7ec28 commit c32f8ab

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

materializationengine/blueprints/client/api2.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def fix_dataframe_types(df):
266266
Returns:
267267
pd.DataFrame: dataframe with fixed types."""
268268
for colname in df.columns:
269+
if df[colname].isnull().all():
270+
df.drop(columns=[colname], inplace=True)
271+
continue
269272
if pd.api.types.is_float_dtype(df[colname]):
270273
df[colname]=df[colname].astype(np.float32)
271274
if pd.api.types.is_integer_dtype(df[colname]):
@@ -1151,29 +1154,21 @@ def process_fields(df, fields, column_names, tags, bool_tags, numerical):
11511154
or field_name == "target_id"
11521155
):
11531156
continue
1154-
1157+
if col not in df.columns:
1158+
continue
11551159
if isinstance(field, mm_fields.String):
1156-
if df[col].isnull().all():
1157-
continue
11581160
# check that this column is not all nulls
11591161
tags.append(col)
11601162
elif isinstance(field, mm_fields.Boolean):
1161-
if df[col].isnull().all():
1162-
continue
11631163
df[col] = df[col].astype(bool)
11641164
bool_tags.append(col)
11651165
elif isinstance(field, PostGISField):
1166-
# if all the values are NaNs skip this column
1167-
if df[col + "_x"].isnull().all():
1168-
continue
11691166
numerical.append(col + "_x")
11701167
numerical.append(col + "_y")
11711168
numerical.append(col + "_z")
11721169
elif isinstance(field, mm_fields.Number):
1173-
if df[col].isnull().all():
1174-
continue
11751170
numerical.append(col)
1176-
1171+
return df
11771172

11781173
def process_view_columns(df, model, column_names, tags, bool_tags, numerical):
11791174
for table_column_name, table_column in model.columns.items():
@@ -1246,10 +1241,10 @@ def preprocess_dataframe(df, table_name, aligned_volume_name, column_names):
12461241
numerical = []
12471242
bool_tags = []
12481243

1249-
process_fields(df, fields, column_names[table_name], tags, bool_tags, numerical)
1244+
df = process_fields(df, fields, column_names[table_name], tags, bool_tags, numerical)
12501245

12511246
if table_metadata["reference_table"]:
1252-
process_fields(
1247+
df = process_fields(
12531248
df,
12541249
ref_fields,
12551250
column_names[ref_table],

materializationengine/blueprints/client/query.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616
from sqlalchemy.orm.util import AliasedClass
1717
from sqlalchemy.sql.schema import Table
1818
from sqlalchemy.sql.selectable import Alias
19-
from sqlalchemy.sql.sqltypes import Boolean, DateTime, Integer
19+
from sqlalchemy.sql.sqltypes import Boolean, DateTime, Integer, BigInteger, Float, String
2020

2121
DEFAULT_SUFFIX_LIST = ["x", "y", "z", "xx", "yy", "zz", "xxx", "yyy", "zzz"]
2222

23+
dtype_map = {
24+
Boolean: pd.BooleanDtype(),
25+
DateTime: np.dtype('datetime64[ns]'),
26+
Integer: pd.Int32Dtype(),
27+
BigInteger: pd.Int64Dtype(),
28+
Float: pd.Float32Dtype(),
29+
}
2330

2431
def concatenate_position_columns(df):
2532
grps = itertools.groupby(df.columns, key=lambda x: x[:-2])
@@ -83,7 +90,6 @@ def fix_columns_with_query(
8390
if n_tables == 1:
8491
schema_model = query.column_descriptions[0]["type"]
8592
for colname in df.columns:
86-
8793
if n_tables == 1:
8894
coltype = type(getattr(schema_model, colname).type)
8995
else:
@@ -112,7 +118,6 @@ def fix_columns_with_query(
112118
)
113119
elif isinstance(df[colname].loc[0], Decimal) and fix_decimal is True:
114120
df[colname] = _fix_decimal_column(df[colname])
115-
116121
return df
117122

118123

@@ -329,10 +334,16 @@ def _execute_query(
329334
else:
330335
if direct_sql_pandas:
331336
statement = str(query.statement.compile(engine, compile_kwargs={"literal_binds": True}))
337+
dtypes = {}
338+
for k in query.statement.columns.keys():
339+
coltype = query.statement.columns[k].type
340+
if type(coltype) in dtype_map:
341+
dtypes[k] = dtype_map[type(coltype)]
332342
df = pd.read_sql(statement,
333343
session.connection().connection,
334344
coerce_float=True,
335345
index_col=index_col,
346+
dtype=dtypes,
336347
dtype_backend='numpy_nullable')
337348
else:
338349
df = read_sql_tmpfile(

0 commit comments

Comments
 (0)