Skip to content

Commit e3510f4

Browse files
Fix for PR546, adding float32 and float16 (#569)
Signed-off-by: Iryna Boiko <[email protected]>
1 parent 9abbdc0 commit e3510f4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@
104104
_TYPE_CACHE: dict[str, dict[str, Any]] = {}
105105

106106
hpu_buffer: list[list[torch.Tensor]] = []
107-
HPU_TORCH_DTYPE_TO_STR_DTYPE = {torch.bfloat16: "bfloat16", torch.float8_e4m3fn: "fp8_e4m3"}
107+
HPU_TORCH_DTYPE_TO_STR_DTYPE = {
108+
torch.float32: "float32",
109+
torch.bfloat16: "bfloat16",
110+
torch.float16: "float16",
111+
torch.float8_e4m3fn: "fp8_e4m3"
112+
}
108113

109114

110115
class BucketingFailedException(Exception):

0 commit comments

Comments
 (0)