Skip to content

Commit

Permalink
[Relay][VM] Clean up the VM and VM profiler code (apache#4391)
Browse files Browse the repository at this point in the history
* [VM] add a few more API to vm

* [VM][Fix] fix vm convert args

* [VM] a few fixes

* rename fields

* update

* update vm profiler

* x

* add doc

* lint

* fix test

* address comments
  • Loading branch information
icemelon authored and zhiics committed Nov 22, 2019
1 parent 1562eae commit 122a493
Show file tree
Hide file tree
Showing 11 changed files with 437 additions and 446 deletions.
326 changes: 183 additions & 143 deletions include/tvm/runtime/vm.h

Large diffs are not rendered by default.

58 changes: 7 additions & 51 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,68 +22,24 @@
"""
from . import vm, _vm

def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
exec : Executable
The executable with profiling code.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return vm.Executable(compiler._get_exec())

def enabled():
"""Whether vm profiler is enabled."""
return hasattr(_vm, "_VMCompilerProfiler")

class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
return hasattr(_vm, "_VirtualMachineDebug")

class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super().__init__(mod)
super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]
self._set_input = self.mod["set_input"]
self._reset = self.mod["reset"]

def get_stat(self):
return self._get_stat()

def reset(self):
self._reset()
72 changes: 65 additions & 7 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
if isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
else:
raise "unsupported type"
raise "Unsupported type: %s" % (type(arg))


def convert(args):
Expand All @@ -57,10 +59,13 @@ class Executable(object):
"""Relay VM executable"""
def __init__(self, mod):
self.mod = mod
self._function_params = {}
self._save = self.mod["save"]
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]

def save(self):
"""Save the Relay VM Executable.
Expand Down Expand Up @@ -239,6 +244,20 @@ def module(self):
"""Return the runtime module contained in a virtual machine executable."""
return self.mod

def get_function_params(self, func_name):
"""Get VM Function parameters"""
if func_name in self._function_params:
return self._function_params[func_name]
arity = self._get_function_arity(func_name)
assert arity >= 0
params = []
for i in range(arity):
p = self._get_function_param_name(func_name, i)
assert p
params.append(p)
self._function_params[func_name] = params
return params


class VirtualMachine(object):
"""Relay VM runtime."""
Expand All @@ -248,8 +267,10 @@ def __init__(self, mod):
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m)
self._exec = mod
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._set_input = self.mod["set_input"]

def init(self, ctx):
"""Initialize the context in the VM.
Expand All @@ -262,7 +283,37 @@ def init(self, ctx):
args = [ctx.device_type, ctx.device_id]
self._init(*args)

def invoke(self, func_name, *args):
def set_input(self, func_name, *args, **kwargs):
"""Set the input to a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
"""
if kwargs:
func_params = self._exec.get_function_params(func_name)
new_args = [None] * len(func_params)
assert len(args) + len(kwargs) == len(func_params)
for k in kwargs:
idx = func_params.index(k)
new_args[idx] = kwargs[k]
idx = 0
for i, arg in enumerate(new_args):
if arg is None:
new_args[i] = args[idx]
idx += 1
args = new_args
cargs = convert(args)
self._set_input(func_name, *cargs)

def invoke(self, func_name, *args, **kwargs):
"""Invoke a function.
Parameters
Expand All @@ -273,28 +324,35 @@ def invoke(self, func_name, *args):
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns
-------
result : Object
The output.
"""
cargs = convert(args)
return self._invoke(func_name, *cargs)
if args or kwargs:
self.set_input(func_name, *args, **kwargs)
return self._invoke(func_name)

def run(self, *args):
def run(self, *args, **kwargs):
"""Run the main function.
Parameters
----------
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
return self.invoke("main", *args, **kwargs)


def compile(mod, target=None, target_host=None, params=None):
Expand Down
50 changes: 0 additions & 50 deletions src/relay/backend/vm/profiler/compiler.cc

This file was deleted.

65 changes: 49 additions & 16 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>
#include <memory>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -67,44 +68,76 @@ PackedFunc Executable::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Save();
});
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
*rv = this->GetFunctionArity(func_name);
});
} else if (name == "get_function_param_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
}
}

int Executable::GetFunctionArity(std::string func_name) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return -1;
}
const auto& func = functions[it->second];
return func.params.size();
}

std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return "";
}
const auto& func = functions[it->second];
if (index > func.params.size()) {
LOG(ERROR) << "Invalid parameter index";
return "";
}
return func.params[index];
}

std::string Executable::GetBytecode() const {
std::ostringstream oss;

for (const auto& func : functions) {
for (size_t i = 0; i < functions.size(); ++i) {
const auto& func = functions[i];
// Print the header of the function format.
oss << "# func name, reg file size, param count, inst count:"
<< std::endl;
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;

// Print pramams of a `VMFunction`.
oss << "# Parameters: "<< std::endl;
oss << "VM Function[" << i << "]: " << func.name << "(";
for (const auto& param : func.params) {
oss << param << " ";
oss << param << ", ";
}
oss << std::endl;
oss.seekp(-2, std::ios_base::end);
oss << ")" << std::endl;
oss << "# reg file size = " << func.register_file_size << std::endl;
oss << "# instruction count = " << func.instructions.size() << std::endl;

// Print the instructions of a `VMFunction`.
// The part after ";" is the instruction in text format.
oss << "hash, opcode, fields # inst(text):"<< std::endl;
for (const auto& instr : func.instructions) {
oss << "opcode, fields # inst(text):" << std::endl;
for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
const auto& instr = func.instructions[idx];
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " "
<< std::dec << serialized_instr.opcode << " ";
oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) {
oss << it << " ";
}
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
}
oss << std::endl;
}

return oss.str();
Expand Down
Loading

0 comments on commit 122a493

Please sign in to comment.