diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 6dc8f107e..b5c239afb 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -61,6 +61,7 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): self.py_output_names = [] self.graph_output_names = [] self.build_options = [] + self.output_nodes = [] self.folder = folder self.graph_key = graph_key @@ -194,8 +195,27 @@ def parse_outputs(self): self.py_output_names.append(str(node)) self.output_args = real_output_args + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + for output in self.output_args: + info = {} + info['format'] = 'ND' + if hasattr(output, 'meta'): + output = output.meta['val'] + if isinstance(output, torch.SymInt): + info['data_type'] = 'INT32' + elif isinstance(output, torch.SymBool): + info['data_type'] = 'BOOL' + else: + info['data_type'] = get_ascend_dtype(output.dtype) + self.output_nodes.append(info) if len(self.assign_args) > 0: self.graph_output_names.extend(list(zip(*self.assign_args))[0]) + for item in self.assign_args: + index = item[1] + info = {} + info['format'] = self.data_nodes[index]['format'] + info['data_type'] = self.data_nodes[index]['data_type'] + self.output_nodes.append(info) def gen_import_code(self): self.import_code.splice( @@ -206,7 +226,7 @@ def gen_import_code(self): import random from torch import empty_strided, as_strided, device from dicp.dynamo_bridge.compile import AsyncCompileKernel - from dicp.vendor.AscendGraph.compile_job import AscendCompileJob + from dicp.vendor.AscendGraph.compile_job import AscendGECompileAclRunJob, AscendGECompileGERunJob aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride @@ -461,6 +481,7 @@ def gen_graph_json(self): "build_options": self.build_options, "data_nodes": self.data_nodes, "common_nodes": self.common_nodes, + "output_nodes": self.output_nodes, } self.remove_symint(graph) return json.dumps(graph) @@ -468,9 +489,11 @@ def gen_graph_json(self): def gen_compile_graph_code(self): compile_graph_code = IndentedBuffer() graph_json = self.gen_graph_json() + compile_job_type = os.environ.get("DICP_ASCEND_COMPILE_JOB_TYPE", "AscendGECompileGERunJob") + assert compile_job_type in ["AscendGECompileGERunJob", "AscendGECompileAclRunJob"] compile_graph_code.splice( f""" - ascend_compile_job = AscendCompileJob('''{graph_json}''') + ascend_compile_job = {compile_job_type}('''{graph_json}''') async_compile = AsyncCompileKernel() kernel_cpp_0 = async_compile.compile_kernel(ascend_compile_job) """, strip=True diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ge_builder.h b/dicp/dicp/vendor/AscendGraph/codegen/ge_builder.h new file mode 100644 index 000000000..7426a93e1 --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/codegen/ge_builder.h @@ -0,0 +1,500 @@ +#ifndef DICP_ASCEND_GE_BUILDER_H +#define DICP_ASCEND_GE_BUILDER_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "all_ops.h" +#include "ascend_string.h" +#include "ge_api.h" +#include "ge_api_types.h" +#include "ge_error_codes.h" +#include "ge_ir_build.h" +#include "gnode.h" +#include "graph.h" +#include "graph_utils.h" +#include "tensor.h" +#include "types.h" + +#define FAILED -1 +#define SUCCESS 0 + +using namespace ge; +using json = nlohmann::json; +using OperatorMap = std::unordered_map; + +static std::unordered_set op_with_dynamic_inputs_outputs = { + "ConcatD", "IdentityN", "Pack", "SplitD", "SplitVD", "IncreFlashAttention"}; + +void check_op(std::unordered_map& op_map, + const std::string& op_name) { + if (op_map.count(op_name) > 0) { + throw std::runtime_error("op_name duplicated: " + op_name); + } +} + +void setTensorData(Tensor& tensor, uint8_t* src_data, uint64_t data_size, + const std::string& debug_name = "") { + auto status = tensor.SetData(reinterpret_cast(src_data), data_size); + if (status != ge::GRAPH_SUCCESS) { + std::cout << "Set " << debug_name << " tensor data failed!" << std::endl; + } +} + +ge::Tensor genTensor(const std::vector& tensor_shape, + ge::Format format, ge::DataType data_type) { + TensorDesc desc(ge::Shape(tensor_shape), format, data_type); + Tensor result(desc); + return result; +} + +template +ge::Tensor genTensorWithData(const std::vector& tensor_shape, + ge::Format format, ge::DataType data_type, + std::vector value) { + TensorDesc desc(ge::Shape(tensor_shape), format, data_type); + Tensor result(desc); + setTensorData(result, reinterpret_cast(value.data()), + value.size() * sizeof(T), "genTensorWithData"); + return result; +} + +ge::Operator genInput(const std::string op_name, + const std::vector shape, ge::Format format, + ge::DataType data_type, int index = -1) { + TensorDesc tensor_desc_data_op = + TensorDesc(ge::Shape(shape), format, data_type); + auto op = op::Data(op_name.c_str()); + op.update_input_desc_x(tensor_desc_data_op); + op.update_output_desc_y(tensor_desc_data_op); + if (index > -1) { + op.set_attr_index(index); + } + return op; +} + +class GEGraphBuilder { + public: + explicit GEGraphBuilder(const std::string& fusion_switch_file, + const std::string& ge_builder_config_file) + : _fusion_switch_file(fusion_switch_file), + _ge_builder_config_file(ge_builder_config_file) { + // 1. system init + std::map global_options; + + auto kSocVersion = aclrtGetSocName(); + global_options[ge::ir_option::SOC_VERSION] = kSocVersion; + global_options[ge::ir_option::FUSION_SWITCH_FILE] = + _fusion_switch_file.c_str(); + + auto raw_conf = parse_json_to_map(_ge_builder_config_file); + for (const auto& item : raw_conf) { + global_options[item.first.c_str()] = item.second.c_str(); + } + CALL_FUNC(aclgrphBuildInitialize(global_options)); + } + + void saveGraph(const std::string& path, const Graph& graph, + std::map& options) { + ModelBufferData model; + + auto status = aclgrphBuildModel(graph, options, model); + if (status == GRAPH_SUCCESS) { + std::cout << "Build Model SUCCESS!" << std::endl; + } else { + std::cout << "Build Model Failed! " << status << std::endl; + return; + } + + // 4. Save Ir Model + status = aclgrphSaveModel(path.c_str(), model); + if (status == GRAPH_SUCCESS) { + std::cout << "Save Offline Model SUCCESS!" << std::endl; + } else { + std::cout << "Save Offline Model Failed! " << status << std::endl; + } + } + + ~GEGraphBuilder() { + aclgrphBuildFinalize(); + std::cout << "aclgrphBuildFinalize success!" << std::endl; + } + + private: + std::string _fusion_switch_file; + std::string _ge_builder_config_file; +}; + +ge::Format get_ascend_format(const std::string& format) { + static std::unordered_map format_map = { + {"NCHW", FORMAT_NCHW}, + {"NHWC", FORMAT_NHWC}, + {"ND", FORMAT_ND}, + {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, + }; + if (format_map.count(format) > 0) { + return format_map[format]; + } + std::string error_msg = "invalid ascend foramt! format: " + format; + throw std::runtime_error(error_msg); +} + +ge::DataType get_ascend_datatype(const std::string& data_type) { + static std::unordered_map datatype_map = { + {"FLOAT", ge::DataType::DT_FLOAT}, {"FLOAT16", ge::DataType::DT_FLOAT16}, + {"INT32", ge::DataType::DT_INT32}, {"INT64", ge::DataType::DT_INT64}, + {"BOOL", ge::DataType::DT_BOOL}, {"UINT8", ge::DataType::DT_UINT8}, + {"BF16", ge::DataType::DT_BF16}, + }; + if (datatype_map.count(data_type) > 0) { + return datatype_map[data_type]; + } + std::string error_msg = "invalid ascend data type! data type: " + data_type; + throw std::runtime_error(error_msg); +} + +template +T genDynamicOp(const std::string& op_name) { + return T(op_name.c_str()); +} + +template +void parseDynamicInput(std::unordered_map& op_map, + T& op, const json& node) { + if (node.contains("dynamic_inputs")) { + for (const auto& i : node["dynamic_inputs"]) { + auto num = i["num"].get(); + auto name = i["name"].get(); + if (name == "x") { + op.create_dynamic_input_x(num); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + if (item.contains("edge")) { + op.set_dynamic_input_x(index, value, + item["edge"].get().c_str()); + } else { + op.set_dynamic_input_x(index, value); + } + } + } else { + throw std::runtime_error("invalid dynamic input name: " + name); + } + } + } +} + +template <> +void parseDynamicInput(std::unordered_map& op_map, + op::IncreFlashAttention& op, const json& node) { + if (node.contains("dynamic_inputs")) { + int kv_inputs_num = 0; + for (const auto& i : node["dynamic_inputs"]) { + auto num = i["num"].get(); + auto name = i["name"].get(); + if (name == "key") { + kv_inputs_num = static_cast(num); + op.create_dynamic_input_byindex_key(num, 1); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + op.set_dynamic_input_key(index, value); + } + } else if (name == "value") { + if (kv_inputs_num == 0 && static_cast(num) == kv_inputs_num) { + throw std::runtime_error( + "need first set dynamic key input for IncreFlashAttention Op" + "and kv_inputs_num == num !!"); + } + op.create_dynamic_input_byindex_value(num, 1 + num); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + op.set_dynamic_input_value(index, value); + } + } else { + throw std::runtime_error("invalid dynamic input name: " + name); + } + } + } +} + +template +void parseDynamicOutput(T& op, const json& node) { + if (node.contains("dynamic_outputs")) { + for (const auto& o : node["dynamic_outputs"]) { + auto name = o["name"].get(); + auto num = o["num"].get(); + if (name == "y") { + op.create_dynamic_output_y(num); + } else { + throw std::runtime_error("invalid dynamic output name: " + name); + } + } + } +} + +ge::Operator genDynamicOperator(OperatorMap& op_map, const json& node) { + auto op_type = node["op_type"].get(); + auto op_name = node["op_name"].get(); + + using OpHandler = std::function; + static const std::unordered_map handlers = { + {"ConcatD", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicInput(op_map, op, node); + return op; + }}, + {"IdentityN", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicInput(op_map, op, node); + parseDynamicOutput(op, node); + return op; + }}, + {"Pack", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicInput(op_map, op, node); + return op; + }}, + {"IncreFlashAttention", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicInput(op_map, op, node); + return op; + }}, + {"SplitD", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicOutput(op, node); + return op; + }}, + {"SplitVD", + [](const std::string& op_name, OperatorMap& op_map, const json& node) { + auto op = genDynamicOp(op_name); + parseDynamicOutput(op, node); + return op; + }}}; + + auto it = handlers.find(op_type); + if (it != handlers.end()) { + return it->second(op_name, op_map, node); + } else { + throw std::runtime_error("invalid dynamic operator: " + op_type); + } +} + +template +T getValue(const json& node, const std::string& key) { + try { + return node.at(key).get(); + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + std::cerr << "JSON Node: " << node.dump(4) << std::endl; + throw std::runtime_error("getValue failed!"); + } +} + +TensorDesc getTensorDescFromJson(const json& desc) { + auto format = getValue(desc, "format"); + auto data_type = getValue(desc, "data_type"); + auto shape = getValue>(desc, "shape"); + TensorDesc tensor_desc(ge::Shape(shape), get_ascend_format(format), + get_ascend_datatype(data_type)); + return tensor_desc; +} + +void parseInputs(std::unordered_map& op_map, + ge::Operator& op, const json& inputs) { + for (const auto& i : inputs) { + auto name = getValue(i, "name").c_str(); + auto value = op_map[getValue(i, "value")]; + + if (i.contains("index")) { + op.SetInput(name, value, getValue(i, "index")); + } else if (i.contains("update_desc")) { + auto desc = i["update_desc"]; + auto tensor_desc = getTensorDescFromJson(desc); + auto output_name = getValue(desc, "output_name"); + if (output_name != "none") { + op_map[getValue(i, "value")].UpdateOutputDesc( + output_name.c_str(), tensor_desc); + } else { + op.UpdateInputDesc(name, tensor_desc); + } + op.SetInput(name, value); + } else { + op.SetInput(name, value); + } + } +} + +void parseOutputs(ge::Operator& op, const json& outputs) { + for (const auto& i : outputs) { + auto name = getValue(i, "output_name").c_str(); + auto tensor_desc = getTensorDescFromJson(i["update_desc"]); + op.UpdateOutputDesc(name, tensor_desc); + } +} + +template +void setTensorAttrHelper(ge::Operator& op, const std::string& attr_name, + const json& attr, const std::vector& dims, + const ge::Format& format, + const ge::DataType& data_type) { + auto value = getValue>(attr, "tensor_value"); + auto tensor = genTensorWithData(dims, format, data_type, value); + op.SetAttr(attr_name.c_str(), tensor); +} + +void setTensorAttr(ge::Operator& op, const std::string& attr_name, + const json& attr) { + auto cpp_data_type = getValue(attr, "tensor_cpp_data_type"); + auto data_type = + get_ascend_datatype(getValue(attr, "tensor_data_type")); + auto format = get_ascend_format(getValue(attr, "tensor_format")); + auto dims = getValue>(attr, "tensor_dims"); + + if (cpp_data_type == "FLOAT") { + setTensorAttrHelper(op, attr_name, attr, dims, format, data_type); + } else if (cpp_data_type == "FLOAT16") { + auto values = getValue>(attr, "tensor_value"); + std::vector half_values(values.begin(), values.end()); + auto tensor = genTensorWithData(dims, format, data_type, + half_values); + op.SetAttr(attr_name.c_str(), tensor); + } else if (cpp_data_type == "INT32") { + setTensorAttrHelper(op, attr_name, attr, dims, format, data_type); + } else if (cpp_data_type == "INT64") { + setTensorAttrHelper(op, attr_name, attr, dims, format, data_type); + } else { + throw std::runtime_error("invalid cpp data type: " + cpp_data_type); + } +} + +void parseAttrs(ge::Operator& op, const json& attrs) { + using AttrHandler = + std::function; + static const std::unordered_map handlers = { + {"str", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue(attr, "value")); + }}, + {"dtype_str", + [](ge::Operator& op, const std::string& name, const json& attr) { + auto value = getValue(attr, "value"); + op.SetAttr(name.c_str(), get_ascend_datatype(value)); + }}, + {"list_int", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), + getValue>(attr, "value")); + }}, + {"list_float", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue>(attr, "value")); + }}, + {"float", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue(attr, "value")); + }}, + {"int", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue(attr, "value")); + }}, + {"bool", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue(attr, "value")); + }}, + {"int64_t", + [](ge::Operator& op, const std::string& name, const json& attr) { + op.SetAttr(name.c_str(), getValue(attr, "value")); + }}, + {"tensor", [](ge::Operator& op, const std::string& name, + const json& attr) { setTensorAttr(op, name, attr); }}}; + + for (const auto& attr : attrs) { + auto attr_name = getValue(attr, "name"); + auto value_type = getValue(attr, "value_type"); + + auto it = handlers.find(value_type); + if (it != handlers.end()) { + it->second(op, attr_name, attr); + } else { + throw std::runtime_error("Invalid attr value type: " + value_type); + } + } +} + +void parseCommonNode(std::unordered_map& op_map, + ge::Operator& op, const json& node) { + if (node.contains("inputs")) { + parseInputs(op_map, op, node["inputs"]); + } + if (node.contains("outputs")) { + parseOutputs(op, node["outputs"]); + } + if (node.contains("attrs")) { + parseAttrs(op, node["attrs"]); + } +} + +void buildGraph(Graph& graph, const json& graph_json, + std::vector& input_tensors) { + std::unordered_map op_map; + json data_nodes = graph_json["data_nodes"]; + for (const auto& node : graph_json["data_nodes"]) { + auto node_name = getValue(node, "op_name"); + auto format = get_ascend_format(getValue(node, "format")); + auto data_type = + get_ascend_datatype(getValue(node, "data_type")); + auto index = getValue(node, "index"); + auto dims = getValue>(node, "dims"); + check_op(op_map, node_name); + op_map[node_name] = genInput(node_name, dims, format, data_type, index); + graph.AddOp(op_map[node_name]); + + // add tensor to inputs + TensorDesc cur_desc(ge::Shape(dims), format, data_type); + Tensor cur_tensor(cur_desc); + input_tensors.emplace_back(cur_tensor); + } + for (const auto& node : graph_json["common_nodes"]) { + auto node_name = getValue(node, "op_name"); + auto op_type = getValue(node, "op_type"); + + check_op(op_map, node_name); + if (op_with_dynamic_inputs_outputs.count(op_type) > 0) { + op_map[node_name] = genDynamicOperator(op_map, node); + } else { + op_map[node_name] = ge::OperatorFactory::CreateOperator(node_name.c_str(), + op_type.c_str()); + } + parseCommonNode(op_map, op_map[node_name], node); + graph.AddOp(op_map[node_name]); + } + std::vector graph_inputs; + std::vector graph_outputs; + for (const auto& i : graph_json["input_names"]) { + graph_inputs.push_back(op_map[i.get()]); + } + for (const auto& i : graph_json["output_names"]) { + graph_outputs.push_back(op_map[i.get()]); + } + graph.SetInputs(graph_inputs).SetOutputs(graph_outputs); +} + +#endif // DICP_ASCEND_GE_BUILDER_H diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ge_builder_config.json b/dicp/dicp/vendor/AscendGraph/codegen/ge_builder_config.json new file mode 100644 index 000000000..cc506c80a --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/codegen/ge_builder_config.json @@ -0,0 +1,3 @@ +{ + "PRECISION_MODE_V2": "mixed_float16" +} \ No newline at end of file diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ge_graph.h b/dicp/dicp/vendor/AscendGraph/codegen/ge_graph.h new file mode 100644 index 000000000..8fadb9e9e --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/codegen/ge_graph.h @@ -0,0 +1,184 @@ +#ifndef DICP_ASCEND_GE_GRAPH_H +#define DICP_ASCEND_GE_GRAPH_H +#include +#include +#include +#include +#include +#include +#include + +#include "ge_api.h" +#include "graph_utils.h" + +using namespace ge; + +class GEGraph { + public: + explicit GEGraph(int graph_id, std::string graph_key, Graph& graph, + std::shared_ptr spec, + std::vector& input_tensors) + : graph_id_(graph_id), + graph_(std::move(graph)), + graph_key_(std::move(graph_key)), + spec_(std::move(spec)), + inputs(std::move(input_tensors)), + is_static(spec_->IsStatic()) { + if (is_static) { + prepare_static_output_tensors(); + prepare_input_output_tensordesc(); + } + } + + size_t const_memory_size() { + size_t size; + CALL_FUNC(spec_->GetConstMemorySize(size)); + return size; + } + + size_t feature_memory_size() { + size_t size; + CALL_FUNC(spec_->GetFeatureMemorySize(size)); + return size; + } + + size_t fixed_feature_memory_size() { + size_t size; + CALL_FUNC(spec_->GetFixedFeatureMemorySize(size)); + return size; + } + + int get_graph_id() const { return graph_id_; } + + std::vector& get_inputs() { return inputs; } + + std::vector& get_outputs() { return outputs; } + + void prepare_static_output_tensors() { + std::vector shapes; + std::vector dtypes; + CALL_FUNC(spec_->GetOutputShapes(shapes)); + CALL_FUNC(spec_->GetOutputDtypes(dtypes)); + + for (size_t i = 0; i < shapes.size(); ++i) { + outputs.emplace_back(ge::TensorDesc(shapes[i], ge::FORMAT_ND, dtypes[i])); + } + } + + void prepare_input_output_tensordesc() { + inputs_desc.reserve(inputs.size()); + outputs_desc.reserve(outputs.size()); + auto get_tensor_desc = [](const Tensor& tensor) { + return tensor.GetTensorDesc(); + }; + std::transform(inputs.begin(), inputs.end(), + std::back_inserter(inputs_desc), get_tensor_desc); + std::transform(outputs.begin(), outputs.end(), + std::back_inserter(outputs_desc), get_tensor_desc); + } + + void assemble_inputs(const std::vector& shapes, + const std::vector& dtypes, + const std::vector& formats) { + inputs.clear(); + CHECK(shapes.size() == dtypes.size() && shapes.size() == formats.size()); + auto size = shapes.size(); + for (unsigned int i = 0; i < size; ++i) { + inputs.emplace_back(ge::TensorDesc(shapes[i], formats[i], dtypes[i])); + inputs[i].SetPlacement(ge::Placement::kPlacementDevice); + } + } + + void assemble_outputs(const std::vector& shapes, + const std::vector& dtypes, + const std::vector& formats) { + outputs.clear(); + CHECK(shapes.size() == dtypes.size() && shapes.size() == formats.size()); + auto size = shapes.size(); + for (unsigned int i = 0; i < size; ++i) { + outputs.emplace_back(ge::TensorDesc(shapes[i], formats[i], dtypes[i])); + outputs[i].SetPlacement(ge::Placement::kPlacementDevice); + } + } + + void update_inputs(const std::vector& shapes) { + CHECK(inputs.size() == shapes.size()); + auto size = shapes.size(); + for (unsigned int i = 0; i < size; ++i) { + auto desc = inputs[i].GetTensorDesc(); + desc.SetShape(shapes[i]); + CALL_FUNC(inputs[i].SetTensorDesc(desc)); + } + } + + void update_outputs(const std::vector& shapes) { + CHECK(outputs.size() == shapes.size()); + auto size = shapes.size(); + for (unsigned int i = 0; i < size; ++i) { + auto desc = outputs[i].GetTensorDesc(); + desc.SetShape(shapes[i]); + CALL_FUNC(outputs[i].SetTensorDesc(desc)); + } + } + + std::vector> get_input_shapes() { + return get_shapes(inputs); + } + + std::vector get_input_dtypes() { return get_dtypes(inputs); } + + std::vector> get_output_shapes() { + return get_shapes(outputs); + } + + std::vector get_output_dtypes() { return get_dtypes(outputs); } + + void set_input_output_data(void* inputs_data[], void* outputs_data[], + int64_t inputs_data_size[], + int64_t outputs_data_size[]) { + const static ge::Tensor::DeleteFunc kDoNothing = [](uint8_t* data) {}; + for (unsigned long i = 0; i < inputs.size(); ++i) { + CALL_FUNC(inputs[i].ResetData(static_cast(inputs_data[i]), + static_cast(inputs_data_size[i]), + kDoNothing)); + } + for (unsigned long i = 0; i < outputs.size(); ++i) { + CALL_FUNC(outputs[i].ResetData(static_cast(outputs_data[i]), + static_cast(outputs_data_size[i]), + kDoNothing)); + } + } + + private: + std::vector> get_shapes( + const std::vector& tensors) const { + std::vector> shapes; + shapes.reserve(tensors.size()); + for (const auto& tensor : tensors) { + shapes.emplace_back(tensor.GetTensorDesc().GetShape().GetDims()); + } + return shapes; + } + + std::vector get_dtypes(const std::vector& tensors) const { + std::vector dtypes; + dtypes.reserve(tensors.size()); + for (const auto& tensor : tensors) { + dtypes.emplace_back( + static_cast(tensor.GetTensorDesc().GetDataType())); + } + return dtypes; + } + + int graph_id_; + Graph graph_; + std::string graph_key_; + std::shared_ptr spec_; + std::vector inputs; + std::vector outputs; + std::vector inputs_desc; + std::vector outputs_desc; + bool is_static; +}; + +#endif // DICP_ASCEND_GE_GRAPH_H diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ge_init_config.json b/dicp/dicp/vendor/AscendGraph/codegen/ge_init_config.json new file mode 100644 index 000000000..19dd3687d --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/codegen/ge_init_config.json @@ -0,0 +1,5 @@ +{ + "ge.graphRunMode": "0", + "ge.exec.precision_mode": "allow_fp32_to_fp16", + "ge.op_compiler_cache_mode": "disable" +} diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ge_runner.h b/dicp/dicp/vendor/AscendGraph/codegen/ge_runner.h new file mode 100644 index 000000000..777eebbed --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/codegen/ge_runner.h @@ -0,0 +1,98 @@ +#ifndef DICP_ASCEND_GE_RUNNER_H +#define DICP_ASCEND_GE_RUNNER_H + +#include +#include +#include + +#include "acl/acl.h" +#include "ascend_string.h" +#include "ge_api.h" +#include "ge_graph.h" +#include "graph_utils.h" + +using namespace ge; + +class GEGraphRunner { + public: + explicit GEGraphRunner(void* context, int device_id, + const std::string& config_file_path) + : device_id_(device_id), + config_file_path_(config_file_path), + session_(nullptr) { + init(static_cast(context)); + } + + GEGraphRunner() : device_id_(0), session_(nullptr) {} + + void init(aclrtContext context) { + // CALL_FUNC(aclrtSetCurrentContext(context)); + std::map config; + config["ge.exec.deviceId"] = std::to_string(device_id_).c_str(); + + auto kSocVersion = aclrtGetSocName(); + config[ge::ir_option::SOC_VERSION] = kSocVersion; + + auto raw_conf = parse_json_to_map(config_file_path_); + for (const auto& item : raw_conf) { + config[item.first.c_str()] = item.second.c_str(); + } + + for (const auto& item : config) { + std::cout << "ge init config: " << item.first.GetString() << " = " + << item.second.GetString() << std::endl; + } + CALL_FUNC(ge::GEInitialize(config)); + + std::map options; + session_ = std::make_unique(options); + } + + std::shared_ptr addGraph(int graph_id, + const Graph& graph, + const std::string& graph_key) { + std::map graph_options = { + {"ge.graph_key", graph_key.c_str()}}; + CALL_FUNC(session_->AddGraph(graph_id, graph, graph_options)); + CALL_FUNC(session_->CompileGraph(graph_id)); + return session_->GetCompiledGraphSummary(graph_id); + } + + void runGraphWithStreamAsync(std::shared_ptr& graph, void* stream) { + CALL_FUNC(session_->RunGraphWithStreamAsync(graph->get_graph_id(), stream, + graph->get_inputs(), + graph->get_outputs())); + } + + void setConstMem(int graph_id, const void* const memory, size_t size) { + CALL_FUNC(session_->SetGraphConstMemoryBase(graph_id, memory, size)); + } + + void setFixedFeatureMem(int graph_id, const void* const memory, size_t size) { + CALL_FUNC( + session_->SetGraphFixedFeatureMemoryBase(graph_id, memory, size);); + } + + void setFeatureMem(int graph_id, const void* const memory, size_t size) { + CALL_FUNC(session_->UpdateGraphFeatureMemoryBase(graph_id, memory, size);); + } + + ~GEGraphRunner() { + session_.reset(); + auto status = ge::GEFinalize(); + if (status != 0) { + std::cout << "GEFinalize failed!" << std::endl; + } + } + + int getDeviceId() { return device_id_; } + + private: + int device_id_; + std::string config_file_path_; + std::unique_ptr session_; + aclrtContext context; + aclrtStream stream; +}; + +#endif // DICP_ASCEND_GE_RUNNER_H diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp b/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp index 99f422dca..886885b60 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_compile.cpp @@ -1,13 +1,192 @@ +#include +#include +#include +#include +#include +#include + +#include "ge_builder.h" +#include "ge_graph.h" +#include "ge_runner.h" #include "graph_utils.h" -static void compile(const std::string& graph_path, - const std::string& graph_json_file, - const std::string& fusion_switch_file) { +extern "C" { + +std::unordered_map> graph_manager; +std::unique_ptr graph_runner; + +void init(void* context, int device_id, const char* config_file_path) { + graph_runner = + std::make_unique(context, device_id, config_file_path); + std::cout << "graph runner init success!" << std::endl; +} + +void release() { graph_runner.reset(); } + +void add_graph(int graph_id, const char* graph_json_file, + const char* graph_key) { + std::string graph_name = "BuildGraph"; + Graph graph(graph_name.c_str()); + + std::ifstream f(graph_json_file); + json graph_json = json::parse(f); + + std::vector input_tensors; + buildGraph(graph, graph_json, input_tensors); + + auto graph_spec = graph_runner->addGraph(graph_id, graph, graph_key); + auto acl_graph = std::make_shared(graph_id, graph_key, graph, + graph_spec, input_tensors); + graph_manager[graph_id] = std::move(acl_graph); +} + +size_t get_const_size(int graph_id) { + return graph_manager[graph_id]->const_memory_size(); +} + +size_t get_feature_size(int graph_id) { + return graph_manager[graph_id]->feature_memory_size(); +} + +size_t get_fixed_feature_size(int graph_id) { + return graph_manager[graph_id]->fixed_feature_memory_size(); +} + +std::string get_shapes(const std::vector>& shapes) { + std::ostringstream oss; + for (size_t i = 0; i < shapes.size(); ++i) { + for (size_t j = 0; j < shapes[i].size(); ++j) { + oss << shapes[i][j] << (j != shapes[i].size() - 1 ? "," : ""); + } + oss << (i != shapes.size() - 1 ? ";" : ""); + } + return oss.str(); +} + +void get_input_shapes(int graph_id, char* input_shapes) { + std::string str = get_shapes(graph_manager[graph_id]->get_input_shapes()); + strncpy(input_shapes, str.c_str(), str.size()); +} + +void get_output_shapes(int graph_id, char* output_shapes) { + std::string str = get_shapes(graph_manager[graph_id]->get_output_shapes()); + strncpy(output_shapes, str.c_str(), str.size()); +} + +std::string get_dtypes(const std::vector& dtypes) { + std::ostringstream oss; + for (size_t i = 0; i < dtypes.size(); ++i) { + oss << dtypes[i] << (i != dtypes.size() - 1 ? ";" : ""); + } + return oss.str(); +} + +void get_input_dtypes(int graph_id, char* input_dtypes) { + std::string str = get_dtypes(graph_manager[graph_id]->get_input_dtypes()); + strncpy(input_dtypes, str.c_str(), str.size() + 1); +} + +void get_output_dtypes(int graph_id, char* output_dtypes) { + std::string str = get_dtypes(graph_manager[graph_id]->get_output_dtypes()); + strncpy(output_dtypes, str.c_str(), str.size() + 1); +} + +void update_inputs(int graph_id, int64_t** shapes, size_t* shape_sizes, + size_t outer_size) { + std::vector ge_shapes; + ge_shapes.reserve(outer_size); + for (size_t i = 0; i < outer_size; ++i) { + std::vector inner_shape(shapes[i], shapes[i] + shape_sizes[i]); + ge_shapes.emplace_back(inner_shape); + } + graph_manager[graph_id]->update_inputs(ge_shapes); +} + +void update_outputs(int graph_id, int64_t** shapes, size_t* shape_sizes, + size_t outer_size) { + std::vector ge_shapes; + ge_shapes.reserve(outer_size); + for (size_t i = 0; i < outer_size; ++i) { + std::vector inner_shape(shapes[i], shapes[i] + shape_sizes[i]); + ge_shapes.emplace_back(inner_shape); + } + graph_manager[graph_id]->update_outputs(ge_shapes); +} + +void assemble_inputs(int graph_id, int64_t** shapes, size_t* shape_sizes, + size_t outer_size, int* dtypes, int* formats) { + std::vector ge_shapes; + std::vector ge_dtypes; + std::vector ge_formats; + ge_shapes.reserve(outer_size); + ge_dtypes.reserve(outer_size); + ge_formats.reserve(outer_size); + for (size_t i = 0; i < outer_size; ++i) { + std::vector inner_shape(shapes[i], shapes[i] + shape_sizes[i]); + ge_shapes.emplace_back(inner_shape); + ge_dtypes.emplace_back(static_cast(dtypes[i])); + ge_formats.emplace_back(static_cast(formats[i])); + } + graph_manager[graph_id]->assemble_inputs(ge_shapes, ge_dtypes, ge_formats); +} + +void assemble_outputs(int graph_id, int64_t** shapes, size_t* shape_sizes, + size_t outer_size, int* dtypes, int* formats) { + std::vector ge_shapes; + std::vector ge_dtypes; + std::vector ge_formats; + ge_shapes.reserve(outer_size); + ge_dtypes.reserve(outer_size); + ge_formats.reserve(outer_size); + for (size_t i = 0; i < outer_size; ++i) { + std::vector inner_shape(shapes[i], shapes[i] + shape_sizes[i]); + ge_shapes.emplace_back(inner_shape); + ge_dtypes.emplace_back(static_cast(dtypes[i])); + ge_formats.emplace_back(static_cast(formats[i])); + } + graph_manager[graph_id]->assemble_outputs(ge_shapes, ge_dtypes, ge_formats); +} + +void set_graph_memory(int graph_id, void* const_mem_ptr, void* workspace_ptr, + size_t const_size, size_t workspace_size) { + graph_runner->setConstMem(graph_id, const_mem_ptr, const_size); + graph_runner->setFeatureMem(graph_id, workspace_ptr, workspace_size); +} + +void set_fixed_feature_graph_memory(int graph_id, void* workspace_ptr, + size_t workspace_size) { + graph_runner->setFixedFeatureMem(graph_id, workspace_ptr, workspace_size); +} + +void set_feature_graph_memory(int graph_id, void* workspace_ptr, + size_t workspace_size) { + graph_runner->setFeatureMem(graph_id, workspace_ptr, workspace_size); +} + +void set_const_graph_memory(int graph_id, void* workspace_ptr, + size_t workspace_size) { + graph_runner->setConstMem(graph_id, workspace_ptr, workspace_size); +} + +void run(aclrtContext context, int graph_id, void* stream, void* inputs_data[], + void* outputs_data[], int64_t inputs_data_size[], + int64_t outputs_data_size[]) { + CALL_FUNC(aclrtSetCurrentContext(context)); + graph_manager[graph_id]->set_input_output_data( + inputs_data, outputs_data, inputs_data_size, outputs_data_size); + graph_runner->runGraphWithStreamAsync(graph_manager[graph_id], stream); +} + +void compile_and_save(const char* graph_path, const char* graph_json_file, + const char* fusion_switch_file, + const char* global_options_file) { std::string graph_name = "BuildGraph"; Graph graph(graph_name.c_str()); std::ifstream f(graph_json_file); json graph_json = json::parse(f); - buildGraph(graph, graph_json); + + std::vector input_tensors; + buildGraph(graph, graph_json, input_tensors); std::map options; bool has_dynamic_shape = graph_json["has_dynamic_shape"].get(); @@ -19,14 +198,7 @@ static void compile(const std::string& graph_path, } } - AclgraphBuilder builder{fusion_switch_file}; + GEGraphBuilder builder{fusion_switch_file, global_options_file}; builder.saveGraph(graph_path, graph, options); } - -int main(int argc, char* argv[]) { - std::string graph_path{argv[1]}; - std::string graph_json_file{argv[2]}; - std::string fusion_switch_file{argv[3]}; - compile(graph_path, graph_json_file, fusion_switch_file); - return 0; } diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 8cee35353..2ed077807 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -1,5 +1,5 @@ -#ifndef DAVINCI_GRAPH_UTILS_H -#define DAVINCI_GRAPH_UTILS_H +#ifndef DICP_ASCEND_GRAPH_UTILS_H +#define DICP_ASCEND_GRAPH_UTILS_H #include #include #include @@ -29,392 +29,59 @@ #define SUCCESS 0 using json = nlohmann::json; -using namespace ge; -static std::unordered_set op_with_dynamic_inputs_outputs = { - "ConcatD", "IdentityN", "Pack", "SplitD", "IncreFlashAttention"}; - -void check_op(std::unordered_map& op_map, - const std::string& op_name) { - if (op_map.count(op_name) > 0) { - throw std::runtime_error("op_name duplicated!"); - } -} - -void setTensorData(Tensor& tensor, uint8_t* src_data, uint64_t data_size, - const std::string& debug_name = "") { - auto status = tensor.SetData(reinterpret_cast(src_data), data_size); - if (status != ge::GRAPH_SUCCESS) { - std::cout << "Set " << debug_name << " tensor data failed!" << std::endl; - } -} - -ge::Tensor genTensor(const std::vector& tensor_shape, - ge::Format format, ge::DataType data_type) { - TensorDesc desc(ge::Shape(tensor_shape), format, data_type); - Tensor result(desc); - return result; -} - -template -ge::Tensor genTensorWithData(const std::vector& tensor_shape, - ge::Format format, ge::DataType data_type, - std::vector value) { - TensorDesc desc(ge::Shape(tensor_shape), format, data_type); - Tensor result(desc); - setTensorData(result, reinterpret_cast(value.data()), - value.size() * sizeof(T), "genTensorWithData"); - return result; -} - -ge::Operator genInput(const std::string op_name, - const std::vector shape, ge::Format format, - ge::DataType data_type, int index = -1) { - TensorDesc tensor_desc_data_op = - TensorDesc(ge::Shape(shape), format, data_type); - auto op = op::Data(op_name.c_str()); - op.update_input_desc_x(tensor_desc_data_op); - op.update_output_desc_y(tensor_desc_data_op); - if (index > -1) { - op.set_attr_index(index); - } - return op; -} - -class AclgraphBuilder { - public: - explicit AclgraphBuilder(const std::string& fusion_switch_file) - : _fusion_switch_file(fusion_switch_file) { - // 1. system init - auto kSocVersion = aclrtGetSocName(); - std::map global_options = { - {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)}, - {AscendString(ge::ir_option::FUSION_SWITCH_FILE), - AscendString(_fusion_switch_file.c_str())}, - {AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"}, - }; - auto status = aclgrphBuildInitialize(global_options); - if (status != GRAPH_SUCCESS) { - std::cout << "aclgrphBuildInitialize failed!" << std::endl; - } else { - std::cout << "aclgrphBuildInitialize success!" << std::endl; - } - } - - void saveGraph(const std::string& path, const Graph& graph, - std::map& options) { - ModelBufferData model; - - auto status = aclgrphBuildModel(graph, options, model); - if (status == GRAPH_SUCCESS) { - std::cout << "Build Model SUCCESS!" << std::endl; - } else { - std::cout << "Build Model Failed! " << status << std::endl; - return; - } - - // 4. Save Ir Model - status = aclgrphSaveModel(path.c_str(), model); - if (status == GRAPH_SUCCESS) { - std::cout << "Save Offline Model SUCCESS!" << std::endl; - } else { - std::cout << "Save Offline Model Failed! " << status << std::endl; - } - } - - ~AclgraphBuilder() { - aclgrphBuildFinalize(); - std::cout << "aclgrphBuildFinalize success!" << std::endl; - } - - private: - std::string _fusion_switch_file; -}; - -ge::Format get_ascend_format(const std::string& format) { - static std::unordered_map format_map = { - {"NCHW", FORMAT_NCHW}, - {"NHWC", FORMAT_NHWC}, - {"ND", FORMAT_ND}, - {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, - }; - if (format_map.count(format) > 0) { - return format_map[format]; - } - throw std::runtime_error("invalid ascend foramt!"); -} - -ge::DataType get_ascend_datatype(const std::string& data_type) { - static std::unordered_map datatype_map = { - {"FLOAT", ge::DataType::DT_FLOAT}, {"FLOAT16", ge::DataType::DT_FLOAT16}, - {"INT32", ge::DataType::DT_INT32}, {"INT64", ge::DataType::DT_INT64}, - {"BOOL", ge::DataType::DT_BOOL}, {"UINT8", ge::DataType::DT_UINT8}, - {"BF16", ge::DataType::DT_BF16}, - }; - if (datatype_map.count(data_type) > 0) { - return datatype_map[data_type]; - } - throw std::runtime_error("invalid ascend data type!"); -} - -template -T genDynamicOp(const std::string& op_name) { - return T(op_name.c_str()); -} - -template -void parseDynamicInput(std::unordered_map& op_map, - T& op, const json& node) { - if (node.contains("dynamic_inputs")) { - for (const auto& i : node["dynamic_inputs"]) { - auto num = i["num"].get(); - auto name = i["name"].get(); - if (name == "x") { - op.create_dynamic_input_x(num); - for (const auto& item : i["value"]) { - auto index = item["index"].get(); - auto value = op_map[item["value"].get()]; - if (item.contains("edge")) { - op.set_dynamic_input_x(index, value, - item["edge"].get().c_str()); - } else { - op.set_dynamic_input_x(index, value); - } - } - } else { - throw std::runtime_error("invalid dynamic input name"); - } - } - } -} - -void parseIncreFlashAttentionDynamicInput( - std::unordered_map& op_map, - op::IncreFlashAttention& op, const json& node) { - if (node.contains("dynamic_inputs")) { - int kv_inputs_num = 0; - for (const auto& i : node["dynamic_inputs"]) { - auto num = i["num"].get(); - auto name = i["name"].get(); - if (name == "key") { - kv_inputs_num = static_cast(num); - op.create_dynamic_input_byindex_key(num, 1); - for (const auto& item : i["value"]) { - auto index = item["index"].get(); - auto value = op_map[item["value"].get()]; - op.set_dynamic_input_key(index, value); - } - } else if (name == "value") { - if (kv_inputs_num == 0 && num == kv_inputs_num) { - throw std::runtime_error( - "need first set dynamic key input for IncreFlashAttention Op" - "and kv_inputs_num == num !!"); - } - op.create_dynamic_input_byindex_value(num, 1 + num); - for (const auto& item : i["value"]) { - auto index = item["index"].get(); - auto value = op_map[item["value"].get()]; - op.set_dynamic_input_value(index, value); - } - } else { - throw std::runtime_error("invalid dynamic input name"); - } - } - } -} - -template -void parseDynamicOutput(T& op, const json& node) { - if (node.contains("dynamic_outputs")) { - for (const auto& o : node["dynamic_outputs"]) { - auto name = o["name"].get(); - auto num = o["num"].get(); - if (name == "y") { - op.create_dynamic_output_y(num); - } else { - throw std::runtime_error("invalid dynamic output name"); - } - } - } -} - -ge::Operator genDynamicOperator( - std::unordered_map& op_map, const json& node) { - auto op_type = node["op_type"].get(); - auto op_name = node["op_name"].get(); - if (op_type == "ConcatD") { - auto op = genDynamicOp(op_name); - parseDynamicInput(op_map, op, node); - return op; - } else if (op_type == "IdentityN") { - auto op = genDynamicOp(op_name); - parseDynamicInput(op_map, op, node); - parseDynamicOutput(op, node); - return op; - } else if (op_type == "Pack") { - auto op = genDynamicOp(op_name); - parseDynamicInput(op_map, op, node); - return op; - } else if (op_type == "IncreFlashAttention") { - auto op = genDynamicOp(op_name); - parseIncreFlashAttentionDynamicInput(op_map, op, node); - return op; - } else if (op_type == "SplitD") { - auto op = genDynamicOp(op_name); - parseDynamicOutput(op, node); - return op; - } - throw std::runtime_error("invalid dynamic opeartor!"); -} - -void parseCommonNode(std::unordered_map& op_map, - ge::Operator& op, const json& node) { - if (node.contains("inputs")) { - for (const auto& i : node["inputs"]) { - auto name = i["name"].get().c_str(); - auto value = op_map[i["value"].get()]; - if (i.contains("index")) { - op.SetInput(name, value, i["index"].get()); - } else if (i.contains("update_desc")) { - auto desc = i["update_desc"]; - auto format = desc["format"].get(); - auto data_type = desc["data_type"].get(); - auto shape = desc["shape"].get>(); - TensorDesc tensor_desc = - TensorDesc(ge::Shape(shape), get_ascend_format(format), - get_ascend_datatype(data_type)); - auto output_name = desc["output_name"].get(); - if (output_name != "none") { - op_map[i["value"].get()].UpdateOutputDesc( - output_name.c_str(), tensor_desc); - op.SetInput(name, value); - } else { - op.SetInput(name, value); - op.UpdateInputDesc(name, tensor_desc); - } - } else { - op.SetInput(name, value); - } - } - } - if (node.contains("outputs")) { - for (const auto& i : node["outputs"]) { - auto name = i["output_name"].get().c_str(); - auto desc = i["update_desc"]; - auto format = desc["format"].get(); - auto data_type = desc["data_type"].get(); - auto shape = desc["shape"].get>(); - TensorDesc tensor_desc = - TensorDesc(ge::Shape(shape), get_ascend_format(format), - get_ascend_datatype(data_type)); - op.UpdateOutputDesc(name, tensor_desc); - } - } - if (node.contains("attrs")) { - for (const auto& attr : node["attrs"]) { - auto attr_name = attr["name"].get(); - auto value_type = attr["value_type"]; - if (value_type == "str") { - op.SetAttr(attr_name, attr["value"].get()); - } else if (value_type == "dtype_str") { - auto value = attr["value"].get(); - op.SetAttr(attr_name, get_ascend_datatype(value)); - } else if (value_type == "list_int") { - auto value = attr["value"].get>(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "list_float") { - auto value = attr["value"].get>(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "float") { - auto value = attr["value"].get(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "int") { - auto value = attr["value"].get(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "bool") { - auto value = attr["value"].get(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "int64") { - auto value = attr["value"].get(); - op.SetAttr(attr_name.c_str(), value); - } else if (value_type == "tensor") { - auto cpp_data_type = attr["tensor_cpp_data_type"].get(); - auto data_type = - get_ascend_datatype(attr["tensor_data_type"].get()); - auto format = - get_ascend_format(attr["tensor_format"].get()); - auto tensor_dims = attr["tensor_dims"]; - auto dims = tensor_dims.get>(); - if (cpp_data_type == "FLOAT") { - auto value = attr["tensor_value"].get>(); - auto tensor = - genTensorWithData(dims, format, data_type, value); - op.SetAttr(attr_name.c_str(), tensor); - } else if (cpp_data_type == "FLOAT16") { - std::vector values = - attr["tensor_value"].get>(); - std::vector half_values; - for (auto& v : values) { - half_values.push_back(half_float::half(v)); - } - auto tensor = genTensorWithData( - dims, format, data_type, half_values); - op.SetAttr(attr_name.c_str(), tensor); - } else if (cpp_data_type == "INT32") { - auto value = attr["tensor_value"].get>(); - auto tensor = genTensorWithData(dims, format, data_type, value); - op.SetAttr(attr_name.c_str(), tensor); - } else if (cpp_data_type == "INT64") { - auto value = attr["tensor_value"].get>(); - auto tensor = - genTensorWithData(dims, format, data_type, value); - op.SetAttr(attr_name.c_str(), tensor); - } else { - throw std::runtime_error("invalid cpp data type!"); - } - } else { - throw std::runtime_error("invalid attr value type!"); - } - } - } -} - -void buildGraph(Graph& graph, const json& graph_json) { - std::unordered_map op_map; - json data_nodes = graph_json["data_nodes"]; - for (const auto& node : graph_json["data_nodes"]) { - auto node_name = node["op_name"].get(); - auto format = get_ascend_format(node["format"].get()); - auto data_type = get_ascend_datatype(node["data_type"].get()); - auto index = node["index"].get(); - auto dims = node["dims"].get>(); - check_op(op_map, node_name); - op_map[node_name] = genInput(node_name, dims, format, data_type, index); - graph.AddOp(op_map[node_name]); - } - for (const auto& node : graph_json["common_nodes"]) { - auto node_name = node["op_name"].get(); - auto op_type = node["op_type"].get(); - - check_op(op_map, node_name); - if (op_with_dynamic_inputs_outputs.count(op_type) > 0) { - op_map[node_name] = genDynamicOperator(op_map, node); +#define DICP_CHECK_ABORT(condition, ...) \ + do { \ + if (!(condition)) { \ + printf("[%s:%s:%d]: ", __FILE__, __FUNCTION__, __LINE__); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + std::abort(); \ + } \ + } while (0); + +#define DICP_ASCEND_CHECK_NULLPTR_ABORT(ptr) \ + DICP_CHECK_ABORT(ptr, "Variable is nullptr, please check.") + +#define TRACK_GE(x) \ + do { \ + static bool enable = std::getenv("DICP_NOT_TRACK_GE") == nullptr; \ + if (enable) { \ + printf("[%s: %d]:%s\n", __FILE__, __LINE__, x); \ + } \ + } while (0); + +#define CALL_FUNC(Expr) \ + do { \ + auto ret = Expr; \ + if (ret != SUCCESS) { \ + TRACK_GE(#Expr); \ + throw std::runtime_error("dicp call function failed."); \ + } \ + } while (0); + +#define CHECK(Expr) \ + do { \ + auto ret = Expr; \ + if (!ret) { \ + TRACK_GE(#Expr); \ + throw std::runtime_error("dicp check failed."); \ + } \ + } while (0); + +std::unordered_map parse_json_to_map( + const std::string& config_file) { + std::ifstream f(config_file); + json config_json = json::parse(f); + std::unordered_map conf; + for (const auto& elem : config_json.items()) { + if (elem.value().is_string()) { + conf[elem.key()] = elem.value().get(); } else { - op_map[node_name] = ge::OperatorFactory::CreateOperator(node_name.c_str(), - op_type.c_str()); + throw std::runtime_error("in config file, json value is not string!"); } - parseCommonNode(op_map, op_map[node_name], node); - graph.AddOp(op_map[node_name]); - } - std::vector graph_inputs; - std::vector graph_outputs; - for (const auto& i : graph_json["input_names"]) { - graph_inputs.push_back(op_map[i.get()]); - } - for (const auto& i : graph_json["output_names"]) { - graph_outputs.push_back(op_map[i.get()]); } - graph.SetInputs(graph_inputs).SetOutputs(graph_outputs); + return conf; } -#endif // DAVINCI_GRAPH_UTILS_H +#endif // DICP_ASCEND_GRAPH_UTILS_H diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index 12c7bb193..f58c24018 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -1,5 +1,13 @@ import atexit +import ctypes import os +import math + +from dicp.vendor.AscendGraph.codegen.utils import get_ascend_dtype_num, get_ascend_format_num, get_torch_dtype + +from ctypes import POINTER, c_longlong, c_size_t, c_void_p, c_int64, c_int + +from pathlib import Path import acl import numpy as np @@ -66,24 +74,6 @@ ACL_HBM_MEM_P2P_NORMAL = 9 -def get_np_dtype(dtype): - if dtype == ACL_FLOAT: - return np.float32 - elif dtype == ACL_INT64: - return np.int64 - elif dtype == ACL_INT32: - return np.int32 - elif dtype == ACL_BOOL: - return np.bool_ - elif dtype == ACL_DOUBLE: - return np.float64 - elif dtype == ACL_COMPLEX64: - return np.complex64 - elif dtype == ACL_FLOAT16: - return np.float16 - raise RuntimeError("unsupported np dtype!") - - def get_tensor_dtype(dtype): if dtype == ACL_FLOAT: return torch.float32 @@ -122,21 +112,376 @@ def __init__(self): def init_work_weight_ptr(self): if self.work_ptr is None: - self.work_size = 26 * 1024 * 1024 * 1024 - self.work_ptr, ret = acl.rt.malloc(self.work_size, - ACL_MEM_MALLOC_HUGE_FIRST) - check_ret("acl.rt.malloc", ret) + free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM) + check_ret("acl.rt.get_mem_info", ret) + + self.work_size = int(6 * 1024 * 1024 * 1024) + self.work_tensor = torch.empty( + self.work_size, dtype=torch.bool, device=dipu_device_str) + self.work_ptr = self.work_tensor.data_ptr() def release_memory(self): print("Release bufferPtr from MemoryPool.") - if self.work_ptr is not None: - ret = acl.rt.free(self.work_ptr) - check_ret("acl.rt.free", ret) - self.work_ptr = None + self.work_tensor = None -zero_tensor = torch.randn(1).to(dipu_device_str) +class GraphCompiler: + def __init__(self): + self._lib_path = os.environ.get( + "DICP_ASCEND_GE_GRAPH_EXECUTOR", "/tmp/dicp_ascend/ge_graph.so") + self.graph_compiler = ctypes.CDLL(self._lib_path) + + +class GraphManager: + def __init__(self): + device_id = torch_dipu.current_device() + self._lib_path = os.environ.get( + "DICP_ASCEND_GE_GRAPH_EXECUTOR", "/tmp/dicp_ascend/ge_graph.so") + self.config_file = os.path.join( + str(Path(__file__).resolve().parent), 'ge_init_config.json') + self.graph_manager = ctypes.CDLL(self._lib_path) + + context, ret = acl.rt.get_context() + check_ret("acl.rt.get_context", ret) + self.graph_manager.init((c_void_p)(context), device_id, self.config_file.encode()) + atexit.register(self.release_graph) + + def release_graph(self): + self.graph_manager.release() + + +zero_tensor = torch.empty(1, device=dipu_device_str) +graph_manager = None +graph_compiler = None memory_pool = MemoryPool() +graph_id = 0 + + +def get_graph_manager(): + global graph_manager + if graph_manager is None: + graph_manager = GraphManager() + return graph_manager.graph_manager + + +def get_graph_compiler(): + global graph_compiler + if graph_compiler is None: + graph_compiler = GraphCompiler() + return graph_compiler.graph_compiler + + +class GEStaticGraphExecutor(object): + def __init__(self, graph_id, device_id): + self.device_id = device_id + self.graph_id = graph_id + print('### graph_id:', self.graph_id) + + # init + self.const_mem_size = graph_manager.graph_manager.get_const_size( + self.graph_id) + self.feature_mem_size = graph_manager.graph_manager.get_feature_size( + self.graph_id) + + # alloc memory + self.const_tensor = torch.empty( + self.const_mem_size, dtype=torch.bool, device='dipu') + self.const_ptr = self.const_tensor.data_ptr() + graph_manager.graph_manager.set_graph_memory(self.graph_id, c_void_p( + self.const_ptr), c_void_p(memory_pool.work_ptr), c_size_t(self.const_mem_size), c_size_t(memory_pool.work_size)) + + # prapare output info + output_shape_buffer = ctypes.create_string_buffer(10000) + output_dtype_buffer = ctypes.create_string_buffer(10000) + graph_manager.graph_manager.get_output_shapes( + self.graph_id, output_shape_buffer) + graph_manager.graph_manager.get_output_dtypes( + self.graph_id, output_dtype_buffer) + + output_shape_str = output_shape_buffer.value.decode('utf-8') + output_dtype_str = output_dtype_buffer.value.decode('utf-8') + shapes = output_shape_str.split(';') + dtypes = output_dtype_str.split(';') + + assert len(shapes) == len(dtypes) + + self.output_shapes = [] + self.output_dtypes = [] + self.output_datasize = [] + for item in shapes: + if item == '': + self.output_shapes.append([]) + continue + elems = item.split(',') + elems = [int(x) for x in elems] + self.output_shapes.append(elems) + for item in dtypes: + elem = int(item) + self.output_dtypes.append(elem) + for i in range(len(shapes)): + elem_size = math.prod(self.output_shapes[i]) if len( + self.output_shapes[i]) > 0 else 1 + self.output_datasize.append( + elem_size * acl.data_type_size(self.output_dtypes[i])) + self.output_datasize_c = ( + c_int64 * len(self.output_datasize))(*self.output_datasize) + + # prapare input info + input_shape_buffer = ctypes.create_string_buffer(50000) + input_dtype_buffer = ctypes.create_string_buffer(50000) + graph_manager.graph_manager.get_input_shapes( + self.graph_id, input_shape_buffer) + graph_manager.graph_manager.get_input_dtypes( + self.graph_id, input_dtype_buffer) + + input_shape_str = input_shape_buffer.value.decode('utf-8') + input_dtype_str = input_dtype_buffer.value.decode('utf-8') + shapes = input_shape_str.split(';') + dtypes = input_dtype_str.split(';') + + assert len(shapes) == len(dtypes) + + self.input_shapes = [] + self.input_dtypes = [] + self.input_datasize = [] + + for item in shapes: + if item == '': + self.input_shapes.append([]) + continue + elems = item.split(',') + elems = [int(x) for x in elems] + self.input_shapes.append(elems) + for item in dtypes: + elem = int(item) + self.input_dtypes.append(elem) + for i in range(len(shapes)): + elem_size = math.prod(self.input_shapes[i]) if len( + self.input_shapes[i]) > 0 else 1 + self.input_datasize.append( + elem_size * acl.data_type_size(self.input_dtypes[i])) + self.input_datasize_c = ( + c_int64 * len(self.input_datasize))(*self.input_datasize) + + @record_function('load_and_run_run') + def run(self, images, dims=None, output_shape=None, + out_stride=None, out_storage_offset=None, + allocated_output=None): + + def get_data_ptr(data): + if data.device.dtype != dipu_device_str: + data = data.to(dipu_device_str) + return data.data_ptr() + + inputs = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) + and x.device.type != dipu_device_str else x for x in images] + input_ptrs = [x.data_ptr() for x in inputs] + + input_ptrs_c = (c_void_p * len(inputs))(*input_ptrs) + output_ptrs = [] + output_tensors = [] + + if allocated_output: + allocated_output_tensor = {} + for output_index, input_index in allocated_output.items(): + allocated_output_tensor[output_index] = inputs[input_index] + for i, shape in enumerate(self.output_shapes): + if i in allocated_output.keys(): + item = allocated_output_tensor[i] + else: + item = torch.empty(shape, dtype=get_torch_dtype( + self.output_dtypes[i]), device=dipu_device_str) + output_ptrs.append(item.data_ptr()) + output_tensors.append(item) + else: + for i, shape in enumerate(self.output_shapes): + item = torch.empty(shape, dtype=get_torch_dtype( + self.output_dtypes[i]), device=dipu_device_str) + output_ptrs.append(item.data_ptr()) + output_tensors.append(item) + + output_ptrs_c = (c_void_p * len(output_tensors))(*output_ptrs) + context, ret = acl.rt.get_context() + current_stream, ret = acl.rt.create_stream() + graph_manager.graph_manager.run((c_void_p)(context), self.graph_id, (c_void_p)(current_stream), input_ptrs_c, + output_ptrs_c, self.input_datasize_c, self.output_datasize_c) + ret = acl.rt.synchronize_stream(current_stream) + ret = acl.rt.destroy_stream(current_stream) + check_ret("acl.rt.synchronize_stream", ret) + return output_tensors + + +class GEDynamicGraphExecutor(object): + def __init__(self, graph_id, device_id, input_nodes, output_nodes): + self.device_id = device_id + self.graph_id = graph_id + self.is_first_run = True + self.input_dtypes = [] + self.input_formats = [] + self.output_dtypes = [] + self.output_formats = [] + self.input_ascend_dtype_nums = [] + self.output_ascend_dtype_nums = [] + self.input_args_size = None + self.input_args_dtypes_array = None + self.input_args_formats_array = None + self.output_args_size = None + self.output_args_dtypes_array = None + self.output_args_formats_array = None + print('### graph_id:', self.graph_id) + + # init + self.fixed_feature_mem_size = graph_manager.graph_manager.get_fixed_feature_size( + self.graph_id) + + # alloc memory + self.fixed_feature_tensor = torch.empty( + self.fixed_feature_mem_size, dtype=torch.bool, device='dipu') + self.fixed_feature_ptr = self.fixed_feature_tensor.data_ptr() + graph_manager.graph_manager.set_fixed_feature_graph_memory(self.graph_id, ctypes.c_void_p( + self.fixed_feature_ptr), self.fixed_feature_mem_size) + + # get input/output dtypes and formats + self.input_args_size = len(input_nodes) + for input in input_nodes: + dtype = get_ascend_dtype_num(input['data_type']) + format = get_ascend_format_num(input['format']) + self.input_formats.append(format) + self.input_dtypes.append(get_torch_dtype(dtype)) + self.input_ascend_dtype_nums.append(dtype) + self.input_args_dtypes_array = ( + c_int * self.input_args_size)(*self.input_ascend_dtype_nums) + self.input_args_formats_array = ( + c_int * self.input_args_size)(*self.input_formats) + + self.output_args_size = len(output_nodes) + for output in output_nodes: + dtype = get_ascend_dtype_num(output['data_type']) + format = get_ascend_format_num(output['format']) + self.output_ascend_dtype_nums.append(dtype) + self.output_dtypes.append(get_torch_dtype(dtype)) + self.output_formats.append(format) + self.output_args_dtypes_array = ( + c_int * self.output_args_size)(*self.output_ascend_dtype_nums) + self.output_args_formats_array = ( + c_int * self.output_args_size)(*self.output_formats) + + @record_function('load_and_run_run') + def run(self, images, dims=None, output_shape=None, + out_stride=None, out_storage_offset=None, + allocated_output=None): + assert len(images) > 0 + inputs = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) + and x.device.type != dipu_device_str else x for x in images] + input_ptrs = [x.data_ptr() for x in inputs] + + input_ptrs_c = (c_void_p * len(inputs))(*input_ptrs) + output_ptrs = [] + output_tensors = [] + + allocated_output_tensor = None + if allocated_output: + allocated_output_tensor = {} + for output_index, input_index in allocated_output.items(): + allocated_output_tensor[output_index] = inputs[input_index] + + # assemble inputs/outputs + cur_input_shapes = [] + cur_per_input_shape_size = [] + cur_output_shapes = [] + cur_per_output_shape_size = [] + input_datasize = [] + output_datasize = [] + for index, i in enumerate(inputs): + shape = list(i.shape) + shape = shape if shape != [] else [1] + shape_data_size = math.prod(shape) + input_datasize.append( + shape_data_size * acl.data_type_size(self.input_ascend_dtype_nums[index])) + cur_input_shapes.append((c_longlong * len(shape))(*shape)) + cur_per_input_shape_size.append(len(shape)) + input_args_shapes_array_size = ( + c_size_t * len(inputs))(*cur_per_input_shape_size) + input_args_shapes_array = ( + POINTER(c_longlong) * len(inputs))(*cur_input_shapes) + + for index, shape in enumerate(output_shape): + shape = shape if shape != [] else [1] + shape_data_size = math.prod(shape) + dtype = self.output_ascend_dtype_nums[index] + output_datasize.append( + shape_data_size * acl.data_type_size(dtype)) + cur_output_shapes.append((c_longlong * len(shape))(*shape)) + cur_per_output_shape_size.append(len(shape)) + output_args_shapes_array_size = ( + c_size_t * len(output_shape))(*cur_per_output_shape_size) + output_args_shapes_array = ( + POINTER(c_longlong) * len(output_shape))(*cur_output_shapes) + + if self.is_first_run: + graph_manager.graph_manager.assemble_inputs(self.graph_id, input_args_shapes_array, input_args_shapes_array_size, + self.input_args_size, self.input_args_dtypes_array, self.input_args_formats_array) + graph_manager.graph_manager.assemble_outputs(self.graph_id, output_args_shapes_array, output_args_shapes_array_size, + self.output_args_size, self.output_args_dtypes_array, self.output_args_formats_array) + self.is_first_run = False + else: + graph_manager.graph_manager.update_inputs( + self.graph_id, input_args_shapes_array, input_args_shapes_array_size, self.input_args_size) + graph_manager.graph_manager.update_outputs( + self.graph_id, input_args_shapes_array, input_args_shapes_array_size, self.output_args_size) + + input_datasize_c = ( + c_int64 * len(input_datasize))(*input_datasize) + output_datasize_c = ( + c_int64 * len(output_datasize))(*output_datasize) + + if allocated_output: + allocated_output_tensor = {} + for output_index, input_index in allocated_output.items(): + allocated_output_tensor[output_index] = inputs[input_index] + for i, shape in enumerate(output_shape): + if i in allocated_output.keys(): + item = allocated_output_tensor[i] + else: + item = torch.empty( + shape, dtype=self.output_dtypes[i], device=dipu_device_str) + output_ptrs.append(item.data_ptr()) + output_tensors.append(item) + else: + for i, shape in enumerate(output_shape): + item = torch.empty( + shape, dtype=self.output_dtypes[i], device=dipu_device_str) + output_ptrs.append(item.data_ptr()) + output_tensors.append(item) + + output_ptrs_c = (c_void_p * len(output_tensors))(*output_ptrs) + context, ret = acl.rt.get_context() + current_stream, ret = acl.rt.create_stream() + graph_manager.graph_manager.run((c_void_p)(context),self.graph_id, (c_void_p)(current_stream), input_ptrs_c, + output_ptrs_c, input_datasize_c, output_datasize_c) + ret = acl.rt.synchronize_stream(current_stream) + ret = acl.rt.destroy_stream(current_stream) + check_ret("acl.rt.synchronize_stream", ret) + return output_tensors + + +class GEModel(): + def __init__(self, graph_id, device_id, is_static=True, input_nodes=None, output_nodes=None) -> None: + atexit.register(self.cleanup) + if is_static: + self.exe = GEStaticGraphExecutor(graph_id, device_id) + else: + self.exe = GEDynamicGraphExecutor( + graph_id, device_id, input_nodes, output_nodes) + + def run(self, images, dims=None, output_shape=None, + out_stride=None, out_storage_offset=None, allocated_output=None): + return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output) + + def cleanup(self): + if hasattr(self, 'exe'): + del self.exe + class AscendExecutor(object): @@ -257,7 +602,7 @@ def init_resource(self): dtype = acl.mdl.get_output_data_type(self.model_desc, i) dims, ret = acl.mdl.get_output_dims(self.model_desc, i) check_ret("acl.mdl.get_output_dims", ret) - self.output_dtypes.append(get_tensor_dtype(dtype)) + self.output_dtypes.append(get_torch_dtype(dtype)) self.output_dims.append(dims["dims"]) self.output_size.append(temp_buffer_size) data_buf = acl.create_data_buffer(0, 1) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/utils.py b/dicp/dicp/vendor/AscendGraph/codegen/utils.py index 404d827f9..2fef82691 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/utils.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/utils.py @@ -109,6 +109,15 @@ def symint_in_shape(shape): return False +def get_ascend_format_num(format: str): + if format == 'NCHW': + return AclFormat.ACL_FORMAT_HWCN.value + elif format == 'ND': + return AclFormat.ACL_FORMAT_ND.value + else: + raise RuntimeError(f"unknow format ({format}) in get_ascend_format_num!") + + def get_ascend_dtype_num(dtype: str): if dtype == "FLOAT": return AclDataType.ACL_FLOAT.value diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py index f9720bc31..f38de9725 100644 --- a/dicp/dicp/vendor/AscendGraph/compile_job.py +++ b/dicp/dicp/vendor/AscendGraph/compile_job.py @@ -1,17 +1,110 @@ import os import subprocess import time +import json +import acl +import torch +import torch_dipu import dicp from dicp.dynamo_bridge.compile import DeviceCompileJob from torch._inductor.codecache import pick_vec_isa, cpp_compile_command, write, code_hash from torch._inductor import exc +from dicp.vendor.AscendGraph.codegen import load_and_run -class AscendCompileJob(DeviceCompileJob): + +class AscendGECompileGERunJob(DeviceCompileJob): + def __init__(self, source_code) -> None: + super().__init__() + third_party_path = dicp.__file__.replace( + '/__init__.py', '') + "/third_party" + graph_util_path = load_and_run.__file__.replace('/load_and_run.py', '') + source_path = graph_util_path + '/graph_compile.cpp' + source_include = graph_util_path + '/graph_utils.h' + compile_file_code = '' + for file in [source_path, source_include]: + with open(file, 'r') as f: + compile_file_code += f.read() + picked_vec_isa = pick_vec_isa() + self.device_id = torch_dipu.current_device() + self.graph = json.loads(source_code.strip()) + self._key, self._input_path = write( + source_code.strip(), + "json", + extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) + + str(self.device_id) + code_hash(compile_file_code) + ) + self._lib_path = "/tmp/dicp_ascend/ge_graph.so" + json_util_path = third_party_path + '/nlohmann' + half_util_path = third_party_path + '/half/include' + self.fusion_switch_file = graph_util_path + '/fusion_switch.cfg' + self._cmd = ['/usr/bin/c++', + '-D_GLIBCXX_USE_CXX11_ABI=0', + '-fPIC', + '-std=c++17', + '-O3', + '-shared', + '-Wall', + '-I/usr/local/Ascend/ascend-toolkit/latest/include', + '-I/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_proto/inc', + '-I/usr/local/Ascend/ascend-toolkit/latest/include/graph', + '-I/usr/local/Ascend/ascend-toolkit/latest/include/ge', + '-I/usr/local/Ascend/ascend-toolkit/latest/parser', + '-I/usr/local/Ascend/ascend-toolkit/latest/compiler/include', + f'-I{graph_util_path}', + f'-I{json_util_path}', + f'-I{half_util_path}', + '-L/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub', + '-lgraph', + '-lge_runner', + source_path, + '-o' + self._lib_path, + '/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub/libgraph.so', + '/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub/libge_runner.so', + '/usr/local/Ascend/ascend-toolkit/latest/lib64/libgraph_base.so', + '/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64/stub/libascendcl.so',] + + def _compile(self): + if not os.path.exists(self._lib_path): + os.system("mkdir -p /tmp/dicp_ascend") + start = time.time() + try: + print(' '.join(self._cmd)) + subprocess.check_output(self._cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(self._cmd, e.output) from e + print('compile time:', time.time() - start) + + def get_key(self): + return self._key + + def get_compile_result(self): + self._compile() + context, ret = acl.rt.get_context() + graph_manager = load_and_run.get_graph_manager() + current_graph_id = load_and_run.graph_id + load_and_run.graph_id = load_and_run.graph_id + 1 + graph_key = f'{self._key}_graph{current_graph_id}_device{self.device_id}' + + graph_manager.add_graph( + current_graph_id, self._input_path.encode(), graph_key.encode()) + ret = acl.rt.set_context(context) + is_static = not self.graph['has_dynamic_shape'] + if is_static: + input_nodes = None + output_nodes = None + else: + input_nodes = self.graph['data_nodes'] + output_nodes = self.graph['output_nodes'] + return load_and_run.GEModel(current_graph_id, self.device_id, is_static, input_nodes=input_nodes, output_nodes=output_nodes) + + +class AscendGECompileAclRunJob(DeviceCompileJob): def __init__(self, source_code) -> None: super().__init__() - third_party_path = dicp.__file__.replace('/__init__.py', '') + "/third_party" + third_party_path = dicp.__file__.replace( + '/__init__.py', '') + "/third_party" from dicp.vendor.AscendGraph.codegen import load_and_run graph_util_path = load_and_run.__file__.replace('/load_and_run.py', '') source_path = graph_util_path + '/graph_compile.cpp' @@ -21,26 +114,27 @@ def __init__(self, source_code) -> None: with open(file, 'r') as f: compile_file_code += f.read() picked_vec_isa = pick_vec_isa() - self._local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.device_id = torch_dipu.current_device() self._key, self._input_path = write( source_code.strip(), "json", extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) + - 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) + str(self.device_id) + code_hash(compile_file_code) ) self._output_graph_path = self._input_path[:-5] + '/graph' - print('output_path: ', self._output_graph_path) self._model_path = [f'{self._output_graph_path}.om', f'{self._output_graph_path}_linux_x86_64.om'] - self._lib_path = "/tmp/dicp_ascend/graph_compile" + self._lib_path = "/tmp/dicp_ascend/ge_graph.so" json_util_path = third_party_path + '/nlohmann' half_util_path = third_party_path + '/half/include' self.fusion_switch_file = graph_util_path + '/fusion_switch.cfg' + self.global_options_file = graph_util_path + '/ge_builder_config.json' self._cmd = ['/usr/bin/c++', '-D_GLIBCXX_USE_CXX11_ABI=0', '-fPIC', - '-std=c++11', + '-std=c++17', '-O3', + '-shared', '-Wall', '-I/usr/local/Ascend/ascend-toolkit/latest/include', '-I/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_proto/inc', @@ -66,6 +160,7 @@ def _compile(self): os.system("mkdir -p /tmp/dicp_ascend") start = time.time() try: + print(' '.join(self._cmd)) subprocess.check_output(self._cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: raise exc.CppCompileError(self._cmd, e.output) from e @@ -76,11 +171,10 @@ def get_key(self): def build_graph(self, output_path, graph_path): self._compile() - cmd = [self._lib_path, output_path, graph_path, self.fusion_switch_file] - try: - subprocess.check_output(cmd, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - raise exc.CppCompileError(cmd, e.output) from e + + graph_compiler = load_and_run.get_graph_compiler() + graph_compiler.compile_and_save(output_path.encode(), graph_path.encode( + ), self.fusion_switch_file.encode(), self.global_options_file.encode()) def get_compile_result(self): if (not os.path.exists(self._model_path[0]) and not os.path.exists(self._model_path[1])): @@ -92,4 +186,4 @@ def get_compile_result(self): self._output_graph_path = origin_graph_path + '_linux_aarch64' assert (os.path.exists(self._output_graph_path + '.om')) from dicp.vendor.AscendGraph.codegen.load_and_run import AscendModel - return AscendModel(self._local_rank, self._output_graph_path + '.om') + return AscendModel(self.device_id, self._output_graph_path + '.om')