diff --git a/src/depcache.erl b/src/depcache.erl index 8f665a2..4343d44 100644 --- a/src/depcache.erl +++ b/src/depcache.erl @@ -297,12 +297,14 @@ memo(Fun, Key, MaxAge, Server) -> Result :: any(). memo(Fun, Key, MaxAge, Dep, Server) -> Key1 = case Key of - undefined -> memo_key(Fun); - _ -> Key - end, + undefined -> memo_key(Fun); + _ -> Key + end, case ?MODULE:get_wait(Key1, Server) of {ok, Value} -> Value; + {throw, premature_exit} -> + ?MODULE:memo(Fun, Key, MaxAge, Dep, Server); {throw, R} -> throw(R); undefined -> @@ -323,31 +325,58 @@ memo(Fun, Key, MaxAge, Dep, Server) -> Server :: depcache_server(), Result :: any(). memo_key(Fun, Key, MaxAge, Dep, Server) -> - try - Value = - case Fun of - {M,F,A} -> erlang:apply(M,F,A); - {M,F} -> M:F(); - _ when is_function(Fun) -> Fun() - end, - {Value1, MaxAge1, Dep1} = - case Value of - #memo{value=V, max_age=MA, deps=D} -> - MA1 = case is_integer(MA) of true -> MA; false -> MaxAge end, - {V, MA1, Dep++D}; - _ -> - {Value, MaxAge, Dep} - end, - case MaxAge of - 0 -> memo_send_replies(Key, Value1, Server); - _ -> set(Key, Value1, MaxAge1, Dep1, Server) - end, - Value1 - catch - ?WITH_STACKTRACE(Class, R, S) - memo_send_errors(Key, {throw, R}, Server), - erlang:raise(Class, R, S) - end. + ExitWatcher = start_exit_watcher(Key, Server), + try + try + {Value1, MaxAge1, Dep1} = case apply_fun(Fun) of + #memo{value=V, max_age=MA, deps=D} -> + MA1 = case is_integer(MA) of + true -> MA; + false -> MaxAge + end, + {V, MA1, Dep++D}; + Value -> + {Value, MaxAge, Dep} + end, + case MaxAge of + 0 -> memo_send_replies(Key, Value1, Server); + _ -> set(Key, Value1, MaxAge1, Dep1, Server) + end, + + Value1 + catch + ?WITH_STACKTRACE(Class, R, S) + memo_send_errors(Key, {throw, R}, Server), + erlang:raise(Class, R, S) + end + after + stop_exit_watcher(ExitWatcher) + end. + +%% @private +%% @doc Monitors the current process... +%% Sends premature_exit throw to depcache server when it detects one. +start_exit_watcher(Key, Server) -> + Self = self(), + spawn(fun() -> + Ref = monitor(process, Self), + receive + done -> + erlang:demonitor(Ref); + {'DOWN', Ref, process, Self, _Reason} -> + memo_send_errors(Key, {throw, premature_exit}, Server) + end + end). + +stop_exit_watcher(Pid) -> + Pid ! done. + +%% @private +%% @doc Execute the memo function +%% Returns the result value +apply_fun({M,F,A}) -> erlang:apply(M,F,A); +apply_fun({M,F}) -> M:F(); +apply_fun(Fun) when is_function(Fun) -> Fun(). %% @private diff --git a/test/depcache_tests.erl b/test/depcache_tests.erl index 906892d..a511221 100644 --- a/test/depcache_tests.erl +++ b/test/depcache_tests.erl @@ -145,3 +145,28 @@ memo_raise_test() -> ?assertMatch({depcache_tests, raise_error, 0, _}, hd(S)) end, ok. + +memo_premature_kill_test() -> + {ok, C} = depcache:start_link(#{}), + + LongTask = fun() -> + Fun = fun() -> + timer:sleep(500), + done + end, + depcache:memo(Fun, test, C) + end, + + Pid = spawn(LongTask), + timer:kill_after(250, Pid), + timer:sleep(50), + ?assertEqual({throw, premature_exit}, depcache:get_wait(test, C)), + + % Check if another process takes over processing in case of pre-mature exits + Task = spawn(LongTask), + timer:kill_after(250, Task), + timer:sleep(50), + ?assertEqual(done, LongTask()), + + ok. +