From b4c1995a98fc22a316a53c28b7eacb5240fc3f89 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 12 Apr 2023 09:19:10 -0500 Subject: [PATCH] [Node] Allow alternative root names in ObjectPath::Root() (#14569) * [Node] Allow alternative root names in ObjectPath::Root() Previously, the `ObjectPath` utility allowed tracking of an object's location within a tree-like structure. However, the base of the path structure was hard-coded to be the string ``. For use cases such as `StructuralEqual`, there is a clear root node. However, other cases such as using `ObjectPath` to specify an object's location relative to a known variable, would require using that known variable's name as the root, rather than the hard-coded string ``. This commit adds an optional parameter to provide an alternative name for the root node, to allow for these use cases. * Updated python API, added unit test --- include/tvm/node/object_path.h | 6 ++++-- python/tvm/runtime/object_path.py | 6 ++++-- src/node/object_path.cc | 20 ++++++++++++++++---- tests/python/unittest/test_object_path.py | 10 ++++++++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h index 35f947a68f4f..97a62bfd2d8f 100644 --- a/include/tvm/node/object_path.h +++ b/include/tvm/node/object_path.h @@ -122,7 +122,7 @@ class ObjectPathNode : public Object { class ObjectPath : public ObjectRef { public: /*! \brief Create a path that represents the root object itself. */ - static ObjectPath Root(); + static ObjectPath Root(Optional name = NullOpt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); }; @@ -135,7 +135,9 @@ class ObjectPath : public ObjectRef { class RootPathNode final : public ObjectPathNode { public: - explicit RootPathNode(); + Optional name; + + explicit RootPathNode(Optional name = NullOpt); static constexpr const char* _type_key = "RootPath"; TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index ecca85d53da3..ff223b75998c 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -20,6 +20,8 @@ via attribute access, array indexing etc. """ +from typing import Optional + import tvm._ffi from tvm.runtime import Object from . import _ffi_node_api @@ -52,8 +54,8 @@ def __init__(self) -> None: ) @staticmethod - def root() -> "ObjectPath": - return _ffi_node_api.ObjectPathRoot() + def root(root_name: Optional[str] = None) -> "ObjectPath": + return _ffi_node_api.ObjectPathRoot(root_name) def __eq__(self, other): return _ffi_node_api.ObjectPathEqual(self, other) diff --git a/src/node/object_path.cc b/src/node/object_path.cc index 9c49daa8c376..4d88873e7950 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -197,7 +197,9 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const { // ============== ObjectPath ============== -/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object()); } +/* static */ ObjectPath ObjectPath::Root(Optional name) { + return ObjectPath(make_object(name)); +} TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); @@ -205,11 +207,21 @@ TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); // ----- Root ----- -RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {} +RootPathNode::RootPathNode(Optional name) : ObjectPathNode(nullptr), name(name) {} + +bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const { + const auto* other = static_cast(other_path); -bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } + if (other->name.defined() != name.defined()) { + return false; + } else if (name && other->name) { + return name.value() == other->name.value(); + } else { + return true; + } +} -std::string RootPathNode::LastNodeString() const { return ""; } +std::string RootPathNode::LastNodeString() const { return name.value_or(""); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); diff --git a/tests/python/unittest/test_object_path.py b/tests/python/unittest/test_object_path.py index f849c129df59..3fea5141c745 100644 --- a/tests/python/unittest/test_object_path.py +++ b/tests/python/unittest/test_object_path.py @@ -30,6 +30,16 @@ def test_root_path(): assert root.parent is None +def test_named_root_path(): + root = ObjectPath.root("base_name") + assert isinstance(root, object_path.RootPath) + assert str(root) == "base_name" + assert len(root) == 1 + assert root != ObjectPath.root() + assert root == ObjectPath.root("base_name") + assert root.parent is None + + def test_path_attr(): path = ObjectPath.root().attr("foo") assert isinstance(path, object_path.AttributeAccessPath)