Skip to content

Commit

Permalink
[METAL] Split kernels and compile them separately (apache#7980)
Browse files Browse the repository at this point in the history
  • Loading branch information
echuraev authored May 25, 2021
1 parent aefa0c8 commit dc5fc68
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 86 deletions.
1 change: 1 addition & 0 deletions apps/android_camera/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
#include "../src/runtime/opencl/opencl_module.cc"
#include "../src/runtime/source_utils.cc"
#endif

#ifdef TVM_VULKAN_RUNTIME
Expand Down
1 change: 1 addition & 0 deletions apps/android_rpc/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
#include "../src/runtime/opencl/opencl_module.cc"
#include "../src/runtime/source_utils.cc"
#endif

#ifdef TVM_VULKAN_RUNTIME
Expand Down
1 change: 1 addition & 0 deletions golang/src/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@
// Uncomment the following lines to enable OpenCL
// #include "../../src/runtime/opencl/opencl_device_api.cc"
// #include "../../src/runtime/opencl/opencl_module.cc"
// #include "../src/runtime/source_utils.cc"
79 changes: 45 additions & 34 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "../file_utils.h"
#include "../meta_data.h"
#include "../pack_args.h"
#include "../source_utils.h"
#include "../thread_storage_scope.h"
#include "metal_common.h"

Expand All @@ -43,7 +44,9 @@
public:
explicit MetalModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
parsed_kernels_ = SplitKernels(data);
}
const char* type_key() const final { return "metal"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
Expand Down Expand Up @@ -71,6 +74,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
return "";
}
}

// get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
Expand All @@ -85,44 +89,52 @@ void SaveToBinary(dmlc::Stream* stream) final {
if (it != e.smap.end()) return it->second;
// compile
NSError* err_msg = nil;
if (e.lib == nil) {
if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
opts.fastMathEnabled = YES;
// opts = nil;
e.lib = [w->devices[device_id]
newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
options:opts
error:&err_msg];
[opts dealloc];
if (e.lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
if (err_msg != nil) {
LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
}
} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
auto data = dispatch_data_create(data_.c_str(), data_.length(), q,
^{
});
e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg];
if (err_msg != nil || e.lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
id<MTLLibrary> lib = nil;
std::string source;
auto kernel = parsed_kernels_.find(func_name);
// If we cannot find this kernel in parsed_kernels_, it means that all kernels going together
// without explicit separator. In this case we use data_ with all kernels. It done for backward
// compatibility.
if (kernel != parsed_kernels_.end())
source = kernel->second;
else
source = data_;
if (fmt_ == "metal") {
MTLCompileOptions* opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_3;
opts.fastMathEnabled = YES;
// opts = nil;
lib =
[w->devices[device_id] newLibraryWithSource:[NSString stringWithUTF8String:source.c_str()]
options:opts
error:&err_msg];
[opts dealloc];
if (lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
if (err_msg != nil) {
LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
}
} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
auto data = dispatch_data_create(source.c_str(), source.length(), q,
^{
});
lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg];
if (err_msg != nil || lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
}
}
id<MTLFunction> f =
[e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
id<MTLFunction> f = [lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
ICHECK(f != nil) << "cannot find function " << func_name;
id<MTLComputePipelineState> state =
[w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg];
ICHECK(state != nil) << "cannot get state:"
<< " for function " << func_name
<< [[err_msg localizedDescription] UTF8String];
[f release];
[lib release];
// The state.threadExecutionWidth can change dynamically according
// to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now.
Expand All @@ -135,13 +147,10 @@ void SaveToBinary(dmlc::Stream* stream) final {
private:
// device specific entry
struct DeviceEntry {
// library
id<MTLLibrary> lib = nil;
// state cache;
std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
std::unordered_map<std::string, id<MTLComputePipelineState>> smap;

~DeviceEntry() {
if (lib != nil) [lib release];
for (auto&& kv : smap) {
[kv.second release];
}
Expand All @@ -159,6 +168,8 @@ void SaveToBinary(dmlc::Stream* stream) final {
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
// parsed kernel data
std::unordered_map<std::string, std::string> parsed_kernels_;
};

// a wrapped function class to get packed func.
Expand Down
8 changes: 0 additions & 8 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,6 @@ class OpenCLModuleNode : public ModuleNode {
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);

/*
* \brief Splits the provided serialized source file into separate
* source for each kernel primitive.
* \param source The serialized program source file (fmt: cl)
* \return Mapping from primitive name to kernel source
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;

private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
Expand Down
39 changes: 6 additions & 33 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <unordered_map>
#include <vector>

#include "../source_utils.h"
#include "opencl_common.h"

namespace tvm {
Expand Down Expand Up @@ -188,6 +189,11 @@ void OpenCLModuleNode::Init() {

// split into source artifacts for each kernel
parsed_kernels_ = SplitKernels(GetSource("cl"));
ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size())
<< "The number of registered kernels does not match number of parsed kernel sources";
// zero initialize cl_program pointers for each device kernel
for (auto& kv : parsed_kernels_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
Expand Down Expand Up @@ -242,39 +248,6 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
return kernel;
}

std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
std::string source) const {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
std::string del{"// Function: "};
size_t end;
size_t begin = source.find(del);
ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
begin += del.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
// std::string::substr returns either start of next kernel
// or std::string::npos, in the latter case substr returns
// all characters until the end of the source string.
end = source.find(del, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
if (end == std::string::npos) {
break;
}
}
}
ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
<< "The number of registered kernels does not match number of parsed kernel sources";
return split_kernels;
}

Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
Expand Down
49 changes: 49 additions & 0 deletions src/runtime/source_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file source_utils.cc
*/
#include "source_utils.h"

namespace tvm {
namespace runtime {

std::unordered_map<std::string, std::string> SplitKernels(std::string source,
std::string delimiter) {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
size_t begin = source.find(delimiter);
size_t end = begin;
while (end != std::string::npos) {
begin += delimiter.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
end = source.find(delimiter, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
}
}
return split_kernels;
}
} // namespace runtime
} // namespace tvm
44 changes: 44 additions & 0 deletions src/runtime/source_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file source_utils.h
* \brief Minimum source manipulation utils for runtime.
*/

#ifndef TVM_RUNTIME_SOURCE_UTILS_H_
#define TVM_RUNTIME_SOURCE_UTILS_H_

#include <string>
#include <unordered_map>

namespace tvm {
namespace runtime {
/*!
* \brief Split the source file on separate kernels by specified delimiter.
* \param source The source code of the kernels.
* \param delimiter The delimiter which is using for splitting kernels.
* \return Mapping from primitive name to kernel source
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source,
std::string delimiter = "// Function: ");
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_SOURCE_UTILS_H_
25 changes: 14 additions & 11 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,27 +325,30 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO
runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
cg.Init(output_ssa);

std::stringstream code;
std::stringstream source;
std::string fmt = "metal";
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
code << "// Function: " << kv.first->name_hint << std::endl;
CodeGenMetal cg;
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
std::string fsource = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source << fsource;
fsource = (*f)(fsource).operator std::string();
fmt = "metallib";
}
code << fsource;
}

std::string code = cg.Finish();
std::string fmt = "metal";
std::string source = "";
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source = code;
code = (*f)(code).operator std::string();
fmt = "metallib";
}
return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
}

TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
Expand Down

0 comments on commit dc5fc68

Please sign in to comment.