diff --git a/include/tvm/runtime/container/optional.h b/include/tvm/runtime/container/optional.h index 9961d5eeba0b..024986a8e037 100644 --- a/include/tvm/runtime/container/optional.h +++ b/include/tvm/runtime/container/optional.h @@ -153,6 +153,15 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; +template +inline Optional ObjectRef::as() const { + if (auto* ptr = this->as()) { + return GetRef(ptr); + } else { + return NullOptType{}; + } +} + } // namespace runtime // expose the functions to the root namespace. diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index e57167919823..b10aff96a116 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -507,6 +507,10 @@ class ObjectPtr { friend ObjectPtr GetObjectPtr(ObjType* ptr); }; +// Forward declaration, to prevent circular includes. +template +class Optional; + /*! \brief Base class of all object reference */ class ObjectRef { public: @@ -550,20 +554,43 @@ class ObjectRef { bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ int use_count() const { return data_.use_count(); } + /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. * * The function will return a nullptr if the cast failed. * - * if (const Add *add = node_ref.As()) { - * // This is an add node - * } - * \tparam ObjectType the target type, must be a subtype of Object/ + * if (const AddNode *ptr = node_ref.as()) { + * // This is an add node + * } + * + * \tparam ObjectType the target type, must be a subtype of Object */ - template + template >> inline const ObjectType* as() const; + /*! + * \brief Try to downcast the ObjectRef to a + * Optional of the requested type. + * + * The function will return a NullOpt if the cast failed. + * + * if (Optional opt = node_ref.as()) { + * // This is an add node + * } + * + * \note While this method is declared in , + * the implementation is in to + * prevent circular includes. This additional include file is only + * required in compilation units that uses this method. + * + * \tparam ObjectRefType the target type, must be a subtype of ObjectRef + */ + template >> + inline Optional as() const; + /*! \brief type indicate the container type. */ using ContainerType = Object; // Default type properties for the reference class. @@ -861,7 +888,7 @@ inline bool Object::IsInstance() const { inline bool Object::unique() const { return use_count() == 1; } -template +template inline const ObjectType* ObjectRef::as() const { if (data_ != nullptr && data_->IsInstance()) { return static_cast(data_.get()); diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 14c91934d3b2..d1a69f10e20c 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -672,8 +672,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return The transformed SplitExpr. */ SplitExpr ToSplitExpr(PrimExpr expr) { - if (const auto* op = expr.as()) { - return GetRef(op); + if (auto op = expr.as()) { + return op.value(); } if (const auto* op = expr.as()) { if (op->base == 0 && op->args.size() == 1) return op->args[0]; @@ -715,8 +715,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return The transformed SumExpr. */ SumExpr ToSumExpr(PrimExpr expr) { - if (const auto* op = expr.as()) { - return GetRef(op); + if (auto op = expr.as()) { + return op.value(); } ObjectPtr n = make_object(); n->dtype = expr.dtype(); @@ -748,8 +748,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(op->value); - } else if (const auto* op = b.as()) { - ret.CopyOnWrite()->AddToSelf(GetRef(op), 1); + } else if (auto op = b.as()) { + ret.CopyOnWrite()->AddToSelf(op.value(), 1); } else { ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1); } @@ -772,8 +772,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(-op->value); - } else if (const auto* op = b.as()) { - ret.CopyOnWrite()->AddToSelf(GetRef(op), -1); + } else if (auto op = b.as()) { + ret.CopyOnWrite()->AddToSelf(op.value(), -1); } else { ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1); } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 1ad182aa8351..cf93a481c226 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -349,8 +349,8 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval // internal helper function to get an interval set IntervalSet ToIntervalSet(IntSet set) { - if (auto* node = set.as()) { - return GetRef(node); + if (auto node = set.as()) { + return node.value(); } DLOG(INFO) << "cannot resolve int set " << set; return IntervalSet::Everything(); @@ -379,6 +379,7 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet min_set = this->Eval(val->min_value); IntervalSet max_set = this->Eval(val->max_value); --recur_depth_; + return IntervalSet(min_set->min_value, max_set->max_value); } diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 05af5b40702d..3f336b11be94 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -723,10 +723,10 @@ class IterMapRewriter : public ExprMutator { * \return The transformed IterSumExpr. */ static IterSumExpr ToIterSumExpr(const PrimExpr& expr) { - if (const auto* op = expr.as()) { - return GetRef(op); - } else if (const auto* op = expr.as()) { - return IterSumExpr({GetRef(op)}, make_zero(expr->dtype)); + if (auto op = expr.as()) { + return op.value(); + } else if (auto op = expr.as()) { + return IterSumExpr({op.value()}, make_zero(expr->dtype)); } else { ICHECK(!expr->IsInstance()); return IterSumExpr({}, expr); @@ -1066,14 +1066,15 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, } } // If it is a predicate for a single input iter - if (const auto* var_ptr = iter.as()) { - auto it = input_iters->find(GetRef(var_ptr)); + if (auto opt = iter.as()) { + auto var = opt.value(); + auto it = input_iters->find(var); if (it != input_iters->end()) { PrimExpr iter_min = (*it).second->min; PrimExpr iter_max = (*it).second->min + (*it).second->extent; if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); - input_iters->Set(GetRef(var_ptr), Range(iter_min, iter_max)); + input_iters->Set(var, Range(iter_min, iter_max)); } } else { result->emplace_back(iter, lower_bound, upper_bound, 0); @@ -1220,10 +1221,10 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { if (!b->IsInstance()) { ret.CopyOnWrite()->base += b; - } else if (const auto* op = b.as()) { - AddToLhs(ret.CopyOnWrite(), GetRef(op), 1); - } else if (const auto* op = b.as()) { - AddToLhs(ret.CopyOnWrite(), GetRef(op), 1); + } else if (auto op = b.as()) { + AddToLhs(ret.CopyOnWrite(), op.value(), 1); + } else if (auto op = b.as()) { + AddToLhs(ret.CopyOnWrite(), op.value(), 1); } else { AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1); } @@ -1255,10 +1256,10 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { if (!b->IsInstance()) { ret.CopyOnWrite()->base -= b; - } else if (const auto* op = b.as()) { - AddToLhs(ret.CopyOnWrite(), GetRef(op), -1); - } else if (const auto* op = b.as()) { - AddToLhs(ret.CopyOnWrite(), GetRef(op), -1); + } else if (auto op = b.as()) { + AddToLhs(ret.CopyOnWrite(), op.value(), -1); + } else if (auto op = b.as()) { + AddToLhs(ret.CopyOnWrite(), op.value(), -1); } else { AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1); } @@ -1692,10 +1693,10 @@ class IterMapToExprNormalizer : public ExprMutator { private: /*! \brief Override VisitExpr for iter expr type processing */ PrimExpr VisitExpr(const PrimExpr& expr) override { - if (const auto* op = expr.as()) { - return ConvertIterSplitExpr(GetRef(op)); - } else if (const auto* op = expr.as()) { - return ConvertIterSumExpr(GetRef(op)); + if (auto op = expr.as()) { + return ConvertIterSplitExpr(op.value()); + } else if (auto op = expr.as()) { + return ConvertIterSumExpr(op.value()); } else { return ExprMutator::VisitExpr(expr); } @@ -1712,10 +1713,10 @@ class IterMapToExprNormalizer : public ExprMutator { PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) { PrimExpr source; - if (const auto* op = expr->source->source.as()) { - source = GetRef(op); - } else if (const auto* op = expr->source->source.as()) { - source = ConvertIterSumExpr(GetRef(op)); + if (auto opt = expr->source->source.as()) { + source = opt.value(); + } else if (auto opt = expr->source->source.as()) { + source = ConvertIterSumExpr(opt.value()); } else { source = VisitExpr(expr->source->source); } @@ -1854,10 +1855,10 @@ class SubspaceDivider { private: static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) { - if (const auto* op = expr.as()) { - return GetRef(op); - } else if (const auto* op = expr.as()) { - return IterSplitExpr(IterMark(GetRef(op), extent)); + if (auto op = expr.as()) { + return op.value(); + } else if (auto op = expr.as()) { + return IterSplitExpr(IterMark(op.value(), extent)); } else { LOG(FATAL) << "Unknown IterMapExpr type"; } @@ -1946,10 +1947,10 @@ class SubspaceDivider { private: DivisionResult AddBase(DivisionResult division, PrimExpr base) { DivisionResult res = division; - if (const auto* op = division.inner.as()) { - res.inner = IterSumExpr({GetRef(op)}, base); - } else if (const auto* op = division.inner.as()) { - const auto& expr = GetRef(op); + if (auto op = division.inner.as()) { + res.inner = IterSumExpr({op.value()}, base); + } else if (auto op = division.inner.as()) { + const auto& expr = op.value(); res.inner = IterSumExpr(expr->args, expr->base + base); } return res; @@ -1976,9 +1977,9 @@ class SubspaceDivider { return it->second; } const Array& splits = collector_.mark2splits_.at(expr->source); - if (const auto* iter_ptr = expr->source->source.as()) { + if (auto iter_ptr = expr->source->source.as()) { // source is input_iter - bool inner = sub_iters_.count(GetRef(iter_ptr)); + bool inner = sub_iters_.count(iter_ptr.value()); for (const IterSplitExpr& split : splits) { if (inner) { // 0*E(split)+split @@ -1988,7 +1989,7 @@ class SubspaceDivider { split_map_.emplace(split, DivisionResult::Outer(split, split->extent)); } } - } else if (const auto* iter_ptr = expr->source->source.as()) { + } else if (auto iter_ptr = expr->source->source.as()) { // source = Y*E+X // splits = [s1, s2, ..., sn] // we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y) @@ -2001,8 +2002,7 @@ class SubspaceDivider { // Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2], // where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2. // It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X. - DivisionResult mark_division = - DivideIterSumExpr(GetRef(iter_ptr), expr->source->extent); + DivisionResult mark_division = DivideIterSumExpr(iter_ptr.value(), expr->source->extent); if (splits.size() == 1) { return mark_division; } @@ -2186,8 +2186,8 @@ class InverseAffineIterMapTransformer { } else { const auto* split_expr = expr.as(); ICHECK(split_expr); - if (const auto* source = split_expr->source->source.as()) { - fvisit(GetRef(source)); + if (auto source = split_expr->source->source.as()) { + fvisit(source.value()); } } post_dfs_order.push_back(expr.get()); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index bde64887856d..d3f50c1c2459 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -471,8 +471,8 @@ void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, stream << "def " << name << "("; for (size_t i = 0; i < inputs.size(); ++i) { if (i) stream << ", "; - if (auto tensor = inputs[i].as()) { - stream << GetTensorID(GetRef(tensor)); + if (auto tensor = inputs[i].as()) { + stream << GetTensorID(tensor.value()); } else { auto var = inputs[i].as(); ICHECK(var) << "Input should either be a tensor or a variable!"; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 569864a29edb..1962b9ab3b2d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -97,8 +97,8 @@ void GetBinds(const Array& args, bool compact, *out_binds = binds; for (const ObjectRef& x : args) { - if (const te::TensorNode* tensor_node = x.as()) { - te::Tensor x_ref = GetRef(tensor_node); + if (auto tensor_node = x.as()) { + te::Tensor x_ref = tensor_node.value(); if (out_binds->find(x_ref) == out_binds->end()) { tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact); @@ -183,8 +183,7 @@ Array CreatePassList(bool disable_loop_partition) { CHECK_GE(phase_num_val, 0); - const tvm::transform::PassNode* pass_node = phase_pass[1].as(); - tvm::transform::Pass pass = GetRef(pass_node); + auto pass = Downcast(phase_pass[1]); // Copy the pass into the correct phase if (phase_num_val == 0) { user_lower_phase0.push_back(pass); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 050d9b87a856..0e09568f158d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,14 +38,14 @@ PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; - if (auto* ptr = ref.as()) { - return GetRef(ptr)->var; + if (const auto* ptr = ref.as()) { + return ptr->var; } - if (auto* ptr = ref.as()) { - return GetRef(ptr)(); + if (auto opt = ref.as()) { + return opt.value()(); } - if (auto* ptr = ref.as()) { - return tir::StringImm(GetRef(ptr)); + if (auto opt = ref.as()) { + return tir::StringImm(opt.value()); } if (const auto* buffer_region = ref.as()) { Array indices; diff --git a/src/ir/module.cc b/src/ir/module.cc index 77316f55ed04..6151e271620c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -322,8 +322,8 @@ std::pair IRModule::FromExprInContext( // All global definitions must be functions. BaseFunc func; - if (auto* func_node = expr.as()) { - func = GetRef(func_node); + if (auto func_node = expr.as()) { + func = func_node.value(); if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 36838b62aabc..4a69c64fbd3b 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -115,8 +115,8 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { for (auto type_param : op->type_params) { auto new_type_param = VisitType(type_param); changed = changed || !new_type_param.same_as(type_param); - if (const TypeVarNode* tin = new_type_param.as()) { - type_params.push_back(GetRef(tin)); + if (auto tin = new_type_param.as()) { + type_params.push_back(tin.value()); } else { LOG(FATAL) << new_type_param; } @@ -126,8 +126,8 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { for (auto type_cs : op->type_constraints) { auto new_type_cs = VisitType(type_cs); changed = changed || !new_type_cs.same_as(type_cs); - if (const TypeConstraintNode* tin = new_type_cs.as()) { - type_constraints.push_back(GetRef(tin)); + if (auto tin = new_type_cs.as()) { + type_constraints.push_back(tin.value()); } else { LOG(FATAL) << new_type_cs; } diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 389c69fe9c8b..416753871244 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -58,8 +58,8 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { std::vector> key_values; key_values.reserve(n); for (const auto& kv : *dict) { - if (const auto* k = kv.first.as()) { - key_values.emplace_back(GetRef(k), kv.second); + if (auto key = kv.first.as()) { + key_values.emplace_back(key.value(), kv.second); } else { LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " << kv.first->GetTypeKey(); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 3240496afe78..a19e6ea3fe23 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -131,8 +131,8 @@ class VerifyGPUCodeNode : public PostprocNode { bool Verify(const IRModule& mod) const { for (const auto& kv : mod->functions) { - if (const auto* prim_func = kv.second.as()) { - if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + if (auto prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(prim_func.value(), this->target_constraints_)) { return false; } } diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index a6b577de9acc..46bc7486e1df 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -37,8 +37,8 @@ class VerifyVTCMLimitNode : public PostprocNode { bool Verify(const IRModule& mod) const { for (const auto& kv : mod->functions) { - if (const auto* prim_func = kv.second.as()) { - if (!tir::VerifyVTCMLimit(GetRef(prim_func), vtcm_capacity)) { + if (auto prim_func = kv.second.as()) { + if (!tir::VerifyVTCMLimit(prim_func.value(), vtcm_capacity)) { return false; } } diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 48fbc82aba02..d1f42ca79dc4 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -51,15 +51,15 @@ class ScheduleFnNode : public SpaceGeneratorNode { return {sch}; } ObjectRef obj = rv; - if (const auto* sch = obj.as()) { - return {GetRef(sch)}; + if (auto sch = obj.as()) { + return {sch.value()}; } if (const auto* arr = obj.as()) { Array result; result.reserve(arr->size()); for (const ObjectRef& obj : *arr) { - if (const auto* sch = obj.as()) { - result.push_back(GetRef(sch)); + if (auto sch = obj.as()) { + result.push_back(sch.value()); } else { LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " "List[Schedule], but got: " diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index d12ec7b98c6f..9d1041f56551 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -44,10 +44,9 @@ CallGraph::CallGraph(IRModule module) { n->module = std::move(module); auto gvar_funcs = n->module->functions; for (const auto& it : gvar_funcs) { - if (const auto* fn = it.second.as()) { - auto func = GetRef(fn); + if (auto func = it.second.as()) { // Add the global function to gradually build up the call graph. - n->AddToCallGraph(it.first, func); + n->AddToCallGraph(it.first, func.value()); } } data_ = std::move(n); @@ -76,9 +75,8 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { LookupGlobalVar(Downcast(props.attrs.metadata["prim_shape_fn_var"])); cg_node->AddCalledGlobal(callee_cg_node); } - } else if (const auto* global_var_node = expr.as()) { - auto callee = GetRef(global_var_node); - CallGraphEntry* callee_cg_node = LookupGlobalVar(callee); + } else if (auto callee = expr.as()) { + CallGraphEntry* callee_cg_node = LookupGlobalVar(callee.value()); cg_node->AddCalledGlobal(callee_cg_node); } }); diff --git a/src/relay/analysis/get_calibration_data.cc b/src/relay/analysis/get_calibration_data.cc index 12bab1e38ddd..0d99e0a9ecad 100644 --- a/src/relay/analysis/get_calibration_data.cc +++ b/src/relay/analysis/get_calibration_data.cc @@ -93,10 +93,10 @@ IRModule GetCalibrateModule(IRModule module) { // module is mutable, hence, we make a copy of it. module.CopyOnWrite(); for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + if (auto opt = pair.second.as()) { // we only collect the outputs for main function if (pair.first->name_hint == "main") { + auto func = opt.value(); Collector collector(module); PostOrderRewrite(func->body, &collector); auto new_outputs = collector.GetNewOutputs(); @@ -108,8 +108,8 @@ IRModule GetCalibrateModule(IRModule module) { } // reset the attribute of functions for running graph executor for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + if (auto opt = pair.second.as()) { + auto func = opt.value(); if (func->GetAttr(attr::kCompiler)) { // we need to inline the functions in order to run grpah runtime func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1)); @@ -179,10 +179,9 @@ Map> GetCalibrateOutputMap(const IRModule& module) { size_t offset = 0; auto glob_funcs = module->functions; for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { + if (const auto* func = pair.second.as()) { if (pair.first->name_hint == "main") { OutputMapper output_mapper(&output_map, module, &offset); - auto func = GetRef(fn); PostOrderRewrite(func->body, &output_mapper); } } diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 05d5b36e3614..9f92ebaa8b47 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -169,10 +169,10 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Patt // Returns a list of all possible expansions. Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, const IRModule& mod) { - if (auto clause_ctor = clause_pat.as()) { - return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); - } else if (auto clause_tup = clause_pat.as()) { - return ExpandWildcardsTuple(GetRef(clause_tup), cand, mod); + if (auto clause_ctor = clause_pat.as()) { + return ExpandWildcardsConstructor(clause_ctor.value(), cand, mod); + } else if (auto clause_tup = clause_pat.as()) { + return ExpandWildcardsTuple(clause_tup.value(), cand, mod); } else { return {cand}; } diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index d40eb8a17c06..b6af97707159 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -160,7 +160,7 @@ class TypeSolver::Unifier : public TypeFunctor { // default: unify only if structural-equal Type VisitTypeDefault_(const Object* op, const Type& tn) final { ObjectRef nr = GetRef(op); - Type t1 = GetRef(nr.as()); + Type t1 = Downcast(nr); if (!tvm::StructuralEqual()(t1, tn)) { return Type(nullptr); } @@ -405,7 +405,7 @@ class TypeSolver::Propagator : public TypeFunctor { void VisitTypeDefault_(const Object* op) override { ObjectRef nr = GetRef(op); - Type t = GetRef(nr.as()); + Type t = Downcast(nr); UpdateRelSet(t); } @@ -489,7 +489,7 @@ class TypeSolver::Merger : public TypeFunctor { void VisitTypeDefault_(const Object* op) override { ObjectRef nr = GetRef(op); - Type t = GetRef(nr.as()); + Type t = Downcast(nr); TransferLinks(t); } diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc index 001d7635e786..8e7ab68cbac9 100644 --- a/src/relay/backend/annotate_used_memory.cc +++ b/src/relay/backend/annotate_used_memory.cc @@ -185,9 +185,8 @@ class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { * \brief Check if a call is a primitive function callsite. */ bool CheckPrimitiveFunctionCall(const Call& callsite) { - if (const auto* var_node = callsite->op.as()) { - Var var = GetRef(var_node); - if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) { + if (auto var = callsite->op.as()) { + if (let_bound_prim_func_.find(var.value()) != let_bound_prim_func_.end()) { return true; } } diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index fb13e8b66e5d..5688d22c1ba3 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -221,7 +221,7 @@ class AOTMainLowerer : public MixedModeVisitor { IRModule lowered_mod = GetRef(mod.CopyOnWrite()); auto lowered_main = lowered_mod->Lookup("main"); - auto lowered_main_func = GetRef(lowered_main.as()); + auto lowered_main_func = Downcast(lowered_main); // Assign StorageInfo to all the Relay exprs and get the return SIDs std::tie(expr_storage_map_, return_sid_) = CreateStorage(lowered_main_func); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 8f7098c24aea..77765bacd6e0 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1103,7 +1103,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); } auto lowered_main = lowered_mod->Lookup("main"); - auto lowered_main_func = GetRef(lowered_main.as()); + auto lowered_main_func = Downcast(lowered_main); // Post-lowering storage map for writing main func AOTOnDemandAllocator final_aot_allocator; diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index f82014d5d1f5..1ce757a62fa9 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -186,8 +186,8 @@ class ExtractConstantsMutator : public MixedModeMutator { // Since the constants are extracted from partitioned functions // a new call to global function is needed - if (auto* glob_var_node = post_call->op.as()) { - auto glob_var = GetRef(glob_var_node); + if (auto opt = post_call->op.as()) { + auto glob_var = opt.value(); auto glob_func = Downcast(mod_->Lookup(glob_var)); auto new_glob_func = VisitExpr(glob_func); if (!new_glob_func.same_as(glob_func)) { @@ -199,8 +199,8 @@ class ExtractConstantsMutator : public MixedModeMutator { // Since the constants are extracted from the local partitioned functions // a new call to local function is needed - if (auto* func_node = call->op.as()) { - Function func = GetRef(func_node); + if (auto opt = call->op.as()) { + Function func = opt.value(); auto new_func = VisitExpr(func); Array new_args = CreateNewCallArgsFromExtractedConstants(GetRef(post_call), func); final_call = Call(new_func, new_args); diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 73d479e6944e..33547f4bd85d 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -404,19 +404,19 @@ class RelayToTIRVisitor : public MixedModeMutator { void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) { Call clip, pool; - Call final_call = GetRef(expr.as()); - Op final_op = GetRef(final_call->op.as()); + Call final_call = Downcast(expr); + Op final_op = Downcast(final_call->op); if (final_op->name == "clip") { clip = final_call; - Call clip_input = GetRef(clip->args[0].as()); - Op clip_input_op = GetRef(clip_input->op.as()); + Call clip_input = Downcast(clip->args[0]); + Op clip_input_op = Downcast(clip_input->op); if (clip_input_op->name == "cast") { - pool = GetRef(clip_input->args[0].as()); + pool = Downcast(clip_input->args[0]); } else { // max_pool2d pool = clip_input; } } else if (final_op->name == "cast") { - pool = GetRef(final_call->args[0].as()); + pool = Downcast(final_call->args[0]); } else { // max_pool2d pool = final_call; } @@ -556,11 +556,11 @@ class RelayToTIRVisitor : public MixedModeMutator { BinaryElementwiseClipPattern ParseBinaryElementwiseOpClipPattern(const Expr& expr) { BinaryElementwiseClipPattern pattern; - Call final_call = GetRef(expr.as()); + Call final_call = Downcast(expr); const OpNode* final_op = final_call->op.as(); if (final_op->name == "clip") { pattern.clip_op = final_call; - pattern.binary_op = GetRef(final_call->args[0].as()); + pattern.binary_op = Downcast(final_call->args[0]); } else { pattern.binary_op = final_call; pattern.clip_op = Optional{nullptr}; diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc index f64f485bfda2..6180fa85160f 100644 --- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -72,8 +72,8 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { final_call = ReplaceScalarWithTensorVariable(GetRef(call)); } - if (auto* glob_var_node = call->op.as()) { - GlobalVar global_var = GetRef(glob_var_node); + if (auto opt = call->op.as()) { + GlobalVar global_var = opt.value(); Function func = Downcast(mod_->Lookup(global_var)); auto new_body = VisitExpr(func->body); if (new_body.same_as(func->body)) { @@ -87,9 +87,8 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { } // Substitute scalar constant with tensor constant in the call to composite function. - if (auto* func_node = call->op.as()) { - Function func = GetRef(func_node); - final_call = ReplaceScalarWithTensorConstant(GetRef(call), func); + if (auto func = call->op.as()) { + final_call = ReplaceScalarWithTensorConstant(GetRef(call), func.value()); } return final_call; diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index a622f96c81da..f35d4c6d48b2 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -191,7 +191,7 @@ class RemoveRedundantIdentities : public MixedModeMutator { } if (const auto* parent_callnode = current_arg.as()) { - if (const auto* parent_op = parent_callnode->op.as()) { + if (auto parent_op = parent_callnode->op.as()) { Call parent_call = GetRef(parent_callnode); if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call) && CheckIdentityBetweenTransformOperations(call, parent_call)) { diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index a0e0ac772fb0..23a873b2d392 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -189,8 +189,8 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr Rewrite_(const CallNode* call, const Expr& post) final { auto post_call = Downcast(post); - if (auto glb_var_node = post_call->op.as()) { - auto glb_var = GetRef(glb_var_node); + if (auto optional_glb_var = post_call->op.as()) { + auto glb_var = optional_glb_var.value(); auto func = Downcast(module_->functions[glb_var]); // If the number of inputs and output are 1 --> no need to do anything @@ -233,9 +233,9 @@ class ExternalFuncIOHandler : public ExprRewriter { IRModule PreprocessExternalFuncIO_(const IRModule& module) { ExternalFuncIOHandler ex_func_io_handle(module); - auto func = GetRef(module->Lookup("main").as()); + auto func = Downcast(module->Lookup("main")); auto preprocessed = PostOrderRewrite(func, &ex_func_io_handle); - module->Update(module->GetGlobalVar("main"), GetRef(preprocessed.as())); + module->Update(module->GetGlobalVar("main"), Downcast(preprocessed)); return module; } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index ad2b06695cc1..2b037181653c 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -68,7 +68,7 @@ class ConvertAddToSubtract : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - Function main = GetRef(ir_module_->Lookup(main_global_var).as()); + Function main = Downcast(ir_module_->Lookup(main_global_var)); Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); @@ -212,17 +212,18 @@ class ConvertAddToSubtract : public MixedModeMutator { return nullptr; } return function_node; - } else if (const auto* global_var_node = expr.as()) { - return AsLowerableFunction(ir_module_->Lookup(GetRef(global_var_node))); + } else if (auto global_var_node = expr.as()) { + return AsLowerableFunction(ir_module_->Lookup(global_var_node.value())); } else { return nullptr; } } const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) { - if (const auto* global_var_node = expr.as()) { - if (ir_module_->Lookup(GetRef(global_var_node)).as()) { - return global_var_node; + if (auto opt = expr.as()) { + auto global_var = opt.value(); + if (ir_module_->Lookup(global_var).as()) { + return global_var.get(); } } return nullptr; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index e6c5ac0d6ef3..865b5616edab 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -69,8 +69,8 @@ struct PairHash { // Analogue of FlattenTupleType for runtime ADT vs NDArray values. // TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? void FlattenADTAux(const ObjectRef& object_ref, std::vector* out) { - if (const NDArray::ContainerType* ndarray = object_ref.as()) { - out->push_back(GetRef(ndarray)); + if (auto ndarray = object_ref.as()) { + out->push_back(ndarray.value()); } else if (const ADTObj* adt = object_ref.as()) { for (size_t i = 0; i < adt->size; ++i) { FlattenADTAux((*adt)[i], out); @@ -785,9 +785,8 @@ class Interpreter : public ExprFunctor, // Now we just evaluate and expect to find a closure. // TODO(@electriclilies): How should call_lowered behave with closures? ObjectRef fn_val = Eval(call_node->op); - if (const InterpreterClosureObj* closure_node = fn_val.as()) { - auto closure = GetRef(closure_node); - return Invoke(closure, args); + if (auto closure = fn_val.as()) { + return Invoke(closure.value(), args); } else if (const RecClosureObj* closure_node = fn_val.as()) { return Invoke(closure_node->clos, args, closure_node->bind); } else { @@ -799,8 +798,8 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const LetNode* let) final { - if (auto func = let->value.as()) { - auto clo = MakeClosure(GetRef(func), let->var); + if (auto func = let->value.as()) { + auto clo = MakeClosure(func.value(), let->var); this->extend(let->var, clo); } else { auto value = Eval(let->value); @@ -1067,9 +1066,8 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // Step 2: Evaluate target function to a closure. // ObjectRef object_ref = intrp->Eval(expr_to_eval); - if (const InterpreterClosureObj* closure_obj = object_ref.as()) { - InterpreterClosure closure = GetRef(closure_obj); - ICHECK(closure.defined()); + if (auto opt = object_ref.as()) { + InterpreterClosure closure = opt.value(); ICHECK(closure->func.defined()); return TypedPackedFunc)>([intrp, closure](Array args) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index ce47be361e23..816595474909 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -651,8 +651,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { BaseFunc base_func = module_->Lookup(GetRef(global_var_node)); return ResolveToPrimitive(base_func); } - } else if (const auto* prim_func_node = expr.as()) { - return GetRef(prim_func_node); + } else if (auto prim_func = expr.as()) { + return prim_func.value(); } else if (const auto* var_node = expr.as()) { auto itr = primitive_functions_.find(var_node); if (itr == primitive_functions_.end()) { @@ -726,8 +726,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } // Alas, WithAttr cannot work with base classes. - if (const auto* prim_func_node = original_function.as()) { - auto func_with_metadata = GetRef(prim_func_node); + if (auto opt = original_function.as()) { + auto func_with_metadata = opt.value(); func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); @@ -737,9 +737,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } this->process_fn_(func_with_metadata); } else { - const auto* function_node = original_function.as(); - ICHECK(function_node); - auto func_with_metadata = GetRef(function_node); + auto func_with_metadata = original_function.as().value(); func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); @@ -866,8 +864,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { BaseFunc primitive_func = ResolveToPrimitive(call_node->op); if (!primitive_func.defined()) { // Cases 5 and 6: Leave as ordinary call. - if (const auto* function_node = call_node->op.as()) { - process_fn_(GetRef(function_node)); + if (auto function = call_node->op.as()) { + process_fn_(function.value()); } return WithFields(GetRef(call_node), std::move(new_op), std::move(new_args)); } @@ -887,15 +885,14 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; // Case 4: If the function has already been lowered we just need to update the call. - if (const auto* prim_func_node = primitive_func.as()) { + if (auto prim_func = primitive_func.as()) { // Function should already be Target annotated by this point // but the TE Compiler metadata is still needed for the callback // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks Optional opt_target = primitive_func->GetAttr(tvm::attr::kTarget); ICHECK(opt_target.defined()); auto prim_fn_var = Downcast(call_node->op); - tir::PrimFunc prim_func = GetRef(prim_func_node); - Map prim_fns = {{prim_fn_var, prim_func}}; + Map prim_fns = {{prim_fn_var, prim_func.value()}}; return MakeLoweredCall(primitive_func, prim_fn_var, std::move(new_args), call_node->span, opt_target.value(), prim_fns); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c29b3195a3fd..b67a4d8da5b6 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -693,13 +693,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto constructor = GetRef(constructor_node); Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, NewRegister())); - } else if (const auto* var_node = call_node->op.as()) { + } else if (auto var = call_node->op.as()) { // If we are calling a variable, it must be the case that it is a closure so we // emit invoke closure here. - VisitExpr(GetRef(var_node)); + VisitExpr(var.value()); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); - } else if (auto inner_call_node = call_node->op.as()) { - VisitExpr(GetRef(inner_call_node)); + } else if (auto inner_call = call_node->op.as()) { + VisitExpr(inner_call.value()); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); } else { // Finally if there are any other cases this is a bug. @@ -921,12 +921,13 @@ void VMCompiler::LowerImpl(IRModule mod) { for (const auto& pair : context_.module->functions) { auto gvar = pair.first; - if (auto* n = pair.second.as()) { - if (n->HasNonzeroAttr(attr::kExtern)) { + if (auto opt = pair.second.as()) { + auto func = opt.value(); + if (func->HasNonzeroAttr(attr::kExtern)) { // Already compiled during lowering. continue; } - auto func = GetRef(n); + VMFunctionCompiler func_compiler(&context_, config_->host_virtual_device); auto vm_func = func_compiler.Compile(gvar, func); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index b2912f6263dc..ba94e4b19ec7 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -81,8 +81,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { auto call = Downcast(DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node)); - if (auto var_node = call_node->op.as()) { - auto var = GetRef(var_node); + if (auto opt = call_node->op.as()) { + auto var = opt.value(); if (!letrec_.empty() && var == letrec_.back()) { auto it = lambda_map_.find(var); ICHECK(it != lambda_map_.end()); diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 993314a75b15..67ba1bd8594f 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -57,8 +57,8 @@ struct CallTracer : ExprVisitor { void VisitExpr_(const GlobalVarNode* op) final { called_funcs_.insert(op->name_hint); auto func = module_->Lookup(op->name_hint); - if (const auto* function_node = func.as()) { - VisitExpr(GetRef(function_node)); + if (auto function_node = func.as()) { + VisitExpr(function_node.value()); } // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. } diff --git a/src/relay/collage/mock_cost_estimator.cc b/src/relay/collage/mock_cost_estimator.cc index b4bbdb2da336..78fd24840517 100644 --- a/src/relay/collage/mock_cost_estimator.cc +++ b/src/relay/collage/mock_cost_estimator.cc @@ -88,8 +88,7 @@ Cost MockCostEstimatorNode::Estimate(const IRModule& mod, const Target& target) double op_cost = static_cast(target_costs_.at(target->kind->name)->value); double cost = 0.0; for (const auto& kv : mod->functions) { - if (const auto* function_node = kv.second.as()) { - auto function = GetRef(function_node); + if (const auto* function = kv.second.as()) { if (kv.first->name_hint == "main") { // Only tensor args are allowed to main. for (const auto& param : function->params) { diff --git a/src/relay/collage/sub_graph.cc b/src/relay/collage/sub_graph.cc index dee72093fd2f..a6559ff5fdb5 100644 --- a/src/relay/collage/sub_graph.cc +++ b/src/relay/collage/sub_graph.cc @@ -325,8 +325,8 @@ std::pair SubExprKindAndLabel(const Expr& sub_expr) class Visitor : public ExprFunctor(const Expr&)> { private: std::pair VisitExpr_(const CallNode* call_node) final { - if (const auto* op_node = call_node->op.as()) { - auto op = GetRef(op_node); + if (auto optional = call_node->op.as()) { + auto op = optional.value(); static auto fpattern = Op::GetAttrMap("TOpPattern"); if (fpattern.count(op) == 0) { VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque"; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 67c6bae6c5f9..185a92898d54 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -128,8 +128,8 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons return matches; } auto attributes = attr_pattern->attrs.as()->dict; - if (const auto* op_node = expr.as()) { - Op op = GetRef(op_node); + if (auto optional = expr.as()) { + Op op = optional.value(); for (auto kv : attributes) { auto attr_name = kv.first; auto attr_value = kv.second; diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 98e2ac0433b0..49ef3864aca8 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -532,8 +532,8 @@ Function SubstituteBoundVars(const Function& func, const tvm::Map& ar if (!args_map.count(func->params[i])) { new_params.push_back(func->params[i]); } else { - if (const VarNode* var = args_map[func->params[i]].as()) { - new_params.push_back(GetRef(var)); + if (auto var = args_map[func->params[i]].as()) { + new_params.push_back(var.value()); } else { ICHECK(false) << "Expected all values in args_map to be vars, but found " << args_map[func->params[i]]->GetTypeKey(); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index b73340df30ac..fd8c646ecf1c 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -147,8 +147,8 @@ TVM_REGISTER_GLOBAL("relay.ir.PrintIR") TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed") .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { - if (const auto* relay_func = base_func.as()) { - Function func = Downcast(relay::DeDup(GetRef(relay_func))); + if (auto relay_func = base_func.as()) { + Function func = Downcast(relay::DeDup(relay_func.value())); // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 88b3e79e8ed1..65351562054e 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -224,8 +224,8 @@ RELAY_REGISTER_OP("memory.kill") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); static void FlattenTupleTypeAux(const Type& type, std::vector* out) { - if (auto tt = type.as()) { - out->push_back(GetRef(tt)); + if (auto tt = type.as()) { + out->push_back(tt.value()); } else if (auto tuple_ty = type.as()) { for (auto field : tuple_ty->fields) { FlattenTupleTypeAux(field, out); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index e8b58f414a43..71e58bb927f5 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -139,8 +139,8 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - if (auto* t0 = types[0].as()) { - Type out_type = TensorType(GetRef(t0)->shape, DataType::Bool()); + if (const auto* t0 = types[0].as()) { + Type out_type = TensorType(t0->shape, DataType::Bool()); reporter->Assign(types[1], out_type); return true; } diff --git a/src/relay/parser/parser.cc b/src/relay/parser/parser.cc index ae7fc52cbead..b519a1778ce0 100644 --- a/src/relay/parser/parser.cc +++ b/src/relay/parser/parser.cc @@ -1366,8 +1366,8 @@ class Parser { } // TODO(@jroesch): not sure about this being the right way to handle nulls. case TokenType::kIdentifier: { - if (auto text = next->data.as()) { - std::string id = GetRef(text); + if (auto text = next->data.as()) { + std::string id = text.value(); if (id == "nullptr") { Match(TokenType::kIdentifier); return ObjectRef(); diff --git a/src/relay/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc index 5b47c262fd48..f6a5b2926aba 100644 --- a/src/relay/printer/relay_text_printer.cc +++ b/src/relay/printer/relay_text_printer.cc @@ -445,11 +445,11 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { } Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { - if (auto* n = base_func.as()) { - return PrintFunc(prefix, GetRef(n)); - } else if (auto* n = base_func.as()) { + if (auto func = base_func.as()) { + return PrintFunc(prefix, func.value()); + } else if (auto func = base_func.as()) { std::ostringstream os; - os << GetRef(n); + os << func.value(); return Doc::RawText(os.str()); } else { // def @xyz = meta['ExternalFunc'][id] @@ -896,18 +896,18 @@ Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_met Doc printed_attr; if (value.as()) { printed_attr << "?"; - } else if (auto str_obj = value.as()) { - printed_attr << Doc::StrLiteral(GetRef(str_obj)); + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(str_obj.value()); } else if (force_meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); - } else if (const auto* virtual_device_node = value.as()) { + } else if (auto virtual_device_node = value.as()) { if (show_meta_data_) { - printed_attr = meta_->GetMetaNode(GetRef(virtual_device_node)); + printed_attr = meta_->GetMetaNode(virtual_device_node.value()); } else { // Special case: The ReprPrinter for VirtualDeviceNodes is much easier to work with while // debugging. std::ostringstream os; - os << GetRef(virtual_device_node); + os << virtual_device_node.value(); return Doc::Text(os.str()); } } else if (const auto* base_attr_node = value.as()) { @@ -925,11 +925,11 @@ Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_met // Special case: Show maps fields as key=value pairs to help debugging. printed_attr << PrintMapAsAttributeValue(GetRef>(base_map_node)); } - } else if (const auto* global_var_node = value.as()) { + } else if (auto global_var = value.as()) { if (show_meta_data_) { - printed_attr = meta_->GetMetaNode(GetRef(global_var_node)); + printed_attr = meta_->GetMetaNode(global_var.value()); } else { - printed_attr << "'" << global_var_node->name_hint << "'"; + printed_attr << "'" << global_var.value()->name_hint << "'"; } } else { printed_attr = VisitAttr(value); diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index b0e96cc47514..f268665ce212 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -68,8 +68,8 @@ class CastCanonicalizer : public ExprMutator { Expr VisitExpr_(const CallNode* call) { static auto fpattern = Op::GetAttrMap("TOpPattern"); - if (const OpNode* opnode = call->op.as()) { - auto pattern = fpattern[GetRef(opnode)]; + if (auto call_op = call->op.as()) { + auto pattern = fpattern[call_op.value()]; if (pattern <= kBroadcast) { Array call_args = call->args; bool unchanged = true; diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index f1e7e223541b..653659bb9a89 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -166,8 +166,8 @@ class OuterInliner : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); - if (const auto* global_var_node = new_call->op.as()) { - auto global_var = GetRef(global_var_node); + if (auto global_var_node = new_call->op.as()) { + auto global_var = global_var_node.value(); if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) { BaseFunc base_func = mod_->Lookup(global_var); const auto* function_node = base_func.as(); diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index a83e757b6cfa..e2b350a439ea 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -549,8 +549,8 @@ Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs); for (const auto& kv : mod->functions) { - if (const auto* function_node = kv.second.as()) { - auto function = GetRef(function_node); + if (auto opt = kv.second.as()) { + auto function = opt.value(); VLOG(1) << "processing " << PrettyPrint(kv.first); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index bf7839dfa48f..c9050c730d10 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -345,9 +345,9 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { auto expr = GetRef(let_node); std::vector> bindings; - while (const auto* inner_let_node = expr.as()) { - Let inner_let = GetRef(inner_let_node); - Expr value = VisitExpr(inner_let_node->value); + while (auto opt = expr.as()) { + auto inner_let = opt.value(); + Expr value = VisitExpr(inner_let->value); OnDeviceProps props = GetOnDeviceProps(value); if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising let-bound expression of let:" << std::endl @@ -356,7 +356,7 @@ class RewriteOnDevices : public ExprMutator { value = MaybeOnDeviceFixed(props.body, props.virtual_device); } bindings.emplace_back(inner_let, value); - expr = inner_let_node->body; + expr = inner_let->body; } expr = VisitExpr(expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { @@ -438,10 +438,9 @@ class DeviceAnalyzer : public MixedModeVisitor { VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'"; domains_->UnifyExprExact(kv.first, kv.second); VisitExpr(GetRef(function_node)); - } else if (const auto* prim_func_node = kv.second.as()) { + } else if (auto prim_func = kv.second.as()) { VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'"; - domains_->UnifyExprExact( - kv.first, DomainForPrimFunc(kv.first, GetRef(prim_func_node))); + domains_->UnifyExprExact(kv.first, DomainForPrimFunc(kv.first, prim_func.value())); } else { VLOG(2) << "skipping '" << kv.first->name_hint << "'"; } @@ -917,10 +916,9 @@ class DeviceCapturer : public ExprMutator { if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'"; result->Add(kv.first, Downcast(Mutate(GetRef(function_node)))); - } else if (const auto* prim_func_node = kv.second.as()) { + } else if (auto prim_func = kv.second.as()) { VLOG(2) << "capturing devices for TIR PrimFunc '" << kv.first->name_hint << "'"; - auto prim_func = GetRef(prim_func_node); - tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func); + tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func.value()); VLOG(2) << "Rewritten prim func:" << std::endl << PrettyPrint(prim_func) << std::endl << "to:" << std::endl @@ -1111,9 +1109,8 @@ class DeviceCapturer : public ExprMutator { // Iterate through chained lets, provided they all agree on their device type. VirtualDevice let_virtual_device = GetVirtualDevice(expr); std::vector> bindings; - while (const auto* inner_let_node = expr.as()) { - Expr inner_let = GetRef(inner_let_node); - if (GetVirtualDevice(inner_let) != let_virtual_device) { + while (const auto* inner_let = expr.as()) { + if (GetVirtualDevice(GetRef(inner_let)) != let_virtual_device) { // We have a device transition which needs to be handled. break; } @@ -1121,12 +1118,12 @@ class DeviceCapturer : public ExprMutator { // By using the fully-unconstrained virtual device for the 'lexical' scope we'll force the // let-bound value to *always* be wrapped by an "on_device" (see introductory comment for // motivation.) - Expr value = VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(), - /*expected_virtual_device=*/GetVirtualDevice(inner_let_node->var), - /*child_virtual_device=*/GetVirtualDevice(inner_let_node->value), - inner_let_node->value); - bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); - expr = inner_let_node->body; + Expr value = + VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(), + /*expected_virtual_device=*/GetVirtualDevice(inner_let->var), + /*child_virtual_device=*/GetVirtualDevice(inner_let->value), inner_let->value); + bindings.emplace_back(inner_let->var, value, inner_let->span); + expr = inner_let->body; } Expr body = VisitChild(/*lexical_virtual_device=*/let_virtual_device, /*expected_virtual_device=*/let_virtual_device, diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index bafdbd359141..a989cf53f818 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -250,8 +250,8 @@ class DynamicToStaticMutator : public MixedModeMutator { Expr PrepareInput(const Expr& expr) { BaseFunc func; - if (auto* func_node = expr.as()) { - func = GetRef(func_node); + if (auto func_node = expr.as()) { + func = func_node.value(); } else { func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index 40b0a54ba38c..9759f732df4d 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -68,8 +68,8 @@ class EtaExpander : public ExprMutator { IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); - if (auto* n = base_func.as()) { - const Function new_func = Downcast(VisitExpr(GetRef(n))); + if (auto func = base_func.as()) { + const Function new_func = Downcast(VisitExpr(func.value())); mod_->Update(global_var, new_func); } } @@ -119,9 +119,9 @@ class EtaExpander : public ExprMutator { return std::move(gvar); } const auto base_func = mod_->Lookup(gvar); - if (auto* ptr = base_func.as()) { + if (auto opt = base_func.as()) { // handle relay function, skip external functions. - auto func = GetRef(ptr); + auto func = opt.value(); tvm::Array params; tvm::Array args; for (size_t i = 0; i < func->params.size(); ++i) { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index dba412f81688..34f986b251a2 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -235,8 +235,8 @@ class ConstantFolder : public MixedModeMutator { if (value->IsInstance()) { auto nd_array = Downcast(value); return Constant(nd_array); - } else if (const auto* val = value.as()) { - runtime::ADT adt = GetRef(val); + } else if (auto opt = value.as()) { + runtime::ADT adt = opt.value(); Array fields; for (size_t i = 0; i < adt.size(); ++i) { fields.push_back(ObjectToExpr(adt[i])); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 1fb857cb1cb3..9c0d38b11587 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -187,8 +187,8 @@ class IndexedForwardGraphCreator : private ExprVisitor { // Finally if the operator position is not a call node we will // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; - if (const OpNode* opnode = call->op.as()) { - auto op = GetRef(opnode); + if (auto optional = call->op.as()) { + auto op = optional.value(); if (IsDynamic(call->checked_type()) && IsDataDependent(call)) { // output of a shape func can't be fed to a data-dependent shape func op_pattern = kOpaque; diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 1cf7cb86692c..edf1e4c99f4d 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -78,8 +78,8 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); - if (auto* n = base_func.as()) { - return GetRef(n); + if (auto func = base_func.as()) { + return func.value(); } else { return e; } @@ -236,9 +236,9 @@ struct ReverseAD : ExprMutator { } Expr VisitCheckpoint(const CallNode* call) { - const OpNode* op_node = call->op.as(); - ICHECK(op_node) << "expected op in call"; - Op op_ref = GetRef(op_node); + auto optional = call->op.as(); + ICHECK(optional) << "expected op in call"; + Op op_ref = optional.value(); ICHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; auto x = call->args[0]; return LetList::With([&](LetList* ll) { @@ -261,14 +261,14 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const CallNode* call) final { - if (const OpNode* op_node = call->op.as()) { - Op op_ref = GetRef(op_node); + if (auto optional = call->op.as()) { + Op op_ref = optional.value(); if (op_ref->name == "annotation.checkpoint") { return VisitCheckpoint(call); } - ICHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined"; + ICHECK(rev_map.count(op_ref)) << op_ref->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; for (const auto& arg : call->args) { diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 012b3579494f..564c0daef70f 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -185,9 +185,8 @@ IRModule Inline(const IRModule& module) { // functions. if (it->empty() || (it->IsRecursive() && it->size() == 1)) continue; auto base_func = module->Lookup(it->GetNameHint()); - if (const auto* fn = base_func.as()) { - auto func = GetRef(fn); - auto new_func = Inliner(it, cg.operator->()).Inline(func); + if (auto func = base_func.as()) { + auto new_func = Inliner(it, cg.operator->()).Inline(func.value()); // TODO(zhiics) Maybe move this to CallGraph, but updating function from // CallGraph arbitarily may lead to incorrect CallGraph. cg->module->Update(it->GetGlobalVar(), new_func); @@ -201,8 +200,7 @@ IRModule Inline(const IRModule& module) { if (cgn->IsRecursive() || original_entry.count(cgn)) continue; auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); // Skip calls to PrimFuncs since they can't be inlined. - if (const auto* fn = base_func.as()) { - auto func = GetRef(fn); + if (const auto* func = base_func.as()) { if (func->HasNonzeroAttr(attr::kInline)) { ICHECK_EQ(cgn->GetRefCount(), 0U) << cgn->GetNameHint() << " is marked as inline but not inlined."; diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 079b790e74c0..548951f19404 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -157,8 +157,8 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { } Expr VisitExpr_(const CallNode* call_node) final { - if (auto* op = (call_node->op).as()) { - Expr op_expr = GetRef(op); + if (auto op = call_node->op.as()) { + Expr op_expr = op.value(); if (op_expr == Op::Get("add")) { return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell")); diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index f791192e25c1..c574a5772c16 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -635,8 +635,8 @@ class PartialEvaluator : public ExprFunctor ICHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { BaseFunc base_func = mod_->Lookup(gv); - if (auto* n = base_func.as()) { - Function func = GetRef(n); + if (auto opt = base_func.as()) { + auto func = opt.value(); InitializeFuncId(func); Func f = VisitFuncStatic(func, gv); gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); @@ -879,10 +879,10 @@ class PartialEvaluator : public ExprFunctor if (v->IsInstance()) { auto nd_array = Downcast(v); return HasStatic(MkSTensor(nd_array), ll->Push(Constant(nd_array))); - } else if (const runtime::ADTObj* op = v.as()) { + } else if (auto opt = v.as()) { std::vector fields; tvm::Array fields_dyn; - auto adt = GetRef(op); + auto adt = opt.value(); for (size_t i = 0; i < adt.size(); ++i) { PStatic ps = Reify(adt[i], ll); fields.push_back(ps); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 32ca2878fdc9..0be68872dd9c 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -213,8 +213,8 @@ class Partitioner : public MixedModeMutator { IRModule Partition() { auto glob_funcs = module_->functions; for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - Function func = GetRef(fn); + if (auto opt = pair.second.as()) { + Function func = opt.value(); func = WithFields(func, func->params, VisitExpr(func->body)); module_->Update(pair.first, func); module_ = transform::InferType()(module_); @@ -426,8 +426,8 @@ IRModule RemoveDefaultAnnotations(IRModule module) { // module is mutable, hence, we make a copy of it. module.CopyOnWrite(); for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + if (auto opt = pair.second.as()) { + auto func = opt.value(); DefaultRemover remover; auto removed = PostOrderRewrite(func->body, &remover); func = WithFields(func, func->params, removed); @@ -482,8 +482,8 @@ IRModule FlattenTupleOutputs(IRModule module) { // module is mutable, hence, we make a copy of it. module.CopyOnWrite(); for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - Function func = GetRef(fn); + if (auto opt = pair.second.as()) { + Function func = opt.value(); TupleOutFlattener to_flattener; auto removed = PostOrderRewrite(func->body, &to_flattener); func = WithFields(func, func->params, removed); @@ -505,8 +505,8 @@ class NameMangleExtFuncs : public MixedModeMutator { // Collect function names to be mangled and create // global mangled variables for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + if (auto opt = pair.second.as()) { + auto func = opt.value(); if (func->GetAttr(attr::kCompiler).defined()) { auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint)); GlobalVar gvar = GlobalVar(fn_name_mangled); @@ -521,8 +521,8 @@ class NameMangleExtFuncs : public MixedModeMutator { new_module->functions = {}; for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + if (auto opt = pair.second.as()) { + auto func = opt.value(); if (func->GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index a9b7390c0374..6c104ce28298 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -581,8 +581,8 @@ class ConcretizeLikeRewrite : public DFPatternRewrite { const TensorTypeNode* like_ty = pre->checked_type().as(); Array cshape; for (const auto& dim : like_ty->shape) { - if (const auto* imm = dim.as()) { - cshape.push_back(Integer(GetRef(imm))); + if (auto imm = dim.as()) { + cshape.push_back(Integer(imm.value())); } else { // shape is not static, don't concretize return post; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index d2eb48073f7d..420c1e50a22d 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -681,7 +681,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Var VisitVar(const Var& v) final { if (vmap_.count(v) == 0) { - vmap_[v] = GetRef(AttachCheckedType(v.as()).as()); + vmap_[v] = Downcast(AttachCheckedType(v.as())); } return vmap_.at(v); } @@ -969,9 +969,7 @@ Pass InferType() { // In the future we plan a unified type checker // that works on TIR and Relay at the same time. - if (auto* func_node = it.second.as()) { - auto func = GetRef(func_node); - + if (auto func = it.second.as()) { // // If a function already has type information we can skip checking it. // if (func->checked_type_.defined()) { // continue; @@ -980,7 +978,7 @@ Pass InferType() { // TODO(@jroesch): we should be able to move the type inferencer outside // of this function but it seems to be more stateful then I expect. auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); - auto updated_func = inferencer.Infer(it.first, func); + auto updated_func = inferencer.Infer(it.first, func.value()); pass_ctx->diag_ctx.value().Render(); diff --git a/src/runtime/debug.cc b/src/runtime/debug.cc index e5d9f0ead09e..37ab6ec5803a 100644 --- a/src/runtime/debug.cc +++ b/src/runtime/debug.cc @@ -108,10 +108,10 @@ void AppendADT(std::ostream& os, const ADT& adt, const DLDevice& host_device, bo void AppendRuntimeObject(std::ostream& os, const ObjectRef& object, const DLDevice& host_device, bool show_contents) { - if (const auto* adt_obj = object.as()) { - AppendADT(os, GetRef(adt_obj), host_device, show_contents); - } else if (const auto* nd_array_cont = object.as()) { - AppendNDArray(os, GetRef(nd_array_cont), host_device, show_contents); + if (auto adt = object.as()) { + AppendADT(os, adt.value(), host_device, show_contents); + } else if (auto nd_array_cont = object.as()) { + AppendNDArray(os, nd_array_cont.value(), host_device, show_contents); } else { os << "?"; } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index aee0b4bb6253..154a1ab3b01b 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -173,10 +173,10 @@ void Reads(Array buffer_slices) { } Array reads; for (const ObjectRef& obj : buffer_slices) { - if (const auto* buffer_region = obj.as()) { - reads.push_back(GetRef(buffer_region)); - } else if (const auto* buffer_load = obj.as()) { - reads.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + if (auto buffer_region = obj.as()) { + reads.push_back(buffer_region.value()); + } else if (auto buffer_load = obj.as()) { + reads.push_back(BufferRegionFromLoad(buffer_load.value())); } else { LOG(FATAL) << "Invalid type for buffer reads."; } @@ -193,10 +193,10 @@ void Writes(Array buffer_slices) { } Array writes; for (const ObjectRef& obj : buffer_slices) { - if (const auto* buffer_region = obj.as()) { - writes.push_back(GetRef(buffer_region)); - } else if (const auto* buffer_load = obj.as()) { - writes.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + if (auto buffer_region = obj.as()) { + writes.push_back(buffer_region.value()); + } else if (auto buffer_load = obj.as()) { + writes.push_back(BufferRegionFromLoad(buffer_load.value())); } else { LOG(FATAL) << "Invalid type for buffer writes."; } @@ -576,8 +576,8 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) int n = buffer->strides.size(); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->strides[i]; - if (const tvm::tir::VarNode* v = e.as()) { - Namer::Name(GetRef(v), name + "_s" + std::to_string(i)); + if (auto v = e.as()) { + Namer::Name(v.value(), name + "_s" + std::to_string(i)); } } }); @@ -608,11 +608,11 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; - if (const auto* var = obj.as()) { - return Arg(name, GetRef(var)); + if (auto var = obj.as()) { + return Arg(name, var.value()); } - if (const auto* buffer = obj.as()) { - return Arg(name, GetRef(buffer)); + if (auto buffer = obj.as()) { + return Arg(name, buffer.value()); } LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); throw; @@ -657,10 +657,10 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) { - if (const auto* var = thread_tag_or_var.as()) { - return LaunchThread(GetRef(var), extent); - } else if (const auto* str = thread_tag_or_var.as()) { - return LaunchThread(GetRef(str), extent); + if (auto var = thread_tag_or_var.as()) { + return LaunchThread(var.value(), extent); + } else if (auto str = thread_tag_or_var.as()) { + return LaunchThread(str.value(), extent); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " << thread_tag_or_var->GetTypeKey(); diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 712796e7a1dd..e5a47d7ca2a5 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -294,54 +294,54 @@ String DocPrinter::GetString() const { void DocPrinter::PrintDoc(const Doc& doc) { size_t start_pos = output_.tellp(); - if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); - } else if (const auto* doc_node = doc.as()) { - PrintTypedDoc(GetRef(doc_node)); + if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 1c751d40f2e7..87e7bfbcd9d2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -78,8 +78,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (const auto* stmt_block = doc.as()) { (*f)->stmts.push_back(stmt_block->stmts.back()); (*f)->stmts.back()->source_paths = std::move(doc->source_paths); - } else if (const auto* stmt = doc.as()) { - (*f)->stmts.push_back(GetRef(stmt)); + } else if (auto stmt = doc.as()) { + (*f)->stmts.push_back(stmt.value()); } else { (*f)->stmts.push_back(Downcast(doc)); } diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 92a80eb36dba..0c9289a9d2a9 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -34,9 +34,9 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // std::unordered_map loop_vars; for (Frame f : d->frames) { if (const auto* tir_f = f.as()) { - if (const auto* for_loop = tir_f->tir.as()) { - for (const tir::ForNode* l = for_loop; l != nullptr; l = l->body.as()) { - loop_vars.insert(std::make_pair(l->loop_var.get(), GetRef(l))); + if (auto for_loop = tir_f->tir.as()) { + for (Optional loop = for_loop; loop; loop = loop.value()->body.as()) { + loop_vars.insert(std::make_pair(loop.value()->loop_var.get(), loop.value())); } } } diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 710f2eab22e2..6f160b940579 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -239,15 +239,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Op::GetAttrMap("TScriptDtypePrintLocation"); tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; ExprDoc prefix{nullptr}; - if (const auto* op = call->op.as()) { - String name = op_names.get(GetRef(op), op->name); - if (op_names.count(GetRef(op)) == 0) { + if (auto optional_op = call->op.as()) { + auto op = optional_op.value(); + String name = op_names.get(op, op->name); + if (op_names.count(op) == 0) { LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } prefix = TIR(d, name); - if (dtype_locations.count(GetRef(op))) { - dtype_print_location = static_cast( - dtype_locations[GetRef(op)].IntValue()); + if (dtype_locations.count(op)) { + dtype_print_location = + static_cast(dtype_locations[op].IntValue()); } } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 69fe8aa2b748..3fbc93f678e4 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1716,8 +1716,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (auto* ptr_op = op->op.as()) { - auto call_op = GetRef(ptr_op); + if (auto opt_call_op = op->op.as()) { + auto call_op = opt_call_op.value(); if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic ICHECK_GE(op->args.size(), 1U); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 7807b46e4227..f6792c1a4e8b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -526,8 +526,8 @@ void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array< } void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (auto* ptr_op = op->op.as()) { - auto call_op = GetRef(ptr_op); + if (auto opt_call_op = op->op.as()) { + auto call_op = opt_call_op.value(); if (op->op.same_as(builtin::tvm_check_return())) { const CallNode* call = op->args[2].as(); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 3b3fdbc58a4b..0a84ed658034 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -687,8 +687,8 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr } void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { - if (auto* ptr_op = op->op.as()) { - Op call_op = GetRef(ptr_op); + if (auto opt_call_opt = op->op.as()) { + Op call_op = opt_call_opt.value(); // This is only for backward compatibility with __shfl_{up/down}. // A macro will be used to replace *_sync calls to legacy ones. if (op_need_warp_shuffle_.get(call_op, false)) { diff --git a/src/target/target.cc b/src/target/target.cc index 24a418709ff3..f05d4db2b888 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -427,10 +427,10 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, return GetRef(ObjTypeCheck(obj, "String")); } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target - if (const auto* ptr = obj.as()) { - return GetRef(ptr); - } else if (const auto* ptr = obj.as()) { - return Target(TargetInternal::FromString(GetRef(ptr))); + if (auto opt = obj.as()) { + return opt.value(); + } else if (auto str = obj.as()) { + return Target(TargetInternal::FromString(str.value())); } else if (const auto* ptr = obj.as()) { for (const auto& kv : *ptr) { if (!kv.first->IsInstance()) { @@ -495,8 +495,8 @@ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { if (const auto* p = obj.as()) { return std::to_string(p->value); } - if (const auto* p = obj.as()) { - auto s = static_cast(GetRef(p)); + if (auto tvm_str = obj.as()) { + std::string s = tvm_str.value(); auto u = Uninterpret(s); if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { u = Quote(u); @@ -660,9 +660,7 @@ Map TargetNode::Export() const { return result; } -Optional TargetNode::GetHost() const { - return GetRef>(this->host.as()); -} +Optional TargetNode::GetHost() const { return this->host.as(); } int TargetNode::GetTargetDeviceType() const { if (Optional device_type = GetAttr("target_device_type")) { @@ -853,8 +851,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // parse 'kind' if (config.count(kKind)) { - if (const auto* kind = config[kKind].as()) { - target->kind = GetTargetKind(GetRef(kind)); + if (auto kind = config[kKind].as()) { + target->kind = GetTargetKind(kind.value()); ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) << "Cannot use both set_attrs_preprocessor and set_target_parser"; @@ -878,8 +876,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "tag" if (config.count(kTag)) { - if (const auto* tag = config[kTag].as()) { - target->tag = GetRef(tag); + if (auto tag = config[kTag].as()) { + target->tag = tag.value(); config.erase(kTag); } else { throw Error(": Expect type of field \"tag\" is String, but get type: " + @@ -896,8 +894,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // user provided keys if (const auto* cfg_keys = config[kKeys].as()) { for (const ObjectRef& e : *cfg_keys) { - if (const auto* key = e.as()) { - keys.push_back(GetRef(key)); + if (auto key = e.as()) { + keys.push_back(key.value()); } else { throw Error( ": Expect 'keys' to be an array of strings, but it " @@ -912,8 +910,8 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // add device name if (config.count(kDeviceName)) { - if (const auto* device = config.at(kDeviceName).as()) { - keys.push_back(GetRef(device)); + if (auto device = config.at(kDeviceName).as()) { + keys.push_back(device.value()); } } if (!has_user_keys) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d1b2c10edf01..3a555e304cb0 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -135,10 +135,9 @@ void CheckOrSetAttr(Map* attrs, const String& name, const Str if (iter == attrs->end()) { attrs->Set(name, value); } else { - const auto* str = (*iter).second.as(); - ICHECK(str != nullptr && GetRef(str) == value) - << "ValueError: Expects \"" << name << "\" to be \"" << value - << "\", but gets: " << (*iter).second; + auto str = (*iter).second.as(); + ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name << "\" to be \"" + << value << "\", but gets: " << (*iter).second; } } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index cc52cf618dc1..dc0b1fbfb86f 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -126,8 +126,8 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { public: PrimFunc Process(PrimFunc func) { for (int i = 0, n = func->params.size(); i < n; ++i) { - if (const auto* v = func->params[i].as()) { - if (Optional buffer = func->buffer_map.Get(GetRef(v))) { + if (auto v = func->params[i].as()) { + if (Optional buffer = func->buffer_map.Get(v.value())) { buffer2index_[buffer.value()] = i; } } @@ -298,8 +298,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // Step 5. Add script_parsing_detect_access attr for auto complete the whole IR. Map annotations; auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef { - if (const auto* tensor_value = value.as()) { - return info->tensor2buffers.at(GetRef(tensor_value)); + if (auto tensor_value = value.as()) { + return info->tensor2buffers.at(tensor_value.value()); } else { return value; } @@ -499,13 +499,12 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Arrayshape, placeholder->dtype, placeholder->name, "global"); info->tensor2buffers[tensor] = buffer; } - } else if (const auto* compute_op = op.as()) { + } else if (auto compute_op = op.as()) { // Case 2. ComputeOp (te.compute) - root_stmts->push_back( - GenerateStmtFromCompute(GetRef(compute_op), info, analyzer)); - } else if (const auto extern_op = op.as()) { + root_stmts->push_back(GenerateStmtFromCompute(compute_op.value(), info, analyzer)); + } else if (const auto extern_op = op.as()) { // Case 3. ExternOp (te.extern) - root_stmts->push_back(GenerateStmtFromExternOp(GetRef(extern_op), info)); + root_stmts->push_back(GenerateStmtFromExternOp(extern_op.value(), info)); } else { ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " << "Only te.placeholder and te.compute are allowed for now."; diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 9da8ec435524..95fd7f134ed2 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -101,9 +101,8 @@ namespace transform { Pass VerifyVTCMLimit(const Integer& limit) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - auto sizes = CalculateAllocatedBytes(func); + if (auto func = kv.second.as()) { + auto sizes = CalculateAllocatedBytes(func.value()); const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been " diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 86ce4e21351f..e2b935b19046 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -291,11 +291,11 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { tir::BufferLoad load; PrimExpr value; - if (auto* as_load = as_equal_node->a.as()) { - load = GetRef(as_load); + if (auto opt = as_equal_node->a.as()) { + load = opt.value(); value = as_equal_node->b; - } else if (auto* as_load = as_equal_node->b.as()) { - load = GetRef(as_load); + } else if (auto opt = as_equal_node->b.as()) { + load = opt.value(); value = as_equal_node->a; } else if (!from_assume_statement) { return; diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index ce9d5eaaf838..1c427e5fd965 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -220,10 +220,10 @@ double EstimateTIRFlops(const IRModule& mod) { } TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { - if (const auto* mod = obj.as()) { - return EstimateTIRFlops(GetRef(mod)); - } else if (const auto* stmt = obj.as()) { - return EstimateTIRFlops(GetRef(stmt)); + if (auto mod = obj.as()) { + return EstimateTIRFlops(mod.value()); + } else if (auto stmt = obj.as()) { + return EstimateTIRFlops(stmt.value()); } else { LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " << obj->GetTypeKey(); diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index 0d3b48dbc2c6..1255b5bb13e9 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -58,8 +58,8 @@ std::variant IdentifyMemCpyImpl(const For& loop, } BufferStore store; - if (auto* ptr = stmt.as()) { - store = GetRef(ptr); + if (auto opt = stmt.as()) { + store = opt.value(); } else { return static_cast( std::stringstream() @@ -68,8 +68,8 @@ std::variant IdentifyMemCpyImpl(const For& loop, } BufferLoad load; - if (auto* ptr = store->value.as()) { - load = GetRef(ptr); + if (auto opt = store->value.as()) { + load = opt.value(); } else { return static_cast( std::stringstream() diff --git a/src/tir/analysis/side_effect.cc b/src/tir/analysis/side_effect.cc index 7c5d39283774..e20e60d24a66 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tir/analysis/side_effect.cc @@ -45,8 +45,8 @@ class ExprSideEffect : public ExprVisitor { void VisitExpr_(const CallNode* op) final { static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); - if (auto* ptr_op = op->op.as()) { - this->UpdateEffect(static_cast(op_call_effect[GetRef(ptr_op)]->value)); + if (auto opt = op->op.as()) { + this->UpdateEffect(static_cast(op_call_effect[opt.value()]->value)); } else { this->UpdateEffect(CallEffectKind::kOpaque); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 3d6c66d0e193..f012f8a1b35e 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -328,9 +328,8 @@ namespace transform { Pass VerifyGPUCode(Map constraints) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - auto errs = VerifyGPUCode_(func, constraints); + if (auto func = kv.second.as()) { + auto errs = VerifyGPUCode_(func.value(), constraints); if (errs.size() != 0) { std::stringstream s; for (auto& err : errs) { diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index a210a555b4cd..a990230e043a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -195,9 +195,8 @@ namespace transform { Pass VerifyMemory() { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - auto errs = VerifyMemory_(func); + if (auto func = kv.second.as()) { + auto errs = VerifyMemory_(func.value()); if (errs.size() > 0) { std::stringstream s; for (auto& err : errs) { diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index d7ccb363c16e..e04dcf90aa79 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -146,9 +146,8 @@ namespace transform { Pass VerifySSA() { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - ICHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func; + if (auto func = kv.second.as()) { + ICHECK(VerifySSA(func.value())) << "RuntimeError: IR is not in SSA form" << func; } } return mod; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index bc4566b1f124..9219dde2291b 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -513,7 +513,7 @@ TVM_REGISTER_GLOBAL("tir.Call") if (const auto* str = it.as()) { prim_expr_args.push_back(StringImm(str->data)); } else if (const auto* iter_var = it.as()) { - prim_expr_args.push_back(GetRef(iter_var)->var); + prim_expr_args.push_back(iter_var->var); } else if (const auto* br = it.as()) { Array indices; for (Range r : br->region) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index f5063b222b9b..075bbcd3ace1 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -708,8 +708,8 @@ class IRSubstitute : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); // remap var node in attr - if (const auto* var_node = op->node.as()) { - if (auto mapped_var = vmap_(GetRef(var_node))) { + if (auto var_node = op->node.as()) { + if (auto mapped_var = vmap_(var_node.value())) { return AttrStmt(mapped_var, op->attr_key, op->value, op->body); } } @@ -770,10 +770,10 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, }; PreOrderVisitor visitor(fvisit); - if (const auto* stmt = stmt_or_expr.as()) { - visitor(GetRef(stmt)); - } else if (const auto* expr = stmt_or_expr.as()) { - visitor(GetRef(expr)); + if (auto stmt = stmt_or_expr.as()) { + visitor(stmt.value()); + } else if (auto expr = stmt_or_expr.as()) { + visitor(expr.value()); } else { LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " << stmt_or_expr->GetTypeKey(); @@ -840,8 +840,8 @@ class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); // remap var node in attr - if (const auto* var_node = op->node.as()) { - if (auto mapped_var = vmap_(GetRef(var_node))) { + if (auto var_node = op->node.as()) { + if (auto mapped_var = vmap_(var_node.value())) { return AttrStmt(mapped_var, op->attr_key, op->value, op->body); } } diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index b071b2d7e4a1..c31516234131 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -105,8 +105,8 @@ class SplitExprCollector { return; } exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); - } else if (const auto* iter_sum_expr = expr->source->source.as()) { - Visit(GetRef(iter_sum_expr)); + } else if (auto iter_sum_expr = expr->source->source.as()) { + Visit(iter_sum_expr.value()); } else { ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); } diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc index 1ae4c5dd034b..ed59fe645026 100644 --- a/src/tir/schedule/analysis/reducer.cc +++ b/src/tir/schedule/analysis/reducer.cc @@ -437,20 +437,20 @@ std::pair, Array> GetInitValuesAndUpdatesFromReduct Array updates; // Step 1. Extract the BufferStores serving as block inits. - if (const auto* init = block->init.as()) { - inits.push_back(GetRef(init)); + if (auto init = block->init.as()) { + inits.push_back(init.value()); } else if (const auto* seq_init = block->init.as()) { std::unordered_set init_buffers; for (const Stmt& stmt : seq_init->seq) { - init = stmt.as(); - if (init == nullptr) { + auto init = stmt.as(); + if (!init) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); } - auto insert_result = init_buffers.insert(init->buffer.get()); + auto insert_result = init_buffers.insert(init.value()->buffer.get()); if (!insert_result.second) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/2); } - inits.push_back(GetRef(init)); + inits.push_back(init.value()); } } else { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index f5c1978a1b25..92c3423bcbbb 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -97,11 +97,11 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { - if (const auto* block = block_or_loop_rv.as()) { - return sch->Annotate(GetRef(block), ann_key, ann_val); + if (auto block = block_or_loop_rv.as()) { + return sch->Annotate(block.value(), ann_key, ann_val); } - if (const auto* loop = block_or_loop_rv.as()) { - return sch->Annotate(GetRef(loop), ann_key, ann_val); + if (auto loop = block_or_loop_rv.as()) { + return sch->Annotate(loop.value(), ann_key, ann_val); } LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); throw; @@ -130,11 +130,11 @@ struct UnannotateTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { - if (const auto* block = block_or_loop_rv.as()) { - return sch->Unannotate(GetRef(block), ann_key); + if (auto block = block_or_loop_rv.as()) { + return sch->Unannotate(block.value(), ann_key); } - if (const auto* loop = block_or_loop_rv.as()) { - return sch->Unannotate(GetRef(loop), ann_key); + if (auto loop = block_or_loop_rv.as()) { + return sch->Unannotate(loop.value(), ann_key); } LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); throw; diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 69443f2e19bd..b1645d5cbd41 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -654,10 +654,10 @@ struct TensorizeTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin, Bool preserve_unit_iters) { - if (const auto* block = block_or_loop_rv.as()) { - sch->Tensorize(GetRef(block), intrin, preserve_unit_iters.operator bool()); - } else if (const auto* loop = block_or_loop_rv.as()) { - sch->Tensorize(GetRef(loop), intrin, preserve_unit_iters.operator bool()); + if (auto block = block_or_loop_rv.as()) { + sch->Tensorize(block.value(), intrin, preserve_unit_iters.operator bool()); + } else if (auto loop = block_or_loop_rv.as()) { + sch->Tensorize(loop.value(), intrin, preserve_unit_iters.operator bool()); } else { LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 5c2db64d6e4d..c2fc7ac24afc 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -2013,8 +2013,8 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde std::unordered_set covered; for (const PrimExpr& index : original_indices) { PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { - if (const VarNode* var = obj.as()) { - covered.insert(GetRef(var)); + if (auto var = obj.as()) { + covered.insert(var.value()); } return true; }); diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index e657b4f4663d..de7bd930948f 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -291,10 +291,10 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* repl_dict.Set(origin_itervar->var, new_var + info.in_bound_region[i]->min); // update new loop range - Var loop_var = GetRef(realize->iter_values[i].as()); - if (loop_var.defined() && new_loop_ranges.count(loop_var)) { + if (auto opt = realize->iter_values[i].as(); opt && new_loop_ranges.count(opt.value())) { // if the block binding is the loop var with single child, mutate loop range // instead of insert extra block predicate + auto loop_var = opt.value(); new_loop_ranges.Set(loop_var, new_range); new_iter_binding.push_back(realize->iter_values[i]); repl_dict.Set(loop_var, loop_var + info.in_bound_region[i]->min); diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 87ec6e550dcd..588770d968ef 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -154,11 +154,11 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { - if (const auto* block = block_or_loop_rv.as()) { - return sch->GetChildBlocks(GetRef(block)); + if (auto block = block_or_loop_rv.as()) { + return sch->GetChildBlocks(block.value()); } - if (const auto* loop = block_or_loop_rv.as()) { - return sch->GetChildBlocks(GetRef(loop)); + if (auto loop = block_or_loop_rv.as()) { + return sch->GetChildBlocks(loop.value()); } LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); throw; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a26843b7bd05..3270de23b520 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -1059,10 +1059,10 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) { - if (const auto* block = rv.as()) { - return sch->AddUnitLoop(GetRef(block)); - } else if (const auto* loop = rv.as()) { - return sch->AddUnitLoop(GetRef(loop)); + if (auto block = rv.as()) { + return sch->AddUnitLoop(block.value()); + } else if (auto loop = rv.as()) { + return sch->AddUnitLoop(loop.value()); } else { LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block"; throw; diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fdf473ff7972..20a044439b94 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -82,14 +82,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { - if (const auto* loop_rv = obj.as()) { - return self->Get(GetRef(loop_rv)); + if (auto loop_rv = obj.as()) { + return self->Get(loop_rv.value()); } - if (const auto* block_rv = obj.as()) { - return self->Get(GetRef(block_rv)); + if (auto block_rv = obj.as()) { + return self->Get(block_rv.value()); } - if (const auto* expr_rv = obj.as()) { - return self->Get(GetRef(expr_rv)); + if (auto expr_rv = obj.as()) { + return self->Get(expr_rv.value()); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() << ". Its value is: " << obj; @@ -97,28 +97,28 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { - if (const auto* loop_rv = obj.as()) { - return self->GetSRef(GetRef(loop_rv)); + if (auto loop_rv = obj.as()) { + return self->GetSRef(loop_rv.value()); } - if (const auto* block_rv = obj.as()) { - return self->GetSRef(GetRef(block_rv)); + if (auto block_rv = obj.as()) { + return self->GetSRef(block_rv.value()); } - if (const auto* stmt = obj.as()) { - return self->GetSRef(GetRef(stmt)); + if (auto stmt = obj.as()) { + return self->GetSRef(stmt.value()); } LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") .set_body_typed([](Schedule self, ObjectRef obj) -> void { - if (const auto* loop_rv = obj.as()) { - return self->RemoveRV(GetRef(loop_rv)); + if (auto loop_rv = obj.as()) { + return self->RemoveRV(loop_rv.value()); } - if (const auto* block_rv = obj.as()) { - return self->RemoveRV(GetRef(block_rv)); + if (auto block_rv = obj.as()) { + return self->RemoveRV(block_rv.value()); } - if (const auto* expr_rv = obj.as()) { - return self->RemoveRV(GetRef(expr_rv)); + if (auto expr_rv = obj.as()) { + return self->RemoveRV(expr_rv.value()); } LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; @@ -138,11 +138,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") .set_body_typed([](Schedule self, ObjectRef rv) { - if (const auto* block_rv = rv.as()) { - return self->GetChildBlocks(GetRef(block_rv)); + if (auto block_rv = rv.as()) { + return self->GetChildBlocks(block_rv.value()); } - if (const auto* loop_rv = rv.as()) { - return self->GetChildBlocks(GetRef(loop_rv)); + if (auto loop_rv = rv.as()) { + return self->GetChildBlocks(loop_rv.value()); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; @@ -164,10 +164,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") .set_body_method(&ScheduleNode::ReorderBlockIterVar); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { - if (const auto* loop_rv = rv.as()) { - return self->AddUnitLoop(GetRef(loop_rv)); - } else if (const auto* block_rv = rv.as()) { - return self->AddUnitLoop(GetRef(block_rv)); + if (auto loop_rv = rv.as()) { + return self->AddUnitLoop(loop_rv.value()); + } else if (auto block_rv = rv.as()) { + return self->AddUnitLoop(block_rv.value()); } else { LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; @@ -229,10 +229,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_method(&ScheduleNode::Blockize); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { - if (const auto* block_rv = rv.as()) { - self->Tensorize(GetRef(block_rv), intrin, preserve_unit_iters); - } else if (const auto* loop_rv = rv.as()) { - self->Tensorize(GetRef(loop_rv), intrin, preserve_unit_iters); + if (auto block_rv = rv.as()) { + self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); + } else if (auto loop_rv = rv.as()) { + self->Tensorize(loop_rv.value(), intrin, preserve_unit_iters); } else { LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; @@ -243,11 +243,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, const ObjectRef& ann_val) { - if (const auto* block_rv = rv.as()) { - return self->Annotate(GetRef(block_rv), ann_key, ann_val); + if (auto block_rv = rv.as()) { + return self->Annotate(block_rv.value(), ann_key, ann_val); } - if (const auto* loop_rv = rv.as()) { - return self->Annotate(GetRef(loop_rv), ann_key, ann_val); + if (auto loop_rv = rv.as()) { + return self->Annotate(loop_rv.value(), ann_key, ann_val); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; @@ -255,11 +255,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { - if (const auto* block_rv = rv.as()) { - return self->Unannotate(GetRef(block_rv), ann_key); + if (auto block_rv = rv.as()) { + return self->Unannotate(block_rv.value(), ann_key); } - if (const auto* loop_rv = rv.as()) { - return self->Unannotate(GetRef(loop_rv), ann_key); + if (auto loop_rv = rv.as()) { + return self->Unannotate(loop_rv.value(), ann_key); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index a7a1c0d482cc..aa2efacd0d35 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -421,8 +421,9 @@ class StateCreator : private StmtVisitor { StateCreator creator(self); for (const auto& kv : n->mod->functions) { const BaseFunc& base_func = kv.second; - if (const auto* func = base_func.as()) { - VerifyWellFormed(GetRef(func)); + if (auto opt = base_func.as()) { + auto func = opt.value(); + VerifyWellFormed(func); creator.VisitStmt(func->body); BlockInfoCollector::Collect(self, func->body); } diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index da7a1e395dca..4b10df7e9728 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -78,13 +78,13 @@ Array TranslateInputRVs(const Array& inputs, auto it = rv_map.find(input.get()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; result.push_back(GetRef(it->second)); - } else if (const auto* expr = input.as()) { // RV: Expr - result.push_back(Substitute(GetRef(expr), f_subst_with_rv_map)); - } else if (const auto* index_map = input.as()) { - result.push_back(Substitute(GetRef(index_map), f_subst_with_rv_map)); - } else if (input->IsInstance()) { + } else if (auto expr = input.as()) { // RV: Expr + result.push_back(Substitute(expr.value(), f_subst_with_rv_map)); + } else if (auto index_map = input.as()) { + result.push_back(Substitute(index_map.value(), f_subst_with_rv_map)); + } else if (auto arr = input.as>()) { // Recursively convert elements of the array into a new list of ObjectRefs. - result.push_back(TranslateInputRVs(Downcast>(input), rv_map)); + result.push_back(TranslateInputRVs(arr.value(), rv_map)); } else { ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " << input->GetTypeKey(); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 99fad558cfe2..cc57735df6dd 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -304,13 +304,13 @@ class BF16ComputeLegalizer : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - if (auto* buffer = op->node.as()) { - auto it = buffer_remap_.find(GetRef(buffer)); + if (auto buffer = op->node.as()) { + auto it = buffer_remap_.find(buffer.value()); if (it != buffer_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto* var = op->node.as()) { - auto it = var_remap_.find(GetRef(var)); + } else if (auto var = op->node.as()) { + auto it = var_remap_.find(var.value()); if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } @@ -523,13 +523,13 @@ class BF16StorageLegalizer : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - if (auto* buffer = op->node.as()) { - auto it = buffer_remap_.find(GetRef(buffer)); + if (auto buffer = op->node.as()) { + auto it = buffer_remap_.find(buffer.value()); if (it != buffer_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto* var = op->node.as()) { - auto it = var_remap_.find(GetRef(var)); + } else if (auto var = op->node.as()) { + auto it = var_remap_.find(var.value()); if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index f9e620ba3322..1bb5c2fabc68 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -95,9 +95,8 @@ tvm::transform::Pass ExtractPrimFuncConstants() { auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) { auto m = GetRef(module.CopyOnWrite()); for (const auto& kv : m->functions) { - BaseFunc f = kv.second; - if (f->IsInstance()) { - m->Update(kv.first, prim_func_pass(GetRef(f.as()), m, pc)); + if (auto func = kv.second.as()) { + m->Update(kv.first, prim_func_pass(func.value(), m, pc)); } } return m; diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index ffc58f3a42b7..494fd7184fc3 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -254,8 +254,8 @@ class HoistInfoCollector : public StmtExprVisitor { Var var; if (const auto* node_iter_var = op->node.as()) { var = node_iter_var->var; - } else if (const auto* node_var = op->node.as()) { - var = GetRef(node_var); + } else if (auto opt = op->node.as()) { + var = opt.value(); } else { return Parent::VisitStmt_(op); } diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 410efba9e215..9fdab35c85f4 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -102,8 +102,8 @@ class RollingBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (auto b = op->node.as()) { - auto buffer = GetRef(b); + if (auto opt = op->node.as()) { + auto buffer = opt.value(); // Keep a dictionary associating attribute statements with the buffers // they reference. We'll need this if the buffer gets hoisted and we // need to hoist all of its attributes at the same time. @@ -213,7 +213,7 @@ class RollingBufferInjector : public StmtExprMutator { auto stmt{StmtExprMutator::VisitStmt_(op)}; op = stmt.as(); - if (rolling_buffers.count(GetRef(op->node.as()))) { + if (auto opt = op->node.as(); opt && rolling_buffers.count(opt.value())) { // Remove the attribute statements attached to rolling buffers // because they will have been hoisted to the relevant rolling // scope diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 8480189855b8..b2f95ad2d590 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -166,8 +166,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { // remap these vars when needed // TODO(tvm-team): remove the rewriting once the buffer var // attrs are being refactored into the corresponding definition node - if (const auto* var_node = op->node.as()) { - auto it = var_remap_.find(GetRef(var_node)); + if (auto var_node = op->node.as()) { + auto it = var_remap_.find(var_node.value()); if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index ce74fdc4c17b..9a702db69f55 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -138,10 +138,10 @@ class OpaqueBlockLower : public StmtExprMutator { PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) { if (!obj.defined()) { return PrimExpr(); - } else if (const PrimExprNode* expr = obj.as()) { - return GetRef(expr); - } else if (const StringObj* str = obj.as()) { - return std::move(StringImm(str->data)); + } else if (auto expr = obj.as()) { + return expr.value(); + } else if (auto str = obj.as()) { + return std::move(StringImm(str.value())); } else { LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey() << " not supported"; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index ed3a2da19613..2f5fa6572159 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -305,8 +305,8 @@ Pass MakePackedAPI() { std::vector> updates; for (const auto& kv : mptr->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); + if (auto opt = kv.second.as()) { + auto func = opt.value(); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func)); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index e44eb34068a6..e327b3094594 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -93,8 +93,8 @@ Pass MakeUnpackedAPI() { std::vector> updates; for (const auto& kv : mptr->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); + if (auto opt = kv.second.as()) { + auto func = opt.value(); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakeUnpackedAPI(std::move(func)); diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index 3a2ef796c688..7f6930e2e2bf 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -267,11 +267,10 @@ Pass InstrumentProfileIntrinsics() { if (reset_start_id) lwp::start_id = 0; std::vector> updates; for (const auto& kv : mptr->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); - auto updated_func = - lwp::AddProfileBuiltins(func, max_instr_depth.IntValue(), min_instr_height.IntValue(), - instr_siblings, disable_func_instrumentation); + if (auto func = kv.second.as()) { + auto updated_func = lwp::AddProfileBuiltins(func.value(), max_instr_depth.IntValue(), + min_instr_height.IntValue(), instr_siblings, + disable_func_instrumentation); updates.push_back({kv.first, updated_func}); } } diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index 8c8824719276..0c3f7a9ba32f 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -85,8 +85,8 @@ class ElseBranchStripper : public StmtExprMutator { private: Stmt VisitStmt_(const IfThenElseNode* op) override { IfThenElse ret = Downcast(StmtExprMutator::VisitStmt_(op)); - auto as_eval = ret->else_case.as(); - if (as_eval && new_else_clauses_.count(GetRef(as_eval))) { + if (auto as_eval = ret->else_case.as(); + as_eval && new_else_clauses_.count(as_eval.value())) { return IfThenElse(ret->condition, ret->then_case); } else { return std::move(ret); diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 7eac6645239e..8cb01dfe6d07 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -181,8 +181,8 @@ class RenewDefMutator : public StmtExprMutator { auto it = remap_.find(expr); if (it != remap_.end()) { return Downcast((*it).second); - } else if (const VarNode* var = expr.as()) { - return this->ReDefineVar(GetRef(var)); + } else if (auto var = expr.as()) { + return this->ReDefineVar(var.value()); } else { return ExprMutator::VisitExpr(expr); } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 21660991ab37..7646d01f8e90 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -49,8 +49,8 @@ class UnsafeExprDetector : public ExprFunctor { } } return false; - } else if (auto* ptr_op = op->op.as()) { - auto effect_kind = op_call_effect_[GetRef(ptr_op)]; + } else if (auto opt = op->op.as()) { + auto effect_kind = op_call_effect_[opt.value()]; if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) { for (PrimExpr e : op->args) { if (VisitExpr(e)) return true; diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 90d7bedcf97b..c21afe400c56 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -353,7 +353,7 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - Var buffer_var(GetRef(op->args[1].as())); + Var buffer_var(Downcast(op->args[1])); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index ade707eb6a92..b80a71aa311c 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -329,8 +329,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; return Call(op->dtype.with_lanes(lane), op->op, new_args); } - auto* op_ptr = op->op.as(); - bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); + auto optional_op = op->op.as(); + bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false); if (!vectorizable) { // Cannot vectorize this op