Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 133 additions & 19 deletions src/tofu_graph.c
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,62 @@ static void accumulate_grad(tofu_graph_node* node, tofu_tensor* grad_contrib)
}
}

/* Helper: Reduce gradient to match input shape when broadcasting occurred
* When broadcasting happens in forward pass (e.g., [3,1] * [3,4] -> [3,4]),
* the gradient has the output shape [3,4], but we need to reduce it back to [3,1]
* by summing over the broadcast dimensions.
*/
static tofu_tensor* reduce_grad_for_broadcast(tofu_tensor* grad, const tofu_tensor* input)
{
int i;
tofu_tensor* reduced_grad = grad;

/* Identify dimensions that need reduction by comparing shapes from the right */
int grad_ndim = grad->ndim;
int input_ndim = input->ndim;

/* Process dimensions from right to left (align from trailing dimensions) */
for (i = 0; i < grad_ndim; i++) {
int grad_axis = grad_ndim - 1 - i; /* Current axis in grad */
int input_axis = input_ndim - 1 - i; /* Corresponding axis in input */

/* Determine if this dimension needs reduction */
int needs_reduction = 0;

if (input_axis < 0) {
/* Grad has more dimensions - this is a prepended dimension, needs reduction */
needs_reduction = 1;
} else if (input->dims[input_axis] == 1 && reduced_grad->dims[grad_axis] > 1) {
/* Input had size 1 but grad has size > 1 - broadcast dimension, needs reduction */
needs_reduction = 1;
}

if (needs_reduction) {
tofu_tensor* temp = tofu_tensor_sumreduce(reduced_grad, NULL, grad_axis);
if (reduced_grad != grad) {
tofu_tensor_free_data_too(reduced_grad);
}
reduced_grad = temp;
/* Note: sumreduce doesn't change ndim, just sets dims[grad_axis] = 1 */
}
}

/* If reduced_grad has more dimensions than input, reshape to match
* (sumreduce keeps dims as size 1, but we need to match input's ndim) */
if (reduced_grad->ndim != input_ndim) {
/* Reshape returns a view, so we need to clone to avoid dangling references */
tofu_tensor* temp_reshape = tofu_tensor_reshape(reduced_grad, input_ndim, input->dims);
tofu_tensor* cloned = tofu_tensor_clone(temp_reshape);
tofu_tensor_free(temp_reshape); /* Free the view */
if (reduced_grad != grad) {
tofu_tensor_free_data_too(reduced_grad); /* Free the intermediate */
}
reduced_grad = cloned;
}

return reduced_grad;
}

/* Backward functions for each operation */

/* Matmul backward: y = A @ B */
Expand Down Expand Up @@ -885,23 +941,46 @@ static void matmul_backward(tofu_graph_node* node)

/* ∂L/∂A = (∂L/∂y) @ B^T */
if (A->requires_grad) {
tofu_tensor* B_T = tofu_tensor_transpose(B->value, NULL, NULL);
tofu_tensor* grad_A = tofu_tensor_matmul(grad_y_2d, B_T, NULL);
/* For batch matmul, transpose only the last 2 dimensions (matrix dims), not batch dims */
tofu_tensor* B_T = NULL;
if (B->value->ndim == 2) {
/* Simple 2D case */
B_T = tofu_tensor_transpose(B->value, NULL, NULL);
} else {
/* N-D case: transpose only last 2 dims */
int* axes = (int*)malloc(B->value->ndim * sizeof(int));
for (int i = 0; i < B->value->ndim - 2; i++) {
axes[i] = i; /* Keep batch dimensions in order */
}
axes[B->value->ndim - 2] = B->value->ndim - 1; /* Swap last two */
axes[B->value->ndim - 1] = B->value->ndim - 2;
B_T = tofu_tensor_transpose(B->value, NULL, axes);
free(axes);
}
tofu_tensor* grad_A_full = tofu_tensor_matmul(grad_y_2d, B_T, NULL);

/* Reshape back to original shape if needed */
/* Reduce gradient if broadcasting occurred in batch dimensions */
tofu_tensor* grad_A = reduce_grad_for_broadcast(grad_A_full, A->value);

/* Reshape back to original shape if needed (for 1D inputs) */
if (is_A_1d && grad_A->ndim == 2) {
if (grad_A->dims[0] == 1) {
/* Squeeze [1, n] to [n] */
int new_dims[] = {grad_A->dims[1]};
tofu_tensor* grad_A_1d = tofu_tensor_reshape(grad_A, 1, new_dims);
tofu_tensor_free(grad_A);
if (grad_A != grad_A_full) {
tofu_tensor_free_data_too(grad_A);
}
grad_A = grad_A_1d;
}
}

accumulate_grad(A, grad_A);
tofu_tensor_free_data_too(B_T);
tofu_tensor_free_data_too(grad_A);
tofu_tensor_free_data_too(grad_A_full);
if (grad_A != grad_A_full) {
tofu_tensor_free_data_too(grad_A);
}
}

/* ∂L/∂B = A^T @ (∂L/∂y) */
Expand All @@ -916,8 +995,22 @@ static void matmul_backward(tofu_graph_node* node)
A_val = A_2d;
}

/* Transpose: [1, n] → [n, 1] */
tofu_tensor* A_T = tofu_tensor_transpose(A_val, NULL, NULL);
/* Transpose only matrix dimensions (last 2) */
tofu_tensor* A_T = NULL;
if (A_val->ndim == 2) {
/* Simple 2D case */
A_T = tofu_tensor_transpose(A_val, NULL, NULL);
} else {
/* N-D case: transpose only last 2 dims */
int* axes = (int*)malloc(A_val->ndim * sizeof(int));
for (int i = 0; i < A_val->ndim - 2; i++) {
axes[i] = i; /* Keep batch dimensions in order */
}
axes[A_val->ndim - 2] = A_val->ndim - 1; /* Swap last two */
axes[A_val->ndim - 1] = A_val->ndim - 2;
A_T = tofu_tensor_transpose(A_val, NULL, axes);
free(axes);
}

/* Reshape grad_y if needed for matmul with A_T */
tofu_tensor* grad_y_for_B = grad_y;
Expand All @@ -929,7 +1022,10 @@ static void matmul_backward(tofu_graph_node* node)
grad_y_for_B = grad_y_B_reshaped;
}

tofu_tensor* grad_B = tofu_tensor_matmul(A_T, grad_y_for_B, NULL);
tofu_tensor* grad_B_full = tofu_tensor_matmul(A_T, grad_y_for_B, NULL);

/* Reduce gradient if broadcasting occurred in batch dimensions */
tofu_tensor* grad_B = reduce_grad_for_broadcast(grad_B_full, B->value);

accumulate_grad(B, grad_B);

Expand All @@ -938,7 +1034,10 @@ static void matmul_backward(tofu_graph_node* node)
if (grad_y_B_reshaped)
tofu_tensor_free(grad_y_B_reshaped);
tofu_tensor_free_data_too(A_T);
tofu_tensor_free_data_too(grad_B);
tofu_tensor_free_data_too(grad_B_full);
if (grad_B != grad_B_full) {
tofu_tensor_free_data_too(grad_B);
}
}

/* Free temporary grad_y_2d if we created it */
Expand All @@ -962,13 +1061,20 @@ static void add_backward(tofu_graph_node* node)

/* ∂L/∂x = ∂L/∂z (sum over broadcast dimensions if needed) */
if (x->requires_grad) {
/* TODO: Handle broadcasting properly - for now assume same shape */
accumulate_grad(x, grad_z);
tofu_tensor* grad_x = reduce_grad_for_broadcast(grad_z, x->value);
accumulate_grad(x, grad_x);
if (grad_x != grad_z) {
tofu_tensor_free_data_too(grad_x);
}
}

/* ∂L/∂y = ∂L/∂z */
/* ∂L/∂y = ∂L/∂z (sum over broadcast dimensions if needed) */
if (y->requires_grad) {
accumulate_grad(y, grad_z);
tofu_tensor* grad_y = reduce_grad_for_broadcast(grad_z, y->value);
accumulate_grad(y, grad_y);
if (grad_y != grad_z) {
tofu_tensor_free_data_too(grad_y);
}
}
}

Expand All @@ -985,18 +1091,26 @@ static void mul_backward(tofu_graph_node* node)
if (!grad_z)
return;

/* ∂L/∂x = ∂L/∂z * y */
/* ∂L/∂x = ∂L/∂z * y (sum over broadcast dimensions if needed) */
if (x->requires_grad) {
tofu_tensor* grad_x = tofu_tensor_elew_broadcast(grad_z, y->value, NULL, TOFU_MUL);
tofu_tensor* grad_x_full = tofu_tensor_elew_broadcast(grad_z, y->value, NULL, TOFU_MUL);
tofu_tensor* grad_x = reduce_grad_for_broadcast(grad_x_full, x->value);
accumulate_grad(x, grad_x);
tofu_tensor_free_data_too(grad_x);
tofu_tensor_free_data_too(grad_x_full);
if (grad_x != grad_x_full) {
tofu_tensor_free_data_too(grad_x);
}
}

/* ∂L/∂y = ∂L/∂z * x */
/* ∂L/∂y = ∂L/∂z * x (sum over broadcast dimensions if needed) */
if (y->requires_grad) {
tofu_tensor* grad_y = tofu_tensor_elew_broadcast(grad_z, x->value, NULL, TOFU_MUL);
tofu_tensor* grad_y_full = tofu_tensor_elew_broadcast(grad_z, x->value, NULL, TOFU_MUL);
tofu_tensor* grad_y = reduce_grad_for_broadcast(grad_y_full, y->value);
accumulate_grad(y, grad_y);
tofu_tensor_free_data_too(grad_y);
tofu_tensor_free_data_too(grad_y_full);
if (grad_y != grad_y_full) {
tofu_tensor_free_data_too(grad_y);
}
}
}

Expand Down
Loading