Skip to content

Commit

Permalink
Fix a problem where the session could persist after disconnect (#26)
Browse files Browse the repository at this point in the history
* Fix a problem where the session could persist after disconnect

* Ensure that will is published on missing connection
  • Loading branch information
mworrell authored Jan 16, 2025
1 parent 056c2e0 commit 7707787
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
28 changes: 19 additions & 9 deletions src/mqtt_sessions_process.erl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
%% @author Marc Worrell <[email protected]>
%% @copyright 2018-2024 Marc Worrell
%% @copyright 2018-2025 Marc Worrell
%% @doc Process handling one single MQTT session.
%% MQTT connections attach and detach from this session. Buffers outgoing
%% messages if there is not connection attached.
%% @end

%% Copyright 2018-2024 Marc Worrell
%% Copyright 2018-2025 Marc Worrell
%%
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,6 +42,7 @@
update_user_context/2,

get_transport/1,
is_connected/1,
kill/1,
incoming_connect/3,
incoming_data/2,
Expand Down Expand Up @@ -168,6 +169,15 @@ get_transport(Pid) ->
{error, noproc}
end.

-spec is_connected( pid() ) -> boolean().
is_connected(Pid) ->
try
gen_server:call(Pid, is_connected, infinity)
catch
exit:{noproc, _} ->
false
end.

-spec kill( pid() ) -> ok.
kill(Pid) when is_pid(Pid) ->
MRef = monitor(process, Pid),
Expand Down Expand Up @@ -237,6 +247,9 @@ handle_call(get_transport, _From, #state{ transport = undefined } = State) ->
handle_call(get_transport, _From, #state{ transport = Transport } = State) ->
{reply, {ok, Transport}, State};

handle_call(is_connected, _From, #state{ is_connected = IsConnected } = State) ->
{reply, IsConnected, State};

handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) ->
Data1 = << Data/binary, NewData/binary >>,
case handle_incoming_data(Data1, State) of
Expand Down Expand Up @@ -299,7 +312,7 @@ handle_info({publish_job, JobPid}, #state{ publish_jobs = Jobs } = State) when i
{noreply, State1};

handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ connection_pid = Pid } = State) ->
State1 = do_disconnected(State),
State1 = cleanup_state_disconnected(State),
{noreply, State1};
handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ will_pid = Pid } = State) ->
send_transport(#{
Expand Down Expand Up @@ -1002,12 +1015,9 @@ mark_packet_sent(PacketId, #state{ awaiting_ack = AwaitAck } = State) ->


%% @doc Called when the connection disconnects or crashes/stops
do_disconnected(#state{ will_pid = WillPid } = State) ->
mqtt_sessions_will:disconnected(WillPid),
cleanup_state_disconnected(State).

%% @todo Cleanup pending messages and awaiting states.
cleanup_state_disconnected(State) ->
cleanup_state_disconnected(#state{ will_pid = WillPid } = State) ->
mqtt_sessions_will:disconnected(WillPid),
delete_buffered_qos0(State#state{
connection_pid = undefined,
transport = undefined,
Expand Down Expand Up @@ -1056,7 +1066,7 @@ extract_will(#{ type := connect, will_flag := true, properties := Props } = Msg)

force_disconnect(#state{ connection_pid = undefined, transport = undefined } = State) ->
State;
force_disconnect(State) ->
force_disconnect(#state{ will_pid = WillPid } = State) ->
State1 = disconnect_transport(State),
if
is_pid(State#state.connection_pid) ->
Expand Down
22 changes: 18 additions & 4 deletions src/mqtt_sessions_will.erl
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@
user_context :: term(),
session_expiry_interval :: non_neg_integer(),
expiry_ref = undefined :: reference() | undefined,
interval_timer_ref = undefined,
timer_ref = undefined,
is_stopping :: boolean()
}).

%% The connect handshake must complete in 20 seconds.
-define(CONNECT_EXPIRY_INTERVAL, 20).

%% Every minute we do a check with the session to see if it is connected.
%% This is to catch any missed disconnects.
-define(CONNECTED_CHECK_INTERVAL, 60).


-spec start_link( atom(), pid() ) -> {ok, pid()}.
start_link(Pool, SessionPid) ->
Expand Down Expand Up @@ -99,9 +104,11 @@ stop(Pid) ->

init([ Pool, SessionPid ]) ->
erlang:monitor(process, SessionPid),
{ok, Timer} = timer:send_interval(?CONNECTED_CHECK_INTERVAL * 1000, check_session_connected),
State = #state{
pool = Pool,
session_pid = SessionPid,
interval_timer_ref = Timer,
will = #{},
is_stopping = false,
session_expiry_interval = 0
Expand Down Expand Up @@ -136,8 +143,10 @@ handle_cast(reconnected, State) ->

handle_cast({disconnected, IsWill, ExpiryInterval}, State) ->
{noreply, do_disconnected(State, IsWill, ExpiryInterval)};
handle_cast(disconnected, State) ->
handle_cast(disconnected, #state{ timer_ref = undefined } = State) ->
{noreply, do_disconnected(State, true, undefined)};
handle_cast(disconnected, State) ->
{noreply, State};

handle_cast({user_context, UserContext}, State) ->
{noreply, State#state{ user_context = UserContext }};
Expand All @@ -158,11 +167,16 @@ handle_info({expired, Ref}, #state{ expiry_ref = Ref } = State) ->
mqtt_sessions_process:kill(State#state.session_pid),
do_publish_will(State),
{stop, shutdown, State};
handle_info({expired, Ref}, #state{ expiry_ref = Ref } = State) ->
do_publish_will(State),
{stop, shutdown, State};
handle_info({expired, _Ref}, State) ->
% old timer - ignore
{noreply, State};
handle_info(check_session_connected, #state{ session_pid = Pid, timer_ref = undefined } = State) ->
State1 = case mqtt_sessions_process:is_connected(Pid) of
true -> State;
false -> do_disconnected(State, true, undefined)
end,
{noreply, State1};
handle_info(check_session_connected, #state{} = State) ->
{noreply, State}.

code_change(_Vsn, State, _Extra) ->
Expand Down

0 comments on commit 7707787

Please sign in to comment.