From 9b788b950af94fcb86b4f57df48f6d8d3fb4eca2 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Tue, 4 Feb 2025 17:32:05 +0800 Subject: [PATCH] refactor PooledObject --- include/luisa/xir/metadata.h | 12 +++++------- include/luisa/xir/pool.h | 9 +++++++-- include/luisa/xir/use.h | 2 +- include/luisa/xir/value.h | 12 ++++++------ src/xir/use.cpp | 4 +--- src/xir/value.cpp | 6 +++--- 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/include/luisa/xir/metadata.h b/include/luisa/xir/metadata.h index 64ebcc521..ccf50256f 100644 --- a/include/luisa/xir/metadata.h +++ b/include/luisa/xir/metadata.h @@ -19,7 +19,7 @@ class LC_XIR_API Metadata : public IntrusiveForwardNode { public: explicit Metadata(Pool *pool) noexcept; - [[nodiscard]] Pool *pool() const noexcept override { return _pool; } + [[nodiscard]] Pool *pool() noexcept override { return _pool; } [[nodiscard]] virtual DerivedMetadataTag derived_metadata_tag() const noexcept = 0; [[nodiscard]] virtual Metadata *clone(Pool *pool) const noexcept = 0; LUISA_XIR_DEFINED_ISA_METHOD(Metadata, metadata) @@ -31,14 +31,12 @@ class LC_XIR_API DerivedMetadata : public Base { using derived_metadata_type = Derived; using Super = DerivedMetadata; using Base::Base; + [[nodiscard]] static constexpr auto - static_derived_metadata_tag() noexcept { - return tag; - } + static_derived_metadata_tag() noexcept { return tag; } + [[nodiscard]] DerivedMetadataTag - derived_metadata_tag() const noexcept final { - return static_derived_metadata_tag(); - } + derived_metadata_tag() const noexcept final { return static_derived_metadata_tag(); } }; using MetadataList = IntrusiveForwardList; diff --git a/include/luisa/xir/pool.h b/include/luisa/xir/pool.h index 1850f25e7..2e7ad553e 100644 --- a/include/luisa/xir/pool.h +++ b/include/luisa/xir/pool.h @@ -31,7 +31,11 @@ class LC_XIR_API PooledObject { public: virtual ~PooledObject() noexcept = default; - [[nodiscard]] virtual Pool *pool() const noexcept = 0; + + [[nodiscard]] virtual Pool *pool() noexcept = 0; + [[nodiscard]] const Pool *pool() const noexcept { + return const_cast(this)->pool(); + } // make the object pinned to its memory location PooledObject(PooledObject &&) noexcept = delete; @@ -74,7 +78,8 @@ class LC_XIR_API PoolOwner { public: explicit PoolOwner(size_t init_pool_cap = 0u) noexcept; virtual ~PoolOwner() noexcept = default; - [[nodiscard]] Pool *pool() const noexcept { return _pool.get(); } + [[nodiscard]] Pool *pool() noexcept { return _pool.get(); } + [[nodiscard]] const Pool *pool() const noexcept { return _pool.get(); } }; }// namespace luisa::compute::xir diff --git a/include/luisa/xir/use.h b/include/luisa/xir/use.h index 3fdadfafa..5a14d3165 100644 --- a/include/luisa/xir/use.h +++ b/include/luisa/xir/use.h @@ -16,7 +16,7 @@ class LC_XIR_API Use final : public IntrusiveForwardNode { public: explicit Use(User *user, Value *value = nullptr) noexcept; void set_value(Value *value) noexcept; - [[nodiscard]] Pool *pool() const noexcept override; + [[nodiscard]] Pool *pool() noexcept override; [[nodiscard]] auto value() noexcept { return _value; } [[nodiscard]] auto value() const noexcept { return const_cast(_value); } [[nodiscard]] auto user() noexcept { return _user; } diff --git a/include/luisa/xir/value.h b/include/luisa/xir/value.h index ad82051b8..4d6f0231a 100644 --- a/include/luisa/xir/value.h +++ b/include/luisa/xir/value.h @@ -69,7 +69,7 @@ class LC_XIR_API GlobalValueModuleMixin { protected: explicit GlobalValueModuleMixin(Module *module) noexcept; ~GlobalValueModuleMixin() noexcept = default; - [[nodiscard]] Pool *_pool_from_parent_module() const noexcept; + [[nodiscard]] Pool *_pool_from_parent_module() noexcept; public: [[nodiscard]] Module *parent_module() noexcept { return _parent_module; } @@ -86,7 +86,7 @@ class DerivedGlobalValue : public DerivedValue, explicit DerivedGlobalValue(Module *module, Args &&...args) noexcept : DerivedValue{std::forward(args)...}, GlobalValueModuleMixin{module} {} - [[nodiscard]] Pool *pool() const noexcept final { return _pool_from_parent_module(); } + [[nodiscard]] Pool *pool() noexcept final { return _pool_from_parent_module(); } }; class LC_XIR_API LocalValueFunctionMixin { @@ -100,7 +100,7 @@ class LC_XIR_API LocalValueFunctionMixin { explicit LocalValueFunctionMixin(Function *function) noexcept; ~LocalValueFunctionMixin() noexcept = default; void _set_parent_function(Function *function) noexcept; - [[nodiscard]] Pool *_pool_from_parent_function() const noexcept; + [[nodiscard]] Pool *_pool_from_parent_function() noexcept; public: [[nodiscard]] Function *parent_function() noexcept { return _parent_function; } @@ -119,7 +119,7 @@ class DerivedFunctionScopeValue : public DerivedValue, explicit DerivedFunctionScopeValue(Function *function, Args &&...args) noexcept : DerivedValue{std::forward(args)...}, LocalValueFunctionMixin{function} {} - [[nodiscard]] Pool *pool() const noexcept final { return _pool_from_parent_function(); } + [[nodiscard]] Pool *pool() noexcept final { return _pool_from_parent_function(); } }; class LC_XIR_API LocalValueBlockMixin { @@ -132,7 +132,7 @@ class LC_XIR_API LocalValueBlockMixin { explicit LocalValueBlockMixin(BasicBlock *block) noexcept; ~LocalValueBlockMixin() noexcept = default; void _set_parent_block(BasicBlock *block) noexcept; - [[nodiscard]] Pool *_pool_from_parent_block() const noexcept; + [[nodiscard]] Pool *_pool_from_parent_block() noexcept; public: [[nodiscard]] BasicBlock *parent_block() noexcept { return _parent_block; } @@ -153,7 +153,7 @@ class DerivedBlockScopeValue : public DerivedValue, explicit DerivedBlockScopeValue(BasicBlock *block, Args &&...args) noexcept : DerivedValue{std::forward(args)...}, LocalValueBlockMixin{block} {} - [[nodiscard]] Pool *pool() const noexcept final { return _pool_from_parent_block(); } + [[nodiscard]] Pool *pool() noexcept final { return _pool_from_parent_block(); } }; }// namespace luisa::compute::xir diff --git a/src/xir/use.cpp b/src/xir/use.cpp index 75ca21e6a..fec1b1c29 100644 --- a/src/xir/use.cpp +++ b/src/xir/use.cpp @@ -15,8 +15,6 @@ void Use::set_value(Value *value) noexcept { _value = value; } -Pool *Use::pool() const noexcept { - return user()->pool(); -} +Pool *Use::pool() noexcept { return user()->pool(); } }// namespace luisa::compute::xir diff --git a/src/xir/value.cpp b/src/xir/value.cpp index 913ab6d7d..c0d59c3c9 100644 --- a/src/xir/value.cpp +++ b/src/xir/value.cpp @@ -22,7 +22,7 @@ GlobalValueModuleMixin::GlobalValueModuleMixin(Module *module) noexcept : _paren LUISA_DEBUG_ASSERT(_parent_module != nullptr, "Module must not be null."); } -Pool *GlobalValueModuleMixin::_pool_from_parent_module() const noexcept { +Pool *GlobalValueModuleMixin::_pool_from_parent_module() noexcept { return parent_module()->pool(); } @@ -36,7 +36,7 @@ void LocalValueFunctionMixin::_set_parent_function(Function *function) noexcept _parent_function = function; } -Pool *LocalValueFunctionMixin::_pool_from_parent_function() const noexcept { +Pool *LocalValueFunctionMixin::_pool_from_parent_function() noexcept { return parent_function()->pool(); } @@ -58,7 +58,7 @@ void LocalValueBlockMixin::_set_parent_block(BasicBlock *block) noexcept { _parent_block = block; } -Pool *LocalValueBlockMixin::_pool_from_parent_block() const noexcept { +Pool *LocalValueBlockMixin::_pool_from_parent_block() noexcept { return parent_block()->pool(); }