diff --git a/oneflow/api/common/folder_rule_table.h b/oneflow/api/common/folder_rule_table.h new file mode 100644 index 00000000000..d945b8f64a1 --- /dev/null +++ b/oneflow/api/common/folder_rule_table.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_ +#define ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_ + +#include "oneflow/core/common/singleton.h" +#include "oneflow/core/framework/folder_rule_table.h" + +namespace oneflow { + +inline std::vector& GetFolderRuleTable() { + auto folder_rule_table= Singleton::Get(); + return folder_rule_table->GetRules(); +} + +inline void AppendRuleToFolderRuleTable(std::string new_rule) { + auto folder_rule_table= Singleton::Get(); + folder_rule_table->Append(new_rule); +} + +inline void ResetFolderRuleTable() { + auto folder_rule_table= Singleton::Get(); + folder_rule_table->Reset(); +} + +} // namespace oneflow + +#endif // ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_ \ No newline at end of file diff --git a/oneflow/api/python/framework/folder_rule_table.cpp b/oneflow/api/python/framework/folder_rule_table.cpp new file mode 100644 index 00000000000..650a00e9c27 --- /dev/null +++ b/oneflow/api/python/framework/folder_rule_table.cpp @@ -0,0 +1,15 @@ +#include +#include +#include +#include "oneflow/api/common/folder_rule_table.h" +#include "oneflow/api/python/of_api_registry.h" + +namespace py = pybind11; + +namespace oneflow { + +ONEFLOW_API_PYBIND11_MODULE("", m) { + m.def("GetFolderRuleTable", &GetFolderRuleTable, py::return_value_policy::reference_internal); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/folder_rule_table.h b/oneflow/core/framework/folder_rule_table.h new file mode 100644 index 00000000000..ffab11ff005 --- /dev/null +++ b/oneflow/core/framework/folder_rule_table.h @@ -0,0 +1,40 @@ +#ifndef ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_ +#define ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_ + +#include +#include +#include "oneflow/core/common/util.h" + +namespace oneflow { + +template +class Singleton; + +class FolderRuleTable final { + public: + OF_DISALLOW_COPY_AND_MOVE(FolderRuleTable); + ~FolderRuleTable() = default; + void Append(std::string new_rule) { + if(!infix_rules_.empty()){ + for(auto& rule : infix_rules_) { + if(new_rule!=rule && new_rule.find(rule)!=std::string::npos) { + rule = new_rule; + return; + } + } + } + infix_rules_.push_back(new_rule); + } + void Reset() { + infix_rules_.clear(); + } + std::vector& GetRules() {return infix_rules_;} + private: + friend class Singleton; + FolderRuleTable() = default; + std::vector infix_rules_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_ \ No newline at end of file diff --git a/oneflow/core/framework/multi_client_session_context.cpp b/oneflow/core/framework/multi_client_session_context.cpp index 1ff15f13aec..5616700b36d 100644 --- a/oneflow/core/framework/multi_client_session_context.cpp +++ b/oneflow/core/framework/multi_client_session_context.cpp @@ -39,6 +39,7 @@ limitations under the License. #include "oneflow/core/job/collective_boxing/scheduler.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/framework/variable_tensor_mgr.h" +#include "oneflow/core/framework/folder_rule_table.h" #ifdef WITH_CUDA #include #endif // WITH_CUDA @@ -113,6 +114,7 @@ Maybe MultiClientSessionContext::TryInit(const ConfigProto& config_proto) Singleton::New(); Singleton::New(); Singleton::New(); + Singleton::New(); } is_inited_ = true; diff --git a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp index c1a974e13be..074f9148d03 100644 --- a/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp +++ b/oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp @@ -30,6 +30,8 @@ limitations under the License. #include "oneflow/core/functional/functional_api.yaml.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/framework/variable_tensor_mgr.h" +#include "oneflow/api/common/folder_rule_table.h" + namespace mlir { namespace oneflow { @@ -44,6 +46,21 @@ StringAttr GenNewVariableOpName(MLIRContext* ctx, const std::string& key = "") { return StringAttr::get(ctx, "variable_" + key + "_" + ::oneflow::NewUniqueId()); } +StringAttr GenNewUnaryVariableOpName(MLIRContext* ctx, const std::string& operand_name, + const std::string& oprator_name) { + std::string infix_rule = oprator_name + " ( " + operand_name + " )"; + ::oneflow::AppendRuleToFolderRuleTable(infix_rule); + return StringAttr::get(ctx, infix_rule); +} + +StringAttr GenNewBinaryVariableOpName(MLIRContext* ctx, const std::string& lhs_operand_name, + const std::string& rhs_operand_name, + const std::string& oprator_name) { + std::string infix_rule = "( " + lhs_operand_name + " ) " + oprator_name + " ( " + rhs_operand_name + " )"; + ::oneflow::AppendRuleToFolderRuleTable(infix_rule); + return StringAttr::get(ctx, infix_rule); +} + bool MLIRDataTypesAreSame(const std::vector& data_types) { if (data_types.empty() || data_types.size() == 1) { return true; } bool result = true; @@ -63,6 +80,7 @@ bool DictionaryAttrsHaveSameDataType(const std::vector& at } OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, + const std::string& operator_name, const std::function& f) { ::oneflow::LazyMode::Guard guard{false}; if (!operands.front()) { return {}; } // Important! @@ -74,7 +92,9 @@ OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, attr_dict.get(OpTrait::IsOpConfCompatible::getDeviceNameAttr())); const auto result = f(tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); - attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); + auto operand_name = attr_dict.get("op_name").cast().getValue().str(); + attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), + GenNewUnaryVariableOpName(ctx, operand_name, operator_name)); attrs.set(OpTrait::TensorSource::getDataTypeAttrName(), attr_dict.get(OpTrait::TensorSource::getDataTypeAttrName())); @@ -82,6 +102,7 @@ OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef operands, } OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef operands, + const std::string& operator_name, const std::function& f) { ::oneflow::LazyMode::Guard guard{false}; if (!(operands.front() && operands.back())) { return {}; } // Important! @@ -107,7 +128,10 @@ OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef operands, const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow(); attrs.set("value", support::TensorToDenseElementsAttr(result, ctx)); - attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), GenNewVariableOpName(ctx)); + auto lhs_operand_name = lhs_attr_dict.get("op_name").cast().getValue().str(); + auto rhs_operand_name = rhs_attr_dict.get("op_name").cast().getValue().str(); + attrs.set(OpTrait::IsOpConfCompatible::getOpNameAttr(), + GenNewBinaryVariableOpName(ctx, lhs_operand_name, rhs_operand_name, operator_name)); attrs.set(OpTrait::TensorSource::getDataTypeAttrName(), lhs_attr_dict.get(OpTrait::TensorSource::getDataTypeAttrName())); @@ -129,22 +153,33 @@ OpFoldResult FrozenVariableOp::fold(FoldAdaptor adaptor) { return DictionaryAttr::get(getContext(), attrs); } +template +std::string VectorToString(const std::vector& vec) { + std::stringstream ss; + ss << "["; + for (const auto& elem : vec) { + ss << elem << ","; + } + ss << "]"; + return ss.str(); +} + OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return UnaryFold(getContext(), operands, [this](const auto& tensor) { - std::vector perm_; - for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast().getSInt()); } + std::vector perm_; + for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast().getSInt()); } + return UnaryFold(getContext(), operands, "Transpose("+VectorToString(perm_)+")", [this, &perm_](const auto& tensor) { return functional::Transpose(tensor, perm_); }); } OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return UnaryFold(getContext(), operands, [this](const auto& tensor) { - std::vector shape_vec; - for (auto& x : getShape().getValue()) { - shape_vec.emplace_back(x.cast().getValue().getSExtValue()); - } + std::vector shape_vec; + for (auto& x : getShape().getValue()) { + shape_vec.emplace_back(x.cast().getValue().getSExtValue()); + } + return UnaryFold(getContext(), operands, "Reshape("+VectorToString(shape_vec)+")", [this, &shape_vec](const auto& tensor) { return functional::Reshape( tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end()))); }); @@ -152,7 +187,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor { + return UnaryFold(getContext(), operands, "ScalarAdd("+std::to_string(getIntOperand())+")", [this](const auto& tensor) -> MaybeTensor { if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); } if (getHasFloatOperand()) { return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false); @@ -164,24 +199,25 @@ OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) { OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return UnaryFold(getContext(), operands, functional::Sqrt); + return UnaryFold(getContext(), operands, "Sqrt", functional::Sqrt); } OpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return BinaryFold(getContext(), operands, functional::Mul); + return BinaryFold(getContext(), operands, "BroadcastMul", functional::Mul); } OpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return BinaryFold(getContext(), operands, functional::Div); + return BinaryFold(getContext(), operands, "BroadcastDiv", functional::Div); } OpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor { - return functional::Sub(lhs, rhs, /*alpha=*/1.0, false); - }); + return BinaryFold(getContext(), operands, "BroadcastSub", + [](const auto& lhs, const auto& rhs) -> MaybeTensor { + return functional::Sub(lhs, rhs, /*alpha=*/1.0, false); + }); } } // namespace oneflow diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index c3b98d7839a..322a556af51 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -61,6 +61,8 @@ from oneflow.nn.optimizer.lr_scheduler import LRScheduler from oneflow.optim.optimizer import Optimizer +from oneflow.nn.graph.tensor_folder_map import TensorFolderMap + class Graph(object): r"""Base class for training or evaluating a neural network in static graph mode. @@ -282,6 +284,9 @@ def __call__(self, *args, **kwargs): if not self._is_compiled: self._compile(*args, **kwargs) + + # generater tensor folder map + self.tensor_folder_map = TensorFolderMap(oneflow._oneflow_internal.GetFolderRuleTable()) return self.__run(*args, **kwargs) diff --git a/python/oneflow/nn/graph/tensor_folder_map.py b/python/oneflow/nn/graph/tensor_folder_map.py new file mode 100644 index 00000000000..6534b184359 --- /dev/null +++ b/python/oneflow/nn/graph/tensor_folder_map.py @@ -0,0 +1,141 @@ +import oneflow as flow +from oneflow.framework.tensor import Tensor +# import torch as flow +# from torch import Tensor +from typing import List, Callable, Tuple + +_unary_ops = set(['Transpose', 'Reshape', 'ScalarAdd', 'Sqrt',]) +_binary_ops = set(['BroadcastMul', 'BroadcastDiv', 'BroadcastSub']) + +def _is_unary_op(token): + global _unary_ops + return any(token.find(op)!=-1 for op in _unary_ops) + +def _is_binary_op(token): + global _binary_ops + return token in _binary_ops + +def _cal_transpose(token, var): + arg_start_index = token.find('[') + arg_end_index = token.find(']') + arg = [int(x) for x in token[arg_start_index+1:arg_end_index].split(",")[:-1]] + # TODO: 不确定 transpose的参数是什么样的,因为 resnet18 没用到 transpose + return flow.transpose(var, *arg) + +def _cal_reshape(token, var): + arg_start_index = token.find('[') + arg_end_index = token.find(']') + arg = [int(x) for x in token[arg_start_index+1:arg_end_index].split(",")[:-1]] + return flow.reshape(var, arg) + +def _cal_scalar_add(token, var): + arg_start_index = token.find('(') + arg_end_index = token.find(')') + arg = int(token[arg_start_index+1:arg_end_index]) + return var + arg + +def _get_eval_func(postfix_rule: List[str]): + + def eval(*inputs: List[Tensor]) -> Tensor: + global _unary_ops + global _binary_ops + cnt = 0 + stack: List[Tensor] = [] + for token in postfix_rule: + if _is_unary_op(token): + var = stack.pop() + if token.find('Transpose') != -1 : + stack.append(_cal_transpose(token, var)) + elif token.find('Reshape') != -1 : + stack.append(_cal_reshape(token, var)) + elif token.find('ScalarAdd') != -1 : + stack.append(_cal_scalar_add(token, var)) + elif token == 'Sqrt' : + stack.append(flow.sqrt(var)) + else : + raise ValueError("Bad Unary Operator " + token) + elif _is_binary_op(token): + rhs = stack.pop() + lhs = stack.pop() + if token == 'BroadcastMul': + stack.append(lhs * rhs) + elif token == 'BroadcastDiv': + stack.append(lhs / rhs) + elif token == 'BroadcastSub': + stack.append(lhs - rhs) + else: + raise ValueError("Bad Binaary Operator " + token) + else: + stack.append(inputs[cnt]) + cnt+=1 + + assert len(stack) == 1, "Bad postfix rule: " + " ".join(postfix_rule) + return stack[0] + + return eval + + +def _transform_infix_to_postfix(infix_rule: str) -> Tuple[List[str], List[str]]: + global _unary_ops + global _binary_ops + # infix to postfix + input_tensor_names = [] + op_stack = [] + postfix_rule = [] + for token in infix_rule.split(): + if token == "(" : + op_stack.append(token) + elif token == ")" : + while len(op_stack)!=0 and op_stack[-1]!="(" : + postfix_rule.append(op_stack.pop()) + assert len(op_stack)!=0 and op_stack[-1]=="(", "input rules with unmatch brackets: " + infix_rule + op_stack.pop() + elif _is_unary_op(token) or _is_binary_op(token): + # Because all ops are wrapped in parentheses + # we don’t need to compare the priority of the top of the stack and the new op + op_stack.append(token) + else: # tensor + postfix_rule.append(token) + input_tensor_names.append(token) + while len(op_stack)!=0 : + postfix_rule.append(op_stack.pop()) + + return postfix_rule, input_tensor_names + +# help func +def _transform(infix_rule: str) -> Tuple[List[str], str, Callable[[List[Tensor]], Tensor]]: + ''' + transform a infix eval rule to a TensorFolderMap element + for example, input rule is ( model.conv1.weight ) Broadcast ( Reshape ( ( model.bn1.weight ) BroadcastDiv ( Sqrt ( ScalarAdd ( model.bn1.running_var ) ) ) ) ) + the output are [ + input_names = [model.conv1.weight, model.bn1.weight, model.bn1.running_var] + ouput_name = infix_rule + a function to eval rule, its input are tensors which names are input_names, its output is the eval result, which tensor is match ouput_name + ] + ''' + postfix_rule, input_tensor_names = _transform_infix_to_postfix(infix_rule) + + return input_tensor_names, infix_rule, _get_eval_func(postfix_rule) + + +class TensorFolderMap: + def __init__(self, rules:List[str]): + super().__init__() + # include input tensor name list, output tensor name, a function of (List[Tensor]) -> Tensor + self._map: List[Tuple[List[str], str, Callable[[List[Tensor]], Tensor]]] = [] + for rule in rules: + self._map.append(_transform(rule)) + + @property + def map(self): + return self._map + + + +if __name__ == "__main__" : + infix_rule = "( model.conv1.weight ) BroadcastMul ( Reshape ( ( model.bn1.weight ) BroadcastDiv ( Sqrt ( ScalarAdd ( model.bn1.running_var ) ) ) ) )" + postfix_rule, input_tensor_names = _transform_infix_to_postfix(infix_rule) + print(" ".join(postfix_rule)) + print(input_tensor_names) + + \ No newline at end of file diff --git a/python/oneflow/test/graph/test_tensor_folder_map.py b/python/oneflow/test/graph/test_tensor_folder_map.py new file mode 100644 index 00000000000..e7e08b4a304 --- /dev/null +++ b/python/oneflow/test/graph/test_tensor_folder_map.py @@ -0,0 +1,50 @@ +import oneflow as flow +from flowvision.models.resnet import resnet18 +import oneflow.nn as nn +import numpy as np +import copy +import os +os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" +os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" +os.environ["ONEFLOW_DEBUG_MODE"] = "1" +os.environ["ONEFLOW_MLIR_DUMPMLIR"] = "1" + +def test_tensor_folder_map(): + data = flow.randn(1, 3, 224, 224) + + model = resnet18(pretrained=False, progress=True) + model.eval() + eager_res = model(data) + copymodel = copy.deepcopy(model) + param_table = dict(copymodel.named_parameters(prefix='model')) + buffer_table = dict(copymodel.named_buffers(prefix='model')) + param_table.update(buffer_table) + + class Resnet18Graph(nn.Graph): + def __init__(self): + super().__init__() + self.model = model + + def build(self, *input): + return self.model(*input) + + graph = Resnet18Graph() + _ = graph(data) + + output_tensor_name_list, output_tensor_list = flow._oneflow_internal.DumpVariableTensorMgr() + for input_tensor_names, output_tensor_name, eval_func in graph.tensor_folder_map.map: + input_tensor_list = [param_table[name].data for name in input_tensor_names] + manual_output_tensor = eval_func(*input_tensor_list) + + index = output_tensor_name_list.index(output_tensor_name) + automatic_output_tensor = output_tensor_list[index] + + assert np.allclose(manual_output_tensor.numpy(), automatic_output_tensor.numpy(), rtol=1e-2, atol=1e-2) + + flow._oneflow_internal.FillVariableTensorMgr([output_tensor_name], [manual_output_tensor]) + + lazy_res = graph(data) + assert np.allclose(eager_res.numpy(), lazy_res.numpy(), rtol=1e-2, atol=1e-2) + +if __name__ == "__main__": + test_tensor_folder_map() \ No newline at end of file