Skip to content

BPF task work WIP #9132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/uapi/linux/bpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7375,6 +7375,10 @@ struct bpf_timer {
__u64 __opaque[2];
} __attribute__((aligned(8)));

struct bpf_task_work {
__u64 __opaque[2];
} __attribute__((aligned(8)));

struct bpf_wq {
__u64 __opaque[2];
} __attribute__((aligned(8)));
Expand Down
59 changes: 57 additions & 2 deletions kernel/bpf/helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <linux/bpf_mem_alloc.h>
#include <linux/kasan.h>
#include <linux/bpf_verifier.h>

#include <linux/task_work.h>
#include "../../lib/kstrtox.h"

/* If kernel subsystem is allowing eBPF programs to call this function,
Expand Down Expand Up @@ -1141,6 +1141,61 @@ enum bpf_async_type {
BPF_ASYNC_TYPE_WQ,
};

struct bpf_defer {
struct bpf_map *map;
bpf_callback_t callback_fn;
struct callback_head work;
};

struct bpf_defer_work {
struct bpf_defer *defer;
} __attribute__((aligned(8)));
/*
static void bpf_task_work_callback(struct callback_head *cb)
{
struct bpf_defer *defer = container_of(cb, struct bpf_defer, work);
bpf_callback_t callback_fn;

printk("Callback called %p\n", defer);

callback_fn = defer->callback_fn;
printk("Callback called is %p\n", callback_fn);
if (callback_fn) {
printk("Callback called 2 %p\n", callback_fn);
printk("Key size %d\n", defer->map->key_size);
callback_fn(0, 0, 0, 0, 0);
printk("Callback called 3 %p\n", callback_fn);
}
}
*/
__bpf_kfunc int bpf_task_work_schedule(void* callback__ign)
{
bpf_callback_t callback_fn;
//struct bpf_defer *defer;
/*
struct bpf_defer_work *defer_work = (struct bpf_defer_work *)task_work;

BTF_TYPE_EMIT(struct bpf_task_work);

defer = bpf_map_kmalloc_node(map, sizeof(struct bpf_defer), GFP_ATOMIC, map->numa_node);
if (!defer) {
return -ENOMEM;
}
//defer->map = map;
defer->work.func = bpf_task_work_callback;
defer->work.next = NULL;
defer->callback_fn = callback__ign;
printk("Callback is %p\n", callback__ign);
defer_work->defer = defer;
printk("Scheduling callback\n");
*/
callback_fn = callback__ign;
callback_fn(0,0,0,0,0);
//task_work_add(NULL, &defer->work, TWA_RESUME);
printk("Callback scheduled \n");
return 0;
}

static DEFINE_PER_CPU(struct bpf_hrtimer *, hrtimer_running);

static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
Expand Down Expand Up @@ -3303,7 +3358,7 @@ BTF_ID_FLAGS(func, bpf_rbtree_first, KF_RET_NULL)
BTF_ID_FLAGS(func, bpf_rbtree_root, KF_RET_NULL)
BTF_ID_FLAGS(func, bpf_rbtree_left, KF_RET_NULL)
BTF_ID_FLAGS(func, bpf_rbtree_right, KF_RET_NULL)

BTF_ID_FLAGS(func, bpf_task_work_schedule)
#ifdef CONFIG_CGROUPS
BTF_ID_FLAGS(func, bpf_cgroup_acquire, KF_ACQUIRE | KF_RCU | KF_RET_NULL)
BTF_ID_FLAGS(func, bpf_cgroup_release, KF_RELEASE)
Expand Down
59 changes: 49 additions & 10 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,11 @@ static bool is_sync_callback_calling_function(enum bpf_func_id func_id)
func_id == BPF_FUNC_user_ringbuf_drain;
}

static bool is_kfunc_callback_calling_function(enum bpf_func_id func_id);

static bool is_async_callback_calling_function(enum bpf_func_id func_id)
{
return func_id == BPF_FUNC_timer_set_callback;
return func_id == BPF_FUNC_timer_set_callback || is_kfunc_callback_calling_function(func_id);
}

static bool is_callback_calling_function(enum bpf_func_id func_id)
Expand Down Expand Up @@ -10781,6 +10783,24 @@ static int set_rbtree_add_callback_state(struct bpf_verifier_env *env,
return 0;
}

static int set_task_work_schedule_callback_state(struct bpf_verifier_env *env,
struct bpf_func_state *caller,
struct bpf_func_state *callee,
int insn_idx)
{

callee->regs[BPF_REG_1] = caller->regs[BPF_REG_1];

/* unused */
__mark_reg_not_init(env, &callee->regs[BPF_REG_2]);
__mark_reg_not_init(env, &callee->regs[BPF_REG_3]);
__mark_reg_not_init(env, &callee->regs[BPF_REG_4]);
__mark_reg_not_init(env, &callee->regs[BPF_REG_5]);
callee->in_callback_fn = true;
callee->callback_ret_range = retval_range(0, 1);
return 0;
}

static bool is_rbtree_lock_required_kfunc(u32 btf_id);

/* Are we currently verifying the callback for a rbtree helper that must
Expand Down Expand Up @@ -12108,6 +12128,7 @@ enum special_kfunc_type {
KF_bpf_res_spin_lock_irqsave,
KF_bpf_res_spin_unlock_irqrestore,
KF___bpf_trap,
KF_bpf_task_work_schedule,
};

BTF_ID_LIST(special_kfunc_list)
Expand Down Expand Up @@ -12174,6 +12195,12 @@ BTF_ID(func, bpf_res_spin_unlock)
BTF_ID(func, bpf_res_spin_lock_irqsave)
BTF_ID(func, bpf_res_spin_unlock_irqrestore)
BTF_ID(func, __bpf_trap)
BTF_ID(func, bpf_task_work_schedule)

static bool is_kfunc_callback_calling_function(enum bpf_func_id func_id)
{
return func_id == special_kfunc_list[KF_bpf_task_work_schedule];
}

static bool is_kfunc_ret_null(struct bpf_kfunc_call_arg_meta *meta)
{
Expand Down Expand Up @@ -12608,7 +12635,7 @@ static bool is_sync_callback_calling_kfunc(u32 btf_id)

static bool is_async_callback_calling_kfunc(u32 btf_id)
{
return btf_id == special_kfunc_list[KF_bpf_wq_set_callback_impl];
return btf_id == special_kfunc_list[KF_bpf_wq_set_callback_impl] || btf_id == special_kfunc_list[KF_bpf_task_work_schedule];
}

static bool is_bpf_throw_kfunc(struct bpf_insn *insn)
Expand Down Expand Up @@ -12861,7 +12888,7 @@ static bool check_css_task_iter_allowlist(struct bpf_verifier_env *env)

static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_arg_meta *meta,
int insn_idx)
{
{ // todo check
const char *func_name = meta->func_name, *ref_tname;
const struct btf *btf = meta->btf;
const struct btf_param *args;
Expand Down Expand Up @@ -13666,6 +13693,16 @@ static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
if (err < 0)
return err;

if (meta.func_id == special_kfunc_list[KF_bpf_task_work_schedule]) {
err = push_callback_call(env, insn, insn_idx, meta.subprogno,
set_task_work_schedule_callback_state);
if (err) {
verbose(env, "kfunc %s#%d failed callback verification\n",
func_name, meta.func_id);
return err;
}
}

if (meta.func_id == special_kfunc_list[KF_bpf_rbtree_add_impl]) {
err = push_callback_call(env, insn, insn_idx, meta.subprogno,
set_rbtree_add_callback_state);
Expand Down Expand Up @@ -16700,7 +16737,7 @@ static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn)
return 0;
}

if (insn->src_reg == BPF_PSEUDO_FUNC) {
if (insn->src_reg == BPF_PSEUDO_FUNC) { // todo check
struct bpf_prog_aux *aux = env->prog->aux;
u32 subprogno = find_subprog(env,
env->insn_idx + insn->imm + 1);
Expand Down Expand Up @@ -20161,7 +20198,7 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
}

if (insn[0].src_reg == BPF_PSEUDO_FUNC) {
aux = &env->insn_aux_data[i];
aux = &env->insn_aux_data[i]; // todo check
aux->ptr_type = PTR_TO_FUNC;
goto next_insn;
}
Expand Down Expand Up @@ -21202,7 +21239,7 @@ static int jit_subprogs(struct bpf_verifier_env *env)
* now populate all bpf_calls with correct addresses and
* run last pass of JIT
*/
for (i = 0; i < env->subprog_cnt; i++) {
for (i = 0; i < env->subprog_cnt; i++) { // Check what this is doing
insn = func[i]->insnsi;
for (j = 0; j < func[i]->len; j++, insn++) {
if (bpf_pseudo_func(insn)) {
Expand Down Expand Up @@ -21458,7 +21495,9 @@ static int fixup_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
insn->imm = BPF_CALL_IMM(desc->addr);
if (insn->off)
return 0;
if (desc->func_id == special_kfunc_list[KF_bpf_obj_new_impl] ||
if (desc->func_id == special_kfunc_list[KF_bpf_task_work_schedule]) {
printk("Can patch program here\n");
} else if (desc->func_id == special_kfunc_list[KF_bpf_obj_new_impl] ||
desc->func_id == special_kfunc_list[KF_bpf_percpu_obj_new_impl]) {
struct btf_struct_meta *kptr_struct_meta = env->insn_aux_data[insn_idx].kptr_struct_meta;
struct bpf_insn addr[2] = { BPF_LD_IMM64(BPF_REG_2, (long)kptr_struct_meta) };
Expand Down Expand Up @@ -21604,8 +21643,8 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
mark_subprog_exc_cb(env, env->exception_callback_subprog);
}

for (i = 0; i < insn_cnt;) {
if (insn->code == (BPF_ALU64 | BPF_MOV | BPF_X) && insn->imm) {
for (i = 0; i < insn_cnt;) { // When i == 25 insns[27] is our instruction, i ==27 is call
if (insn->code == (BPF_ALU64 | BPF_MOV | BPF_X) && insn->imm) {
if ((insn->off == BPF_ADDR_SPACE_CAST && insn->imm == 1) ||
(((struct bpf_map *)env->prog->aux->arena)->map_flags & BPF_F_NO_USER_CONV)) {
/* convert to 32-bit mov that clears upper 32-bit */
Expand Down Expand Up @@ -24109,7 +24148,7 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr, bpfptr_t uattr, __u3
ret = convert_ctx_accesses(env);

if (ret == 0)
ret = do_misc_fixups(env);
ret = do_misc_fixups(env); // kkl overwrites are here !!!!!!

/* do 32-bit optimization after insn patching has done so those patched
* insns could be handled correctly.
Expand Down
4 changes: 4 additions & 0 deletions tools/include/uapi/linux/bpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7375,6 +7375,10 @@ struct bpf_timer {
__u64 __opaque[2];
} __attribute__((aligned(8)));

struct bpf_task_work {
__u64 __opaque[2];
} __attribute__((aligned(8)));

struct bpf_wq {
__u64 __opaque[2];
} __attribute__((aligned(8)));
Expand Down
48 changes: 48 additions & 0 deletions tools/testing/selftests/bpf/prog_tests/test_task_work.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2025 Meta Platforms, Inc. and affiliates. */
#include <test_progs.h>
#include <string.h>
#include <stdio.h>
#include "task_work.skel.h"

static void test_task_work_run(void)
{
struct task_work *skel;
struct bpf_program *prog;
//struct bpf_link *link;
char data[5000];
int err, prog_fd;
//int err;
LIBBPF_OPTS(bpf_test_run_opts, opts,
.data_in = &data,
.data_size_in = sizeof(data),
.repeat = 1,
);

skel = task_work__open();
if (!ASSERT_OK_PTR(skel, "task_work__open"))
return;

err = task_work__load(skel);
if (!ASSERT_OK(err, "task_work__load"))
goto cleanup;

prog = bpf_object__find_program_by_name(skel->obj, "test_task_work");
prog_fd = bpf_program__fd(prog);
fprintf(stderr, "Running a program \n");
err = bpf_prog_test_run_opts(prog_fd, &opts);
sleep(20);
if (!ASSERT_OK(err, "test_run"))
goto cleanup;

fprintf(stderr, "Gooing to sleep \n");
sleep(20);
cleanup:
task_work__destroy(skel);
}

void test_task_work(void)
{
if (test__start_subtest("test_task_work_run"))
test_task_work_run();
}
27 changes: 27 additions & 0 deletions tools/testing/selftests/bpf/progs/task_work.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2022 Facebook */

#include <vmlinux.h>
#include <string.h>
#include <stdbool.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include "bpf_misc.h"
#include "errno.h"

char _license[] SEC("license") = "GPL";

static __u64 test_cb(__u64 p)
{
bpf_printk("Hello map %u\n", p);
return 0;
}

volatile int cnt = 0;

SEC("xdp")
int test_task_work(struct xdp_md *xdp)
{
bpf_task_work_schedule(test_cb);
return XDP_PASS;
}
Loading