Skip to content

Commit

Permalink
[Node] Allow alternative root names in ObjectPath::Root() (apache#14569)
Browse files Browse the repository at this point in the history
* [Node] Allow alternative root names in ObjectPath::Root()

Previously, the `ObjectPath` utility allowed tracking of an object's
location within a tree-like structure.  However, the base of the path
structure was hard-coded to be the string `<root>`.  For use cases
such as `StructuralEqual`, there is a clear root node.  However, other
cases such as using `ObjectPath` to specify an object's location
relative to a known variable, would require using that known
variable's name as the root, rather than the hard-coded string
`<root>`.

This commit adds an optional parameter to provide an alternative name
for the root node, to allow for these use cases.

* Updated python API, added unit test
  • Loading branch information
Lunderberg authored Apr 12, 2023
1 parent 3ef745c commit b4c1995
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
6 changes: 4 additions & 2 deletions include/tvm/node/object_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ObjectPathNode : public Object {
class ObjectPath : public ObjectRef {
public:
/*! \brief Create a path that represents the root object itself. */
static ObjectPath Root();
static ObjectPath Root(Optional<String> name = NullOpt);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode);
};
Expand All @@ -135,7 +135,9 @@ class ObjectPath : public ObjectRef {

class RootPathNode final : public ObjectPathNode {
public:
explicit RootPathNode();
Optional<String> name;

explicit RootPathNode(Optional<String> name = NullOpt);

static constexpr const char* _type_key = "RootPath";
TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/runtime/object_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
via attribute access, array indexing etc.
"""

from typing import Optional

import tvm._ffi
from tvm.runtime import Object
from . import _ffi_node_api
Expand Down Expand Up @@ -52,8 +54,8 @@ def __init__(self) -> None:
)

@staticmethod
def root() -> "ObjectPath":
return _ffi_node_api.ObjectPathRoot()
def root(root_name: Optional[str] = None) -> "ObjectPath":
return _ffi_node_api.ObjectPathRoot(root_name)

def __eq__(self, other):
return _ffi_node_api.ObjectPathEqual(self, other)
Expand Down
20 changes: 16 additions & 4 deletions src/node/object_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,31 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const {

// ============== ObjectPath ==============

/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object<RootPathNode>()); }
/* static */ ObjectPath ObjectPath::Root(Optional<String> name) {
return ObjectPath(make_object<RootPathNode>(name));
}

TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);

// ============== Individual path classes ==============

// ----- Root -----

RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}
RootPathNode::RootPathNode(Optional<String> name) : ObjectPathNode(nullptr), name(name) {}

bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const {
const auto* other = static_cast<const RootPathNode*>(other_path);

bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }
if (other->name.defined() != name.defined()) {
return false;
} else if (name && other->name) {
return name.value() == other->name.value();
} else {
return true;
}
}

std::string RootPathNode::LastNodeString() const { return "<root>"; }
std::string RootPathNode::LastNodeString() const { return name.value_or("<root>"); }

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_object_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_root_path():
assert root.parent is None


def test_named_root_path():
root = ObjectPath.root("base_name")
assert isinstance(root, object_path.RootPath)
assert str(root) == "base_name"
assert len(root) == 1
assert root != ObjectPath.root()
assert root == ObjectPath.root("base_name")
assert root.parent is None


def test_path_attr():
path = ObjectPath.root().attr("foo")
assert isinstance(path, object_path.AttributeAccessPath)
Expand Down

0 comments on commit b4c1995

Please sign in to comment.