Skip to content

Commit a65e79d

Browse files
committed
outline call_impl to save on code size
This does cause more move constructions, as shown by the needed update to test_copy_move. Up to reviewers whether they want more code size or more moves.
1 parent f34a039 commit a65e79d

File tree

3 files changed

+140
-28
lines changed

3 files changed

+140
-28
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10+
//
11+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://llvm.org/LICENSE.txt for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file contains some extension to <functional>.
18+
//
19+
// No library is required when using these functions.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
// Extra additions to <functional>
23+
//===----------------------------------------------------------------------===//
24+
25+
/// An efficient, type-erasing, non-owning reference to a callable. This is
26+
/// intended for use as the type of a function parameter that is not used
27+
/// after the function in question returns.
28+
///
29+
/// This class does not own the callable, so it is not in general safe to store
30+
/// a FunctionRef.
31+
32+
// pybind11: modified again from executorch::runtime::FunctionRef
33+
// - renamed back to function_ref
34+
// - use pybind11 enable_if_t, remove_cvref_t, and remove_reference_t
35+
36+
// torch::executor: modified from llvm::function_ref
37+
// - renamed to FunctionRef
38+
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
39+
// - use namespaced internal::remove_cvref_t
40+
41+
#pragma once
42+
43+
#include <pybind11/detail/common.h>
44+
45+
#include <cstdint>
46+
#include <type_traits>
47+
#include <utility>
48+
49+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
50+
PYBIND11_NAMESPACE_BEGIN(detail)
51+
52+
//===----------------------------------------------------------------------===//
53+
// Features from C++20
54+
//===----------------------------------------------------------------------===//
55+
56+
template <typename Fn>
57+
class function_ref;
58+
59+
template <typename Ret, typename... Params>
60+
class function_ref<Ret(Params...)> {
61+
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
62+
intptr_t callable;
63+
64+
template <typename Callable>
65+
static Ret callback_fn(intptr_t callable, Params... params) {
66+
return (*reinterpret_cast<Callable *>(callable))(std::forward<Params>(params)...);
67+
}
68+
69+
public:
70+
function_ref() = default;
71+
function_ref(std::nullptr_t) {}
72+
73+
template <typename Callable>
74+
function_ref(
75+
Callable &&callable,
76+
// This is not the copy-constructor.
77+
enable_if_t<!std::is_same<remove_cvref_t<Callable>, function_ref>::value> * = nullptr,
78+
// Functor must be callable and return a suitable type.
79+
enable_if_t<
80+
std::is_void<Ret>::value
81+
|| std::is_convertible<decltype(std::declval<Callable>()(std::declval<Params>()...)),
82+
Ret>::value> * = nullptr)
83+
: callback(callback_fn<remove_reference_t<Callable>>),
84+
callable(reinterpret_cast<intptr_t>(&callable)) {}
85+
86+
Ret operator()(Params... params) const {
87+
return callback(callable, std::forward<Params>(params)...);
88+
}
89+
90+
explicit operator bool() const { return callback; }
91+
92+
bool operator==(const function_ref<Ret(Params...)> &Other) const {
93+
return callable == Other.callable;
94+
}
95+
};
96+
PYBIND11_NAMESPACE_END(detail)
97+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

include/pybind11/pybind11.h

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "detail/dynamic_raw_ptr_cast_if_possible.h"
1414
#include "detail/exception_translation.h"
1515
#include "detail/function_record_pyobject.h"
16+
#include "detail/function_ref.h"
1617
#include "detail/init.h"
1718
#include "detail/native_enum_data.h"
1819
#include "detail/using_smart_holder.h"
@@ -379,6 +380,40 @@ class cpp_function : public function {
379380
return unique_function_record(new detail::function_record());
380381
}
381382

383+
private:
384+
// This is outlined from the dispatch lambda in initialize to save
385+
// on code size. Crucially, we use function_ref to type-erase the
386+
// actual function lambda so that we can get code reuse for
387+
// functions with the same Return, Args, and Guard.
388+
template <typename Return, typename Guard, typename ArgsConverter, typename... Args>
389+
static handle call_impl(detail::function_call &call, detail::function_ref<Return(Args...)> f) {
390+
using namespace detail;
391+
using cast_out
392+
= make_caster<conditional_t<std::is_void<Return>::value, void_type, Return>>;
393+
394+
ArgsConverter args_converter;
395+
if (!args_converter.load_args(call)) {
396+
return PYBIND11_TRY_NEXT_OVERLOAD;
397+
}
398+
399+
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
400+
return_value_policy policy
401+
= return_value_policy_override<Return>::policy(call.func.policy);
402+
403+
/* Perform the function call */
404+
handle result;
405+
if (call.func.is_setter) {
406+
(void) std::move(args_converter).template call<Return, Guard>(f);
407+
result = none().release();
408+
} else {
409+
result = cast_out::cast(
410+
std::move(args_converter).template call<Return, Guard>(f), policy, call.parent);
411+
}
412+
413+
return result;
414+
}
415+
416+
protected:
382417
/// Special internal constructor for functors, lambda functions, etc.
383418
template <typename Func, typename Return, typename... Args, typename... Extra>
384419
void initialize(Func &&f, Return (*)(Args...), const Extra &...extra) {
@@ -441,13 +476,6 @@ class cpp_function : public function {
441476

442477
/* Dispatch code which converts function arguments and performs the actual function call */
443478
rec->impl = [](function_call &call) -> handle {
444-
cast_in args_converter;
445-
446-
/* Try to cast the function arguments into the C++ domain */
447-
if (!args_converter.load_args(call)) {
448-
return PYBIND11_TRY_NEXT_OVERLOAD;
449-
}
450-
451479
/* Invoke call policy pre-call hook */
452480
process_attributes<Extra...>::precall(call);
453481

@@ -456,24 +484,11 @@ class cpp_function : public function {
456484
: call.func.data[0]);
457485
auto *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));
458486

459-
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
460-
return_value_policy policy
461-
= return_value_policy_override<Return>::policy(call.func.policy);
462-
463-
/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
464-
using Guard = extract_guard_t<Extra...>;
465-
466-
/* Perform the function call */
467-
handle result;
468-
if (call.func.is_setter) {
469-
(void) std::move(args_converter).template call<Return, Guard>(cap->f);
470-
result = none().release();
471-
} else {
472-
result = cast_out::cast(
473-
std::move(args_converter).template call<Return, Guard>(cap->f),
474-
policy,
475-
call.parent);
476-
}
487+
auto result = call_impl<Return,
488+
/* Function scope guard -- defaults to the compile-to-nothing
489+
`void_type` */
490+
extract_guard_t<Extra...>,
491+
cast_in>(call, detail::function_ref<Return(Args...)>(cap->f));
477492

478493
/* Invoke call policy post-call hook */
479494
process_attributes<Extra...>::postcall(call, result);

tests/test_copy_move.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def test_move_and_copy_loads():
7070

7171
assert c_m.copy_assignments + c_m.copy_constructions == 0
7272
assert c_m.move_assignments == 6
73-
assert c_m.move_constructions == 9
73+
assert c_m.move_constructions == 21
7474
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
7575
assert c_mc.move_assignments == 5
76-
assert c_mc.move_constructions == 8
76+
assert c_mc.move_constructions == 18
7777
assert c_c.copy_assignments == 4
78-
assert c_c.copy_constructions == 6
78+
assert c_c.copy_constructions == 14
7979
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0
8080

8181

0 commit comments

Comments
 (0)