From 8aa2667ae61ffecbee883009d5fa7f18bef595db Mon Sep 17 00:00:00 2001 From: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com> Date: Mon, 17 Jun 2024 05:52:13 -0700 Subject: [PATCH] add bf16 for Tile CUDA executor (#20854) ### Description add bf16 for Tile CUDA executor ### Motivation and Context required change to support phimm model for ORT training --- docs/OperatorKernels.md | 2 +- onnxruntime/core/providers/cuda/tensor/tile.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 67bfe48327e14..5f19c16cba616 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -827,7 +827,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)| |||1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| +|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**

or

*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)| |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index e671b2cdc0277..01522c71dc51b 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Tile);