|
| 1 | +import contextlib |
| 2 | +import functools |
| 3 | +import operator |
| 4 | +from datetime import datetime |
| 5 | +from pathlib import ( |
| 6 | + Path, |
| 7 | +) |
| 8 | + |
| 9 | +import dask |
| 10 | +import feast |
| 11 | +import feast.repo_operations |
| 12 | +import feast.utils as utils |
| 13 | +import toolz |
| 14 | +from attr import ( |
| 15 | + field, |
| 16 | + frozen, |
| 17 | +) |
| 18 | +from attr.validators import ( |
| 19 | + instance_of, |
| 20 | +) |
| 21 | +from feast.infra.offline_stores.offline_utils import ( |
| 22 | + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, |
| 23 | +) |
| 24 | + |
| 25 | +import xorq.api as xo |
| 26 | + |
| 27 | + |
| 28 | +@dask.base.normalize_token.register(dask.utils.methodcaller) |
| 29 | +def normalize_methodcaller(mc): |
| 30 | + return dask.base.normalize_token( |
| 31 | + ( |
| 32 | + dask.utils.methodcaller, |
| 33 | + mc.method, |
| 34 | + ) |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +@frozen |
| 39 | +class Store: |
| 40 | + path = field(validator=instance_of(Path), converter=Path) |
| 41 | + |
| 42 | + def __attrs_post_init__(self): |
| 43 | + assert self.path.exists() |
| 44 | + |
| 45 | + @property |
| 46 | + @functools.cache |
| 47 | + def store(self): |
| 48 | + return feast.FeatureStore(self.path) |
| 49 | + |
| 50 | + @property |
| 51 | + def config(self): |
| 52 | + return self.store.config |
| 53 | + |
| 54 | + @property |
| 55 | + def provider(self): |
| 56 | + return self.store._get_provider() |
| 57 | + |
| 58 | + @property |
| 59 | + @functools.cache |
| 60 | + def repo_contents(self): |
| 61 | + with contextlib.chdir(self.path): |
| 62 | + return feast.repo_operations._get_repo_contents( |
| 63 | + self.path, self.project_name |
| 64 | + ) |
| 65 | + |
| 66 | + @property |
| 67 | + def registry(self): |
| 68 | + return self.store._registry |
| 69 | + |
| 70 | + @property |
| 71 | + def entities(self): |
| 72 | + return self.store.list_entities() |
| 73 | + |
| 74 | + @property |
| 75 | + def project_name(self): |
| 76 | + return self.config.project |
| 77 | + |
| 78 | + def apply(self, skip_source_validation=False): |
| 79 | + with contextlib.chdir(self.path): |
| 80 | + return feast.repo_operations.apply_total( |
| 81 | + self.config, self.path, skip_source_validation=skip_source_validation |
| 82 | + ) |
| 83 | + |
| 84 | + def teardown(self): |
| 85 | + return self.store.teardown() |
| 86 | + |
| 87 | + def list_on_demand_feature_view_names(self): |
| 88 | + return tuple(el.name for el in self.repo_contents.on_demand_feature_views) |
| 89 | + |
| 90 | + def get_on_demand_feature_view(self, on_demand_feature_view_name): |
| 91 | + return self.registry.get_on_demand_feature_view( |
| 92 | + on_demand_feature_view_name, self.store.project |
| 93 | + ) |
| 94 | + |
| 95 | + def list_feature_view_names(self): |
| 96 | + return tuple(el.name for el in self.repo_contents.feature_views) |
| 97 | + |
| 98 | + def get_feature_view(self, feature_view_name): |
| 99 | + return self.registry.get_feature_view(feature_view_name, self.store.project) |
| 100 | + |
| 101 | + def get_feature_refs(self, features): |
| 102 | + return utils._get_features(self.registry, self.store.project, list(features)) |
| 103 | + |
| 104 | + def get_feature_views_to_use(self, features): |
| 105 | + (all_feature_views, all_on_demand_feature_views) = ( |
| 106 | + utils._get_feature_views_to_use( |
| 107 | + self.registry, |
| 108 | + self.store.project, |
| 109 | + list(features), |
| 110 | + ) |
| 111 | + ) |
| 112 | + return (all_feature_views, all_on_demand_feature_views) |
| 113 | + |
| 114 | + def get_grouped_feature_views(self, features): |
| 115 | + feature_refs = self.get_feature_refs(features) |
| 116 | + (all_feature_views, all_on_demand_feature_views) = ( |
| 117 | + self.get_feature_views_to_use(features) |
| 118 | + ) |
| 119 | + fvs, odfvs = utils._group_feature_refs( |
| 120 | + feature_refs, |
| 121 | + all_feature_views, |
| 122 | + all_on_demand_feature_views, |
| 123 | + ) |
| 124 | + (feature_views, on_demand_feature_views) = ( |
| 125 | + tuple(view for view, _ in gen) for gen in (fvs, odfvs) |
| 126 | + ) |
| 127 | + return feature_views, on_demand_feature_views |
| 128 | + |
| 129 | + def validate_entity_expr(self, entity_expr, features, full_feature_names=False): |
| 130 | + (_, on_demand_feature_views) = self.get_grouped_feature_views(features) |
| 131 | + if self.store.config.coerce_tz_aware: |
| 132 | + # FIXME: pass entity_expr back out |
| 133 | + # entity_df = utils.make_df_tzaware(typing.cast(pd.DataFrame, entity_df)) |
| 134 | + pass |
| 135 | + bad_pairs = ( |
| 136 | + (feature_name, odfv.name) |
| 137 | + for odfv in on_demand_feature_views |
| 138 | + for feature_name in odfv.get_request_data_schema().keys() |
| 139 | + if feature_name not in entity_expr.columns |
| 140 | + ) |
| 141 | + if pair := next(bad_pairs, None): |
| 142 | + from feast.feature_store import RequestDatanotFoundInEntityDfException |
| 143 | + |
| 144 | + (feature_name, feature_view_name) = pair |
| 145 | + raise RequestDatanotFoundInEntityDfException( |
| 146 | + feature_name=feature_name, |
| 147 | + feature_view_name=feature_view_name, |
| 148 | + ) |
| 149 | + utils._validate_feature_refs( |
| 150 | + self.get_feature_refs(features), |
| 151 | + full_feature_names, |
| 152 | + ) |
| 153 | + |
| 154 | + def get_historical_features(self, entity_expr, features, full_feature_names=False): |
| 155 | + self.validate_entity_expr( |
| 156 | + entity_expr, features, full_feature_names=full_feature_names |
| 157 | + ) |
| 158 | + (odfv_dct, fv_dct) = group_features(self, features) |
| 159 | + entity_expr, all_join_keys = process_all_feature_views( |
| 160 | + self, entity_expr, fv_dct |
| 161 | + ) |
| 162 | + expr = process_odfvs(entity_expr, odfv_dct) |
| 163 | + return expr |
| 164 | + |
| 165 | + def get_historical_features_feast( |
| 166 | + self, entity_df, features, full_feature_names=False |
| 167 | + ): |
| 168 | + return self.store.get_historical_features( |
| 169 | + entity_df=entity_df, |
| 170 | + features=features, |
| 171 | + full_feature_names=full_feature_names, |
| 172 | + ) |
| 173 | + |
| 174 | + def get_online_features(self, features, entity_rows): |
| 175 | + return self.store.get_online_features( |
| 176 | + features=features, |
| 177 | + entity_rows=entity_rows, |
| 178 | + ).to_dict() |
| 179 | + |
| 180 | + def list_feature_service_names(self): |
| 181 | + return tuple(el.name for el in self.store.list_feature_services()) |
| 182 | + |
| 183 | + def get_feature_service(self, feature_service_name): |
| 184 | + return self.store.get_feature_service(feature_service_name) |
| 185 | + |
| 186 | + def list_data_source_names(self): |
| 187 | + return tuple( |
| 188 | + el.name for el in self.registry.list_data_sources(self.project_name) |
| 189 | + ) |
| 190 | + |
| 191 | + def get_data_source(self, data_source_name): |
| 192 | + return self.registry.get_data_source(data_source_name, self.project_name) |
| 193 | + |
| 194 | + @classmethod |
| 195 | + def make_applied_materialized(cls, path, end_date=None): |
| 196 | + end_date = end_date or datetime.now() |
| 197 | + store = cls(path) |
| 198 | + store.apply() |
| 199 | + store.store.materialize_incremental(end_date=end_date) |
| 200 | + return store |
| 201 | + |
| 202 | + |
| 203 | +def process_one_feature_view( |
| 204 | + entity_expr, store, feature_view, feature_names, all_join_keys |
| 205 | +): |
| 206 | + def _read_mapped( |
| 207 | + con, |
| 208 | + store, |
| 209 | + feature_view, |
| 210 | + feature_names, |
| 211 | + right_entity_key_columns, |
| 212 | + ets, |
| 213 | + ts, |
| 214 | + full_feature_names=False, |
| 215 | + ): |
| 216 | + def maybe_rename(expr, dct): |
| 217 | + return ( |
| 218 | + expr.rename({to_: from_ for from_, to_ in dct.items() if from_ in expr}) |
| 219 | + if dct |
| 220 | + else expr |
| 221 | + ) |
| 222 | + |
| 223 | + if full_feature_names: |
| 224 | + raise ValueError |
| 225 | + expr = ( |
| 226 | + xo.deferred_read_parquet( |
| 227 | + store.config.repo_path.joinpath(feature_view.batch_source.path), con=con |
| 228 | + ) |
| 229 | + .pipe(maybe_rename, feature_view.batch_source.field_mapping) |
| 230 | + .pipe(maybe_rename, feature_view.projection.join_key_map) |
| 231 | + .select(list(right_entity_key_columns) + list(feature_names)) |
| 232 | + ) |
| 233 | + if ts == ets: |
| 234 | + new_ts = f"__{ts}" |
| 235 | + expr, ts = expr.pipe(maybe_rename, {ts: new_ts}), new_ts |
| 236 | + return expr, ts |
| 237 | + |
| 238 | + def _merge(entity_expr, feature_expr, join_keys): |
| 239 | + return entity_expr.join( |
| 240 | + feature_expr, predicates=join_keys, how="left", rname="{name}__" |
| 241 | + ) |
| 242 | + |
| 243 | + def _normalize_timestamp(expr, *tss): |
| 244 | + casts = { |
| 245 | + ts: xo.expr.datatypes.Timestamp(timezone="UTC") |
| 246 | + for ts in tss |
| 247 | + if ts in expr and expr[ts].type().timezone is None |
| 248 | + } |
| 249 | + return expr.cast(casts) if casts else expr |
| 250 | + |
| 251 | + def _filter_ttl(expr, ttl, ets, ts): |
| 252 | + isna_condition = expr[ts].isnull() |
| 253 | + le_condition = expr[ts] <= expr[ets] |
| 254 | + if ttl and ttl.total_seconds() != 0: |
| 255 | + ge_condition = ( |
| 256 | + expr[ets] - xo.interval(seconds=ttl.total_seconds()) |
| 257 | + ) <= expr[ts] |
| 258 | + time_condition = ge_condition & le_condition |
| 259 | + else: |
| 260 | + time_condition = le_condition |
| 261 | + condition = isna_condition | time_condition |
| 262 | + return expr[condition] |
| 263 | + |
| 264 | + def _drop_duplicates(expr, join_keys, ets, ts, cts): |
| 265 | + order_by = tuple( |
| 266 | + expr[ts].desc(nulls_first=False) |
| 267 | + # cts desc first: most recent update |
| 268 | + # ts desc: closest to the event ts |
| 269 | + for ts in (cts, ts) |
| 270 | + if ts in expr |
| 271 | + ) |
| 272 | + ROW_NUM = "row_num" |
| 273 | + expr = ( |
| 274 | + expr.mutate( |
| 275 | + **{ |
| 276 | + ROW_NUM: ( |
| 277 | + xo.row_number().over( |
| 278 | + group_by=list(join_keys) + [ets], |
| 279 | + order_by=order_by, |
| 280 | + ) |
| 281 | + ), |
| 282 | + } |
| 283 | + ) |
| 284 | + .filter(xo._[ROW_NUM] == 0) |
| 285 | + .drop(ROW_NUM) |
| 286 | + ) |
| 287 | + return expr |
| 288 | + |
| 289 | + ets = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL |
| 290 | + assert ets in entity_expr |
| 291 | + con = entity_expr._find_backend() |
| 292 | + |
| 293 | + ts, cts = ( |
| 294 | + feature_view.batch_source.timestamp_field, |
| 295 | + feature_view.batch_source.created_timestamp_column, |
| 296 | + ) |
| 297 | + join_keys = tuple( |
| 298 | + feature_view.projection.join_key_map.get(entity_column.name, entity_column.name) |
| 299 | + for entity_column in feature_view.entity_columns |
| 300 | + ) |
| 301 | + all_join_keys = all_join_keys + [ |
| 302 | + join_key for join_key in join_keys if join_key not in all_join_keys |
| 303 | + ] |
| 304 | + right_entity_key_columns = list(filter(None, [ts, cts] + list(join_keys))) |
| 305 | + |
| 306 | + entity_expr = _normalize_timestamp(entity_expr, ets) |
| 307 | + |
| 308 | + feature_expr, ts = _read_mapped( |
| 309 | + con, store, feature_view, feature_names, right_entity_key_columns, ets, ts |
| 310 | + ) |
| 311 | + expr = _merge(entity_expr, feature_expr, join_keys) |
| 312 | + expr = _normalize_timestamp(expr, ts, cts) |
| 313 | + expr = _filter_ttl(expr, feature_view.ttl, ets, ts) |
| 314 | + expr = _drop_duplicates(expr, all_join_keys, ets, ts, cts) |
| 315 | + return expr, all_join_keys |
| 316 | + |
| 317 | + |
| 318 | +def process_all_feature_views(store, entity_expr, fv_dct): |
| 319 | + all_join_keys = [] |
| 320 | + for feature_view, feature_names in fv_dct.items(): |
| 321 | + entity_expr, all_join_keys = process_one_feature_view( |
| 322 | + entity_expr, store, feature_view, feature_names, all_join_keys |
| 323 | + ) |
| 324 | + return entity_expr, all_join_keys |
| 325 | + |
| 326 | + |
| 327 | +@toolz.curry |
| 328 | +def apply_odfv_dct(df, odfv_udfs): |
| 329 | + for other in (udf(df) for udf in odfv_udfs): |
| 330 | + df = df.join(other) |
| 331 | + return df |
| 332 | + |
| 333 | + |
| 334 | +def make_uniform_timestamps(expr, timezone="UTC", scale=6): |
| 335 | + import xorq.vendor.ibis.expr.datatypes as dt |
| 336 | + |
| 337 | + casts = { |
| 338 | + name: dt.Timestamp(timezone=timezone, scale=scale) |
| 339 | + for name, typ in expr.schema().items() |
| 340 | + if isinstance(typ, dt.Timestamp) |
| 341 | + } |
| 342 | + return expr.cast(casts) if casts else expr |
| 343 | + |
| 344 | + |
| 345 | +def calc_odfv_schema_append(odfv_dct): |
| 346 | + fields = (field for odfv in odfv_dct for field in odfv.features) |
| 347 | + schema_append = {field.name: field.dtype.name for field in fields} |
| 348 | + return schema_append |
| 349 | + |
| 350 | + |
| 351 | +def process_odfvs(entity_expr, odfv_dct, full_feature_names=False): |
| 352 | + if full_feature_names: |
| 353 | + raise ValueError |
| 354 | + entity_expr = make_uniform_timestamps(entity_expr) |
| 355 | + odfv_udfs = tuple(odfv.feature_transformation.udf for odfv in odfv_dct.keys()) |
| 356 | + schema_in = entity_expr.schema() |
| 357 | + schema_append = calc_odfv_schema_append(odfv_dct) |
| 358 | + udxf = xo.expr.relations.flight_udxf( |
| 359 | + process_df=apply_odfv_dct(odfv_udfs=odfv_udfs), |
| 360 | + maybe_schema_in=schema_in, |
| 361 | + maybe_schema_out=schema_in | schema_append, |
| 362 | + name="process_odfvs", |
| 363 | + ) |
| 364 | + return udxf(entity_expr) |
| 365 | + |
| 366 | + |
| 367 | +def group_features(store, feature_names): |
| 368 | + splat = tuple(feature_name.split(":") for feature_name in feature_names) |
| 369 | + assert (2,) == tuple(set(map(len, splat))) |
| 370 | + name_to_use_to_view = { |
| 371 | + view.projection.name_to_use(): view |
| 372 | + for view in store.store.list_all_feature_views() |
| 373 | + } |
| 374 | + dct = toolz.groupby( |
| 375 | + operator.itemgetter(0), |
| 376 | + splat, |
| 377 | + ) |
| 378 | + view_to_feature_names = { |
| 379 | + name_to_use_to_view[feature_view_name]: tuple( |
| 380 | + feature_name for _, feature_name in pairs |
| 381 | + ) |
| 382 | + for feature_view_name, pairs in dct.items() |
| 383 | + } |
| 384 | + is_odfv = toolz.flip(isinstance)(feast.OnDemandFeatureView) |
| 385 | + (odfv_dct, fv_dct) = ( |
| 386 | + toolz.keyfilter(f, view_to_feature_names) |
| 387 | + for f in (is_odfv, toolz.complement(is_odfv)) |
| 388 | + ) |
| 389 | + return odfv_dct, fv_dct |
0 commit comments