From b21bedfc0503f980ed11278991e0fd43d42cd286 Mon Sep 17 00:00:00 2001
From: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com>
Date: Mon, 17 Jun 2024 05:52:13 -0700
Subject: [PATCH] add bf16 for Tile CUDA executor (#20854)
### Description
add bf16 for Tile CUDA executor
### Motivation and Context
required change to support phimm model for ORT training
---
docs/OperatorKernels.md | 2 +-
onnxruntime/core/providers/cuda/tensor/tile.cc | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 67bfe48327e14..5f19c16cba616 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -827,7 +827,7 @@ Do not modify directly.*
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|ThresholdedRelu|*in* X:**T**
*out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**
or
*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
+|Tile|*in* input:**T**
*in* repeats:**T1**
*out* output:**T**
or
*in* input:**T**
*in* tiles:**T**
*in* axis:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(int64)|
|TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**
or
*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc
index e671b2cdc0277..01522c71dc51b 100644
--- a/onnxruntime/core/providers/cuda/tensor/tile.cc
+++ b/onnxruntime/core/providers/cuda/tensor/tile.cc
@@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
DataTypeImpl::GetTensorType(),
- DataTypeImpl::GetTensorType()})
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType()),
Tile);