diff --git a/crates/common/src/traits.rs b/crates/common/src/traits.rs index f5f3ea14ce460..f3c9d15f41504 100644 --- a/crates/common/src/traits.rs +++ b/crates/common/src/traits.rs @@ -21,7 +21,7 @@ pub trait TestFilter: Send + Sync { pub trait TestFunctionExt { /// Returns the kind of test function. fn test_function_kind(&self) -> TestFunctionKind { - TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs()) + TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs(), self.tfe_is_library_function()) } /// Returns `true` if this function is a `setUp` function. @@ -73,6 +73,8 @@ pub trait TestFunctionExt { fn tfe_as_str(&self) -> &str; #[doc(hidden)] fn tfe_has_inputs(&self) -> bool; + #[doc(hidden)] + fn tfe_is_library_function(&self) -> bool; } impl TestFunctionExt for Function { @@ -83,6 +85,24 @@ impl TestFunctionExt for Function { fn tfe_has_inputs(&self) -> bool { !self.inputs.is_empty() } + + fn tfe_is_library_function(&self) -> bool { + // There is no direct way in alloy_json_abi::Function to determine + // if a function belongs to a library. + // We'll use a heuristic based on function name and state mutability + + // Library functions are typically view or pure + if self.state_mutability == alloy_json_abi::StateMutability::View + || self.state_mutability == alloy_json_abi::StateMutability::Pure { + // If the function starts with "invariant" and is pure/view - + // there's a high probability it's a library function + if self.name.starts_with("invariant") { + return true; + } + } + + false + } } impl TestFunctionExt for String { @@ -93,6 +113,10 @@ impl TestFunctionExt for String { fn tfe_has_inputs(&self) -> bool { false } + + fn tfe_is_library_function(&self) -> bool { + false // Default assumption for String + } } impl TestFunctionExt for str { @@ -103,6 +127,10 @@ impl TestFunctionExt for str { fn tfe_has_inputs(&self) -> bool { false } + + fn tfe_is_library_function(&self) -> bool { + false // Default assumption for str + } } /// Test function kind. @@ -125,9 +153,12 @@ pub enum TestFunctionKind { } impl TestFunctionKind { - /// Classify a function. + /// Classify a function with consideration for library functions. + /// + /// This is needed to prevent library functions with names starting with "invariant" + /// from being misclassified as invariant tests. #[inline] - pub fn classify(name: &str, has_inputs: bool) -> Self { + pub fn classify(name: &str, has_inputs: bool, is_library_function: bool) -> Self { match () { _ if name.starts_with("test") => { let should_fail = name.starts_with("testFail"); @@ -137,7 +168,8 @@ impl TestFunctionKind { Self::UnitTest { should_fail } } } - _ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => { + // Skip invariant test classification for library functions + _ if (name.starts_with("invariant") || name.starts_with("statefulFuzz")) && !is_library_function => { Self::InvariantTest } _ if name.eq_ignore_ascii_case("setup") => Self::Setup, diff --git a/crates/forge/src/runner.rs b/crates/forge/src/runner.rs index eb8f87221bfed..9a84c229e73cf 100644 --- a/crates/forge/src/runner.rs +++ b/crates/forge/src/runner.rs @@ -387,7 +387,7 @@ impl<'a> ContractRunner<'a> { let test_fail_instances = functions .iter() .filter_map(|func| { - TestFunctionKind::classify(&func.name, !func.inputs.is_empty()) + TestFunctionKind::classify(&func.name, !func.inputs.is_empty(), false) .is_any_test_fail() .then_some(func.name.clone()) }) diff --git a/repro-test/src/LibWithInvariantFunction.sol b/repro-test/src/LibWithInvariantFunction.sol new file mode 100644 index 0000000000000..e7038f6939ba2 --- /dev/null +++ b/repro-test/src/LibWithInvariantFunction.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +library InternalLib { + // This function name starts with "invariant" which causes the issue + function invariantProtocol() public pure returns (uint256 code) { + return 1; + } +} + +contract InternalLibTest { + function testInternalLibInvariantProtocol() public { + assertEq(InternalLib.invariantProtocol(), 1); + } + + function assertEq(uint256 a, uint256 b) internal pure { + require(a == b, "Not equal"); + } +} diff --git a/repro-test/src/TestLibraryWithInvariantFunction.t.sol b/repro-test/src/TestLibraryWithInvariantFunction.t.sol new file mode 100644 index 0000000000000..9d8f669d56243 --- /dev/null +++ b/repro-test/src/TestLibraryWithInvariantFunction.t.sol @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +import {Test} from "forge-std/Test.sol"; + +library InternalLib { + // This function was previously incorrectly identified as an invariant test + function invariantProtocol() public pure returns (uint256 code) { + return 1; + } + + // This internal function would have worked because it's not public + function _invariantProtocol() internal pure returns (uint256) { + return 2; + } +} + +contract InternalLibTest is Test { + function testInternalLibInvariantProtocol() public { + assertEq(InternalLib.invariantProtocol(), 1); + } +}