From df7c818e060f1d85a60f700ef77fd4b3ef48d1f4 Mon Sep 17 00:00:00 2001 From: Utkarsh Kunwar <14164924+UtkarshKunwar@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:58:40 +0530 Subject: [PATCH] feat(fx-importer): support for importing fp8 model parameters Essentially follows the same way how the importing of `bfloat16` is being handled. --- python/torch_mlir/extras/fx_importer.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 8840055744e7..119e20ded9a6 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -176,7 +176,6 @@ TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { torch.float16: lambda: F16Type.get(), - torch.bfloat16: lambda: BF16Type.get(), torch.float32: lambda: F32Type.get(), torch.float64: lambda: F64Type.get(), torch.uint8: lambda: IntegerType.get_unsigned(8), @@ -191,16 +190,20 @@ torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), } -# Type entries added only in torch with higher version +# Type entries added only in torch with higher version. bfloat16 is present from before +# but handling it here keeps the numpy and ml_dtypes business clean. OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = { + "bfloat16": lambda: BF16Type.get(), "float8_e5m2": lambda: Float8E5M2Type.get(), "float8_e4m3fn": lambda: Float8E4M3FNType.get(), "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), } +OPTIONAL_TORCH_DTYPES: List[TorchDtype] = list() for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): if hasattr(torch, dtype_str): TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type + OPTIONAL_TORCH_DTYPES.append(getattr(torch, dtype_str)) TORCH_DTYPE_TO_NPY_TYPE = { # torch.qint8: None, # no equivalent np datatype @@ -218,8 +221,15 @@ torch.complex64: np.complex64, torch.complex128: np.complex128, } + if ml_dtypes is not None: - TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16 + # Type entries added only in torch with higher version. ml_dtypes follows the same + # naming but we should check both regardless. + for dtype_str in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.keys(): + if hasattr(torch, dtype_str) and hasattr(ml_dtypes, dtype_str): + TORCH_DTYPE_TO_NPY_TYPE[getattr(torch, dtype_str)] = getattr( + ml_dtypes, dtype_str + ) TORCH_DTYPE_TO_INT = { torch.uint8: 0, @@ -2070,10 +2080,10 @@ def _make_vtensor_literal_op( ) -> Operation: mapping = py_attr_tracker.track(tensor) if mapping.is_empty: - # check support for bfloat16 + # check support for bfloat16 and optional types. assert not ( - tensor.dtype == torch.bfloat16 and ml_dtypes is None - ), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + tensor.dtype in OPTIONAL_TORCH_DTYPES and ml_dtypes is None + ), f"{tensor.dtype} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" # Resolve the attribute. npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype) assert (