diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index 92eece72..6be60774 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -3,6 +3,9 @@ #ifdef USE_ACCMI #ifdef USE_MIOPEN +#include +#include + #include #include "caffe/common.hpp" @@ -41,6 +44,48 @@ namespace caffe { namespace miopen { +class miopenHandleMap { +private: + miopenHandleMap() : map_(), lock_() {} +public: + static miopenHandleMap& getInstance() { + static miopenHandleMap instance_{}; + return instance_; + } + + ~miopenHandleMap() { + for (auto iter : map_) { + miopenDestroy(iter.second); + } + map_.clear(); + } + + miopenHandle_t getHandle(int device) { + miopenHandle_t ret = nullptr; + std::lock_guard l(lock_); + if (map_.find(device) != map_.end()) { + ret = map_[device]; + } + return ret; + } + + bool setHandle(int device, miopenHandle_t handle) { + bool ret = false; + std::lock_guard l(lock_); + if (map_.find(device) != map_.end()) { + LOG(FATAL) << "Duplicated MIOpen handle for device: " << device; + } else { + map_[device] = handle; + ret = true; + } + return ret; + } + +private: + std::unordered_map map_; + std::mutex lock_; +}; + template class dataType; template<> class dataType { public: diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 0496a4ca..22cd10e7 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -13,7 +13,7 @@ CuDNNConvolutionLayer::CuDNNConvolutionLayer(const LayerParameter& param) : ConvolutionLayer(param), handles_setup_(false), fwd_algo_(), bwd_weight_algo_(), bwd_data_algo_(), workspace_fwd_sizes_(), workspace_bwd_filter_sizes_(), workspace_bwd_data_sizes_(), - workspace() { } + workspace(), handle_(nullptr) { } /** * TODO(dox) explain cuDNN interface @@ -57,7 +57,18 @@ void CuDNNConvolutionLayer::LayerSetUp( for (int g = 0; g < this->group_ * WORKSPACE_PER_GROUP; g++) { #ifdef USE_MIOPEN - MIOPEN_CHECK(miopenCreateWithStream(&handle_, nullptr)); + int device; + HIP_CHECK(hipGetDevice(&device)); + + auto& hmap = caffe::miopen::miopenHandleMap::getInstance(); + handle_ = hmap.getHandle(device); + if (handle_ == nullptr) { + DLOG(INFO) << "Creating MIOpen handle on device: " << device; + MIOPEN_CHECK(miopenCreateWithStream(&handle_, nullptr)); + hmap.setHandle(device, handle_); + } else { + DLOG(INFO) << "Get MIOpen handle from cache on device: " << device; + } #endif workspace[g] = NULL; @@ -435,10 +446,6 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { #endif #endif -#ifdef USE_MIOPEN - miopenDestroy(handle_); -#endif - hipFree(workspaceData); }