Skip to content
Merged
6 changes: 0 additions & 6 deletions infini_train/include/autograd/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ class Tensor;

namespace infini_train::autograd {

struct LinearGradFlags {
bool input = false;
bool weight = false;
bool bias = false;
};

class Linear : public Function {
public:
static constexpr char kType[] = "LinearFunction";
Expand Down
2 changes: 2 additions & 0 deletions infini_train/include/autograd/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ class Matmul : public Function {

private:
int64_t out_features_ = 0;
std::vector<int64_t> input1_dims_;
std::vector<int64_t> input2_dims_;
};
} // namespace infini_train::autograd
80 changes: 80 additions & 0 deletions infini_train/include/common/cuda/gemm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#pragma once

#include <cublas_v2.h>

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::kernels::cuda {

/**
* Parameter bundle for a single GEMM call:
* C = alpha * op(A) * op(B) + beta * C
*
* batch_count == 1 → non-batched path (cublasGemmEx)
* batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx)
*
* When batch_count == 1, stride_a/b/c are unused and must be left at 0.
*/
struct GemmParams {
cublasOperation_t trans_a = CUBLAS_OP_N;
cublasOperation_t trans_b = CUBLAS_OP_N;

int m = 0; // rows of op(A) and C
int n = 0; // cols of op(B) and C
int k = 0; // cols of op(A) == rows of op(B)

const void *A = nullptr;
int lda = 0;
const void *B = nullptr;
int ldb = 0;
void *C = nullptr;
int ldc = 0;

float alpha = 1.0f;
float beta = 0.0f;

// batch_count=1: non-batched (Linear path); stride_a/b/c must be 0
// batch_count>1: strided-batched (Matmul path)
int batch_count = 1;
long long stride_a = 0;
long long stride_b = 0;
long long stride_c = 0;

DataType input_dtype; // dtype of A and B
DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out)
};

/**
* Execute the GEMM described by `p` via cuBLAS.
* Dispatches to cublasGemmEx (batch_count==1) or
* cublasGemmStridedBatchedEx (batch_count>1).
* Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision.
* Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)).
*/
void GemmCuda(const Device &device, const GemmParams &p);

/**
* Parameter bundle for a single SGEMV call (fp32 only):
* y = alpha * op(A) * x + beta * y
*
* op(A) is m_phys-by-n_phys when trans==N, or n_phys-by-m_phys when trans==T,
* where m_phys and n_phys are the physical (pre-transpose) row/col counts of A.
*/
struct SgemvParams {
cublasOperation_t trans = CUBLAS_OP_N;
int m = 0;
int n = 0;
const float *A = nullptr;
int lda = 0;
const float *x = nullptr;
int incx = 1;
float *y = nullptr;
int incy = 1;
float alpha = 1.0f;
float beta = 0.0f;
};

void SgemvCuda(const Device &device, const SgemvParams &p);

} // namespace infini_train::kernels::cuda
30 changes: 21 additions & 9 deletions infini_train/src/autograd/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
const auto &grad_output = grad_outputs[0];

CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Linear::Backward";
LinearGradFlags grad_flags = {.input = needs_input_grad_[0],
.weight = needs_input_grad_.size() > 1 && needs_input_grad_[1],
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
bool need_grad_input = needs_input_grad_[0];
bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2];

auto device = grad_output->GetDevice().type();
// TODO: skip autograd graph construction entirely when no input requires grad
auto [grad_input, grad_weight, grad_bias]
= Dispatcher::Instance()
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_,
grad_output, bias_, grad_flags);

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_weight = nullptr;
std::shared_ptr<Tensor> grad_bias = nullptr;

if (need_grad_input) {
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_);
}
if (need_grad_weight) {
grad_weight = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>(
{device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_);
}
if (need_grad_bias) {
grad_bias = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "LinearBackwardBias"}, grad_output,
out_features_);
}

if (bias_) {
return {grad_input, grad_weight, grad_bias};
} else {
Expand Down
41 changes: 33 additions & 8 deletions infini_train/src/autograd/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
// determined by autocast, not derived from output->Dtype().
auto compute_dtype = output->Dtype();
saved_tensors_ = {
input1->Dtype() == compute_dtype ? input1 : std::make_shared<Tensor>(input1->To(compute_dtype)),
input2->Dtype() == compute_dtype ? input2 : std::make_shared<Tensor>(input2->To(compute_dtype)),

// grad_input1 = grad_output @ input2^T, so input2 is needed
// grad_input2 = grad_output^T @ input1, so input1 is needed
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto cast = [&](const std::shared_ptr<Tensor> &t) {
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
};

saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
Comment thread
Chamberlain0w0 marked this conversation as resolved.
input1_dims_ = input1->Dims();
input2_dims_ = input2->Dims();
out_features_ = output->Dims()[0];
}

Expand All @@ -45,10 +54,26 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];

auto device = input1->GetDevice().type();
auto [grad_input1, grad_input2]
= Dispatcher::Instance().Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
{device, "MatmulBackward"}, input1, input2, grad_output);
return {grad_input1, grad_input2};
CHECK(!needs_input_grad_.empty()) << "needs_input_grad_ not populated in Matmul::Backward";
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];

auto device = grad_output->GetDevice().type();

std::shared_ptr<Tensor> grad_input = nullptr;
std::shared_ptr<Tensor> grad_other = nullptr;

if (need_grad_input1) {
CHECK(input2 != nullptr) << "input2 not saved but need_grad_input1 is true";
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
grad_output, input1_dims_);
}
if (need_grad_input2) {
CHECK(input1 != nullptr) << "input1 not saved but need_grad_input2 is true";
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
grad_output, input2_dims_);
}

return {grad_input, grad_other};
}
} // namespace infini_train::autograd
Loading
Loading