Skip to content

Commit

Permalink
[Object] Implemented .as<T> for ObjectRef param, returns Optional<T> (a…
Browse files Browse the repository at this point in the history
…pache#14522)

* [Object] Implemented .as<T> for ObjectRef param, returns Optional<T>

Prior to this commit, the `ObjectRef::as<T>()` method could be used
for any `T` that inherits from `tvm::Object`, and would return a
`const T*` if the class could be cast to the specified type, or
`nullptr` otherwise.  However, if the
caller needed a `ObjectRef`, they would then need to call
`GetRef<MyObjRef>` to convert from a `const T*`.

This commit extends `ObjectRef::as<T>` to operate on a `T` that
inherits from `tvm::ObjectRef` as well.  In this case, the return type
is `Optional<T>`, returning either an instance of the specified
subclass, or `NullOpt` if the object was not an instance of the
specified subclass.  Example usage of this new conversion, along with
how it relates to existing functionality, is shown below.

```c++
// Unconditionally convert, throwing an exception if the object isn't
// of the specified type.  In contexts where the type of the object is
// unknown, this shouldn't be used.
PrimExpr expr = Downcast<PrimExpr>(obj);

// Protect the Downcast from throwing an exception using IsInstance.
// This avoids the error, but performs the type-checking twice.  In
// addition, it requires the caller to specify both the ObjectRef
// subclass and the Object subclass, even though these usually have a
// 1:1 correspondence.
if (obj->IsInstance<PrimExprNode>()) {
  PrimExpr expr = Downcast<PrimExpr>(obj);
}

// Perform both type-checking and downcasting with the ObjectRef::as()
// method, then use GetRef to convert to an ObjectRef.  This avoids
// double-checking the type, but still requires the caller to
if (const PrimExprNode* ptr = obj.as<PrimExprNode>()) {
  PrimExpr expr = GetRef<PrimExpr>(ptr);
}

// New method introduced by this PR.  The type-checking is only
// performed once, and the Object subclass is inferred from the
// ObjectRef subclass.
if (Optional<PrimExpr> opt = obj.as<PrimExpr>()) {
  PrimExpr expr = opt.value();
}
```

* Use the ObjectRef to Optional<ObjectRefSubclass> where possible

This commit looked for cases where `ObjectRef::as<ObjectSubclass>()`
was used to convert to a `const ObjectSubclass*` followed immediately
by a call to `GetRef<ObjectRefSubclass>()`, and replaced them with a
single call to `ObjectRef::as<ObjectRefSubclass>()`.

* Fixed usage in preprocess.cc

* Fix an updated usage in rolling buffer
  • Loading branch information
Lunderberg authored Apr 11, 2023
1 parent f28fcd1 commit 7766f3c
Show file tree
Hide file tree
Showing 103 changed files with 548 additions and 536 deletions.
9 changes: 9 additions & 0 deletions include/tvm/runtime/container/optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ class Optional : public ObjectRef {
static constexpr bool _type_is_nullable = true;
};

template <typename ObjectRefType, typename>
inline Optional<ObjectRefType> ObjectRef::as() const {
if (auto* ptr = this->as<typename ObjectRefType::ContainerType>()) {
return GetRef<ObjectRefType>(ptr);
} else {
return NullOptType{};
}
}

} // namespace runtime

// expose the functions to the root namespace.
Expand Down
39 changes: 33 additions & 6 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ class ObjectPtr {
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};

// Forward declaration, to prevent circular includes.
template <typename T>
class Optional;

/*! \brief Base class of all object reference */
class ObjectRef {
public:
Expand Down Expand Up @@ -550,20 +554,43 @@ class ObjectRef {
bool unique() const { return data_.unique(); }
/*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_.use_count(); }

/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
*
* The function will return a nullptr if the cast failed.
*
* if (const Add *add = node_ref.As<Add>()) {
* // This is an add node
* }
* \tparam ObjectType the target type, must be a subtype of Object/
* if (const AddNode *ptr = node_ref.as<AddNode>()) {
* // This is an add node
* }
*
* \tparam ObjectType the target type, must be a subtype of Object
*/
template <typename ObjectType>
template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
inline const ObjectType* as() const;

/*!
* \brief Try to downcast the ObjectRef to a
* Optional<T> of the requested type.
*
* The function will return a NullOpt if the cast failed.
*
* if (Optional<Add> opt = node_ref.as<Add>()) {
* // This is an add node
* }
*
* \note While this method is declared in <tvm/runtime/object.h>,
* the implementation is in <tvm/runtime/container/optional.h> to
* prevent circular includes. This additional include file is only
* required in compilation units that uses this method.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
inline Optional<ObjectRefType> as() const;

/*! \brief type indicate the container type. */
using ContainerType = Object;
// Default type properties for the reference class.
Expand Down Expand Up @@ -861,7 +888,7 @@ inline bool Object::IsInstance() const {

inline bool Object::unique() const { return use_count() == 1; }

template <typename ObjectType>
template <typename ObjectType, typename>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
Expand Down
16 changes: 8 additions & 8 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
* \return The transformed SplitExpr.
*/
SplitExpr ToSplitExpr(PrimExpr expr) {
if (const auto* op = expr.as<SplitExprNode>()) {
return GetRef<SplitExpr>(op);
if (auto op = expr.as<SplitExpr>()) {
return op.value();
}
if (const auto* op = expr.as<SumExprNode>()) {
if (op->base == 0 && op->args.size() == 1) return op->args[0];
Expand Down Expand Up @@ -715,8 +715,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
* \return The transformed SumExpr.
*/
SumExpr ToSumExpr(PrimExpr expr) {
if (const auto* op = expr.as<SumExprNode>()) {
return GetRef<SumExpr>(op);
if (auto op = expr.as<SumExpr>()) {
return op.value();
}
ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype();
Expand Down Expand Up @@ -748,8 +748,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) {

if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
} else if (auto op = b.as<SumExpr>()) {
ret.CopyOnWrite()->AddToSelf(op.value(), 1);
} else {
ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1);
}
Expand All @@ -772,8 +772,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) {

if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(-op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
} else if (auto op = b.as<SumExpr>()) {
ret.CopyOnWrite()->AddToSelf(op.value(), -1);
} else {
ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1);
}
Expand Down
5 changes: 3 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, Interval

// internal helper function to get an interval set
IntervalSet ToIntervalSet(IntSet set) {
if (auto* node = set.as<IntervalSetNode>()) {
return GetRef<IntervalSet>(node);
if (auto node = set.as<IntervalSet>()) {
return node.value();
}
DLOG(INFO) << "cannot resolve int set " << set;
return IntervalSet::Everything();
Expand Down Expand Up @@ -379,6 +379,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
IntervalSet min_set = this->Eval(val->min_value);
IntervalSet max_set = this->Eval(val->max_value);
--recur_depth_;

return IntervalSet(min_set->min_value, max_set->max_value);
}

Expand Down
76 changes: 38 additions & 38 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,10 +723,10 @@ class IterMapRewriter : public ExprMutator {
* \return The transformed IterSumExpr.
*/
static IterSumExpr ToIterSumExpr(const PrimExpr& expr) {
if (const auto* op = expr.as<IterSumExprNode>()) {
return GetRef<IterSumExpr>(op);
} else if (const auto* op = expr.as<IterSplitExprNode>()) {
return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
if (auto op = expr.as<IterSumExpr>()) {
return op.value();
} else if (auto op = expr.as<IterSplitExpr>()) {
return IterSumExpr({op.value()}, make_zero(expr->dtype));
} else {
ICHECK(!expr->IsInstance<IterMapExprNode>());
return IterSumExpr({}, expr);
Expand Down Expand Up @@ -1066,14 +1066,15 @@ bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
}
}
// If it is a predicate for a single input iter
if (const auto* var_ptr = iter.as<VarNode>()) {
auto it = input_iters->find(GetRef<Var>(var_ptr));
if (auto opt = iter.as<Var>()) {
auto var = opt.value();
auto it = input_iters->find(var);
if (it != input_iters->end()) {
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
input_iters->Set(var, Range(iter_min, iter_max));
}
} else {
result->emplace_back(iter, lower_bound, upper_bound, 0);
Expand Down Expand Up @@ -1220,10 +1221,10 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {

if (!b->IsInstance<IterMapExprNode>()) {
ret.CopyOnWrite()->base += b;
} else if (const auto* op = b.as<IterSumExprNode>()) {
AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), 1);
} else if (const auto* op = b.as<IterSplitExprNode>()) {
AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), 1);
} else if (auto op = b.as<IterSumExpr>()) {
AddToLhs(ret.CopyOnWrite(), op.value(), 1);
} else if (auto op = b.as<IterSplitExpr>()) {
AddToLhs(ret.CopyOnWrite(), op.value(), 1);
} else {
AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1);
}
Expand Down Expand Up @@ -1255,10 +1256,10 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) {

if (!b->IsInstance<IterMapExprNode>()) {
ret.CopyOnWrite()->base -= b;
} else if (const auto* op = b.as<IterSumExprNode>()) {
AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), -1);
} else if (const auto* op = b.as<IterSplitExprNode>()) {
AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), -1);
} else if (auto op = b.as<IterSumExpr>()) {
AddToLhs(ret.CopyOnWrite(), op.value(), -1);
} else if (auto op = b.as<IterSplitExpr>()) {
AddToLhs(ret.CopyOnWrite(), op.value(), -1);
} else {
AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1);
}
Expand Down Expand Up @@ -1692,10 +1693,10 @@ class IterMapToExprNormalizer : public ExprMutator {
private:
/*! \brief Override VisitExpr for iter expr type processing */
PrimExpr VisitExpr(const PrimExpr& expr) override {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
if (auto op = expr.as<IterSplitExpr>()) {
return ConvertIterSplitExpr(op.value());
} else if (auto op = expr.as<IterSumExpr>()) {
return ConvertIterSumExpr(op.value());
} else {
return ExprMutator::VisitExpr(expr);
}
Expand All @@ -1712,10 +1713,10 @@ class IterMapToExprNormalizer : public ExprMutator {

PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
PrimExpr source;
if (const auto* op = expr->source->source.as<VarNode>()) {
source = GetRef<Var>(op);
} else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
if (auto opt = expr->source->source.as<Var>()) {
source = opt.value();
} else if (auto opt = expr->source->source.as<IterSumExpr>()) {
source = ConvertIterSumExpr(opt.value());
} else {
source = VisitExpr(expr->source->source);
}
Expand Down Expand Up @@ -1854,10 +1855,10 @@ class SubspaceDivider {

private:
static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return GetRef<IterSplitExpr>(op);
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
if (auto op = expr.as<IterSplitExpr>()) {
return op.value();
} else if (auto op = expr.as<IterSumExpr>()) {
return IterSplitExpr(IterMark(op.value(), extent));
} else {
LOG(FATAL) << "Unknown IterMapExpr type";
}
Expand Down Expand Up @@ -1946,10 +1947,10 @@ class SubspaceDivider {
private:
DivisionResult AddBase(DivisionResult division, PrimExpr base) {
DivisionResult res = division;
if (const auto* op = division.inner.as<IterSplitExprNode>()) {
res.inner = IterSumExpr({GetRef<IterSplitExpr>(op)}, base);
} else if (const auto* op = division.inner.as<IterSumExprNode>()) {
const auto& expr = GetRef<IterSumExpr>(op);
if (auto op = division.inner.as<IterSplitExpr>()) {
res.inner = IterSumExpr({op.value()}, base);
} else if (auto op = division.inner.as<IterSumExpr>()) {
const auto& expr = op.value();
res.inner = IterSumExpr(expr->args, expr->base + base);
}
return res;
Expand All @@ -1976,9 +1977,9 @@ class SubspaceDivider {
return it->second;
}
const Array<IterSplitExpr>& splits = collector_.mark2splits_.at(expr->source);
if (const auto* iter_ptr = expr->source->source.as<VarNode>()) {
if (auto iter_ptr = expr->source->source.as<Var>()) {
// source is input_iter
bool inner = sub_iters_.count(GetRef<Var>(iter_ptr));
bool inner = sub_iters_.count(iter_ptr.value());
for (const IterSplitExpr& split : splits) {
if (inner) {
// 0*E(split)+split
Expand All @@ -1988,7 +1989,7 @@ class SubspaceDivider {
split_map_.emplace(split, DivisionResult::Outer(split, split->extent));
}
}
} else if (const auto* iter_ptr = expr->source->source.as<IterSumExprNode>()) {
} else if (auto iter_ptr = expr->source->source.as<IterSumExpr>()) {
// source = Y*E+X
// splits = [s1, s2, ..., sn]
// we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y)
Expand All @@ -2001,8 +2002,7 @@ class SubspaceDivider {
// Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2],
// where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2.
// It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X.
DivisionResult mark_division =
DivideIterSumExpr(GetRef<IterSumExpr>(iter_ptr), expr->source->extent);
DivisionResult mark_division = DivideIterSumExpr(iter_ptr.value(), expr->source->extent);
if (splits.size() == 1) {
return mark_division;
}
Expand Down Expand Up @@ -2186,8 +2186,8 @@ class InverseAffineIterMapTransformer {
} else {
const auto* split_expr = expr.as<IterSplitExprNode>();
ICHECK(split_expr);
if (const auto* source = split_expr->source->source.as<IterMapExprNode>()) {
fvisit(GetRef<IterMapExpr>(source));
if (auto source = split_expr->source->source.as<IterMapExpr>()) {
fvisit(source.value());
}
}
post_dfs_order.push_back(expr.get());
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs,
stream << "def " << name << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i) stream << ", ";
if (auto tensor = inputs[i].as<TensorNode>()) {
stream << GetTensorID(GetRef<Tensor>(tensor));
if (auto tensor = inputs[i].as<Tensor>()) {
stream << GetTensorID(tensor.value());
} else {
auto var = inputs[i].as<VarNode>();
ICHECK(var) << "Input should either be a tensor or a variable!";
Expand Down
7 changes: 3 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ void GetBinds(const Array<ObjectRef>& args, bool compact,
*out_binds = binds;

for (const ObjectRef& x : args) {
if (const te::TensorNode* tensor_node = x.as<te::TensorNode>()) {
te::Tensor x_ref = GetRef<te::Tensor>(tensor_node);
if (auto tensor_node = x.as<te::Tensor>()) {
te::Tensor x_ref = tensor_node.value();
if (out_binds->find(x_ref) == out_binds->end()) {
tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype,
x_ref->op->name, -1, 0, compact);
Expand Down Expand Up @@ -183,8 +183,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {

CHECK_GE(phase_num_val, 0);

const tvm::transform::PassNode* pass_node = phase_pass[1].as<tvm::transform::PassNode>();
tvm::transform::Pass pass = GetRef<tvm::transform::Pass>(pass_node);
auto pass = Downcast<tvm::transform::Pass>(phase_pass[1]);
// Copy the pass into the correct phase
if (phase_num_val == 0) {
user_lower_phase0.push_back(pass);
Expand Down
12 changes: 6 additions & 6 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value))

PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
return GetRef<tir::IterVar>(ptr)->var;
if (const auto* ptr = ref.as<tir::IterVarNode>()) {
return ptr->var;
}
if (auto* ptr = ref.as<te::TensorNode>()) {
return GetRef<te::Tensor>(ptr)();
if (auto opt = ref.as<te::Tensor>()) {
return opt.value()();
}
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImm(GetRef<runtime::String>(ptr));
if (auto opt = ref.as<runtime::String>()) {
return tir::StringImm(opt.value());
}
if (const auto* buffer_region = ref.as<tir::BufferRegionNode>()) {
Array<PrimExpr> indices;
Expand Down
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(

// All global definitions must be functions.
BaseFunc func;
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto func_node = expr.as<BaseFunc>()) {
func = func_node.value();
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
// Function literal has been annotated with it's required global symbol.
gv_name = opt.value();
Expand Down
8 changes: 4 additions & 4 deletions src/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
changed = changed || !new_type_param.same_as(type_param);
if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
type_params.push_back(GetRef<TypeVar>(tin));
if (auto tin = new_type_param.as<TypeVar>()) {
type_params.push_back(tin.value());
} else {
LOG(FATAL) << new_type_param;
}
Expand All @@ -126,8 +126,8 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
changed = changed || !new_type_cs.same_as(type_cs);
if (const TypeConstraintNode* tin = new_type_cs.as<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
if (auto tin = new_type_cs.as<TypeConstraint>()) {
type_constraints.push_back(tin.value());
} else {
LOG(FATAL) << new_type_cs;
}
Expand Down
Loading

0 comments on commit 7766f3c

Please sign in to comment.