@@ -51,6 +51,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
5151 return torch_upstream::ScalarType::Long;
5252 if (type.isSignedInteger (32 ))
5353 return torch_upstream::ScalarType::Int;
54+ if (type.isSignedInteger (16 ))
55+ return torch_upstream::ScalarType::Short;
5456 if (type.isSignlessInteger (1 ))
5557 return torch_upstream::ScalarType::Bool;
5658 if (type.isBF16 ())
@@ -95,6 +97,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
9597 return IntegerType::get (context, 64 , mlir::IntegerType::Signed);
9698 case torch_upstream::ScalarType::Int:
9799 return IntegerType::get (context, 32 , mlir::IntegerType::Signed);
100+ case torch_upstream::ScalarType::Short:
101+ return IntegerType::get (context, 16 , mlir::IntegerType::Signed);
98102 case torch_upstream::ScalarType::Bool:
99103 return IntegerType::get (context, 1 );
100104 case torch_upstream::ScalarType::BFloat16:
@@ -213,8 +217,8 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
213217 Location loc, float value,
214218 Type dtype) {
215219 // Creating constants satisfying backend contract.
216- if (dtype.isInteger (64 ) || dtype.isInteger (32 ) || dtype.isInteger (8 ) ||
217- dtype.isInteger (1 ))
220+ if (dtype.isInteger (64 ) || dtype.isInteger (32 ) || dtype.isInteger (16 ) ||
221+ dtype.isInteger (8 ) || dtype. isInteger ( 1 ))
218222 return rewriter.create <ConstantIntOp>(
219223 loc, rewriter.getI64IntegerAttr ((int64_t )value));
220224 if (dtype.isF64 () || dtype.isF32 () || dtype.isF16 () || dtype.isBF16 ())
0 commit comments