Skip to content

Commit

Permalink
Revert "Add torch._scaled_mm for CPU (pytorch#139975)"
Browse files Browse the repository at this point in the history
This reverts commit 22fae4c.

Reverted pytorch#139975 on behalf of https://github.com/huydhn due to third time is the charm ([comment](pytorch#139975 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 18, 2025
1 parent 59a0813 commit 49e8f9c
Show file tree
Hide file tree
Showing 12 changed files with 586 additions and 915 deletions.
83 changes: 0 additions & 83 deletions aten/src/ATen/native/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
#include <ATen/Config.h>

#include <ATen/native/mkldnn/Matmul.h>
#include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>
#if !defined(__s390x__) && !defined(__powerpc__)
#include <cpuinfo.h>
#endif

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/CPUFunctions.h>
Expand All @@ -29,9 +24,6 @@
#include <ATen/ops/mv_native.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/matmul.h>
#endif

namespace at::meta {
Expand Down Expand Up @@ -230,79 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){

}

static Tensor&
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");

TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());

// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());

auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});

float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
if (bias) {
out_tmp.add_(bias.value());
}
out_tmp = out_tmp.to(out.scalar_type());
out.copy_(out_tmp);
return out;
}

Tensor&
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
#if AT_MKLDNN_ENABLED() && (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 5)
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
} else
#endif
{
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
}

Tensor
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}

} // namespace at::native
127 changes: 1 addition & 126 deletions aten/src/ATen/native/mkldnn/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <ATen/core/Tensor.h>
#include <torch/library.h>
#include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -47,20 +46,9 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
}

Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
}

} // namespace at::native


#else // AT_MKLDNN_ENABLED

#include <ATen/native/mkldnn/MKLDNNCommon.h>
Expand Down Expand Up @@ -459,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
TORCH_FN(mkldnn_linear_pointwise_binary));
}

Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");

TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());

// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");

// Validation checks have passed lets resize the output to actual size
auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});

float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto src = at::native::itensor_view_from_dense(mat1_c);
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
bool with_bias = bias.has_value();
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
N = mat2_sizes[1];

std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> weight_dims = {K, N};
std::vector<int64_t> dst_dims = {M, N};

ideep::tensor dst = at::native::itensor_view_from_dense(out);
auto src_desc = ideep::tensor::desc(
src_dims,
get_mkldnn_dtype(mat1.scalar_type()),
ideep::format_tag::any);
auto weights_desc = ideep::tensor::desc(
weight_dims,
get_mkldnn_dtype(mat2.scalar_type()),
ideep::format_tag::any);
auto dst_desc = ideep::tensor::desc(
dst_dims,
get_mkldnn_dtype(out.scalar_type()),
ideep::format_tag::any);
ideep::tensor onednn_bias;
if (with_bias) {
auto bias_value = bias.value();
if (bias_value.dim() == 1) {
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
} else {
onednn_bias = at::native::itensor_view_from_dense(bias_value);
}
}
auto bias_desc = ideep::tensor::desc();
if (with_bias) {
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
get_mkldnn_dtype(bias.value().scalar_type()),
ideep::format_tag::any);
}
auto op_attr = ideep::attr_t();
if (input_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
}
if (weight_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
}

op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
auto engine = ideep::engine::cpu_engine();
dnnl::matmul::primitive_desc primitive_desc = with_bias
? dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, dst_desc, op_attr);
auto primitive = dnnl::matmul(primitive_desc);

// Prepare args and execute primitive
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
ideep::exec_args args;
args.insert({DNNL_ARG_SRC, src});
args.insert({DNNL_ARG_WEIGHTS, weight_t});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, onednn_bias});
}
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));

if (input_scale != 1.0f) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
}
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});

primitive.execute(ideep::stream::default_stream(), args);
return out;
}

} // namespace at

#endif // AT_MKLDNN_ENABLED
12 changes: 0 additions & 12 deletions aten/src/ATen/native/mkldnn/Linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,3 @@ C10_API Tensor mkl_linear(
} // namespace at

#endif // AT_MKLDNN_ENABLED()

namespace at::native {
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out);
} // namespace at::native
22 changes: 1 addition & 21 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
return ideep::tensor::data_type::bf16;
case ScalarType::Half:
return ideep::tensor::data_type::f16;
case ScalarType::Float8_e4m3fn:
return ideep::tensor::data_type::f8_e4m3;
case ScalarType::Float8_e5m2:
return ideep::tensor::data_type::f8_e5m2;
default:
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
}
Expand Down Expand Up @@ -165,24 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e4m3,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e5m2,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else {
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
}
}

Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7066,13 +7066,11 @@
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda

- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda

# NOTE [ Sparse: autograd and API ]
Expand Down
Loading

0 comments on commit 49e8f9c

Please sign in to comment.