diff --git a/capture_service/server.cc b/capture_service/server.cc index aebf6016..8c5ab528 100644 --- a/capture_service/server.cc +++ b/capture_service/server.cc @@ -16,7 +16,8 @@ limitations under the License. #include "server.h" -int main(int argc, char **argv) { - Dive::server_main(); +int main(int argc, char **argv) +{ + Dive::ServerMain(); return 0; } diff --git a/capture_service/server.h b/capture_service/server.h index 8e983b64..98eba237 100644 --- a/capture_service/server.h +++ b/capture_service/server.h @@ -18,5 +18,6 @@ limitations under the License. namespace Dive { -int server_main(); -} +int ServerMain(); +void StopServer(); +} // namespace Dive diff --git a/capture_service/service.cc b/capture_service/service.cc index 3f3ab0b0..193b666f 100644 --- a/capture_service/service.cc +++ b/capture_service/service.cc @@ -132,8 +132,26 @@ grpc::Status DiveServiceImpl::DownloadFile(grpc::ServerContext *cont return grpc::Status::OK; } +std::unique_ptr &GetServer() +{ + static std::unique_ptr server = nullptr; + return server; +} + +void StopServer() +{ + auto &server = GetServer(); + if (server) + { + LOGI("StopServer at service.cc"); + server->Shutdown(); + server = nullptr; + } +} + void RunServer(uint16_t port) { + LOGI("port is %d\n", port); std::string server_address = absl::StrFormat("0.0.0.0:%d", port); DiveServiceImpl service; @@ -142,12 +160,13 @@ void RunServer(uint16_t port) builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); + auto &server = GetServer(); + server = builder.BuildAndStart(); LOGI("Server listening on %s", server_address.c_str()); server->Wait(); } -int server_main() +int ServerMain() { RunServer(absl::GetFlag(FLAGS_port)); return 0; diff --git a/layer/layer_common.cc b/layer/layer_common.cc index 22e3344a..ff36bc8f 100644 --- a/layer/layer_common.cc +++ b/layer/layer_common.cc @@ -14,6 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ +#ifdef _WIN32 +# include +#else +# include +#endif + #include "layer_common.h" #include @@ -58,7 +64,7 @@ ServerRunner::ServerRunner() LOGI("libwrap loaded: %d", is_libwrap_loaded); if (is_libwrap_loaded) { - server_thread = std::thread(Dive::server_main); + server_thread = std::thread(Dive::ServerMain); } } @@ -66,6 +72,7 @@ ServerRunner::~ServerRunner() { if (is_libwrap_loaded && server_thread.joinable()) { + Dive::StopServer(); LOGI("Wait for server thread to join"); server_thread.join(); } @@ -77,4 +84,41 @@ ServerRunner &GetServerRunner() return runner; } -} // namespace DiveLayer \ No newline at end of file +} // namespace DiveLayer + +void PreventLibraryUnload() +{ +#ifdef _WIN32 + HMODULE module = nullptr; + GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_PIN, + reinterpret_cast(&PreventLibraryUnload), + &module); +#else + Dl_info info; + if (dladdr(reinterpret_cast(&PreventLibraryUnload), &info)) + { + dlopen(info.dli_fname, RTLD_NOW | RTLD_NOLOAD | RTLD_LOCAL | RTLD_NODELETE); + } +#endif +} + +#ifdef _WIN32 +BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) +{ + if (fdwReason == DLL_PROCESS_ATTACH) + { + [[maybe_unused]] auto &server = DiveLayer::GetServerRunner(); + PreventLibraryUnload(); + } + return TRUE; +} +#else +extern "C" +{ + __attribute__((constructor)) void InitializeLibrary() + { + [[maybe_unused]] auto &server = DiveLayer::GetServerRunner(); + PreventLibraryUnload(); + } +} +#endif \ No newline at end of file diff --git a/layer/openxr_layer.cc b/layer/openxr_layer.cc index 2ad7355d..8f4dd230 100644 --- a/layer/openxr_layer.cc +++ b/layer/openxr_layer.cc @@ -56,12 +56,6 @@ struct XrSessionData { XrSession session; XrGeneratedDispatchTable dispatch_table; - ServerRunner &server; - - XrSessionData() : - server(GetServerRunner()) - { - } }; static thread_local XrInstanceData *last_used_xr_instance_data = nullptr; @@ -190,8 +184,7 @@ XRAPI_ATTR XrResult XRAPI_CALL ApiDiveLayerXrDestroyInstance(XrInstance instance LOGD("ApiDiveLayerXrDestroyInstance\n"); XrResult result = XR_SUCCESS; - - auto sess_data = GetXrInstanceLayerData(DataKey(instance)); + auto sess_data = GetXrInstanceLayerData(DataKey(instance)); if (sess_data) { result = sess_data->dispatch_table.DestroyInstance(instance); diff --git a/layer/vk_layer_base.cc b/layer/vk_layer_base.cc index 878cf61d..a669b630 100644 --- a/layer/vk_layer_base.cc +++ b/layer/vk_layer_base.cc @@ -43,12 +43,6 @@ struct InstanceData { VkInstance instance; InstanceDispatchTable dispatch_table; - ServerRunner &server; - - InstanceData() : - server(GetServerRunner()) - { - } }; struct DeviceData