diff --git a/mostlyai/engine/_common.py b/mostlyai/engine/_common.py index 32a627dc..b6434aaa 100644 --- a/mostlyai/engine/_common.py +++ b/mostlyai/engine/_common.py @@ -511,7 +511,7 @@ def skip_if_error_wrapper(*args, **kwargs) -> Any: def encode_slen_sidx_sdec(vals: pd.Series, max_seq_len: int, prefix: str = "") -> pd.DataFrame: - assert vals.dtype == int or vals.dtype == "int64[pyarrow]" + assert is_integer_dtype(vals) if max_seq_len < SLEN_SIDX_DIGIT_ENCODING_THRESHOLD or prefix == SDEC_SUB_COLUMN_PREFIX: # encode slen and sidx as numeric_discrete df = pd.DataFrame({f"{prefix}cat": vals})