diff --git a/src/mqtt_sessions_process.erl b/src/mqtt_sessions_process.erl index 7a8afa2..5933a69 100644 --- a/src/mqtt_sessions_process.erl +++ b/src/mqtt_sessions_process.erl @@ -1,11 +1,11 @@ %% @author Marc Worrell -%% @copyright 2018-2024 Marc Worrell +%% @copyright 2018-2025 Marc Worrell %% @doc Process handling one single MQTT session. %% MQTT connections attach and detach from this session. Buffers outgoing %% messages if there is not connection attached. %% @end -%% Copyright 2018-2024 Marc Worrell +%% Copyright 2018-2025 Marc Worrell %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -42,6 +42,7 @@ update_user_context/2, get_transport/1, + is_connected/1, kill/1, incoming_connect/3, incoming_data/2, @@ -168,6 +169,15 @@ get_transport(Pid) -> {error, noproc} end. +-spec is_connected( pid() ) -> boolean(). +is_connected(Pid) -> + try + gen_server:call(Pid, is_connected, infinity) + catch + exit:{noproc, _} -> + false + end. + -spec kill( pid() ) -> ok. kill(Pid) when is_pid(Pid) -> MRef = monitor(process, Pid), @@ -237,6 +247,9 @@ handle_call(get_transport, _From, #state{ transport = undefined } = State) -> handle_call(get_transport, _From, #state{ transport = Transport } = State) -> {reply, {ok, Transport}, State}; +handle_call(is_connected, _From, #state{ is_connected = IsConnected } = State) -> + {reply, IsConnected, State}; + handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) -> Data1 = << Data/binary, NewData/binary >>, case handle_incoming_data(Data1, State) of @@ -299,7 +312,7 @@ handle_info({publish_job, JobPid}, #state{ publish_jobs = Jobs } = State) when i {noreply, State1}; handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ connection_pid = Pid } = State) -> - State1 = do_disconnected(State), + State1 = cleanup_state_disconnected(State), {noreply, State1}; handle_info({'DOWN', _Mref, process, Pid, _Reason}, #state{ will_pid = Pid } = State) -> send_transport(#{ @@ -1002,12 +1015,9 @@ mark_packet_sent(PacketId, #state{ awaiting_ack = AwaitAck } = State) -> %% @doc Called when the connection disconnects or crashes/stops -do_disconnected(#state{ will_pid = WillPid } = State) -> - mqtt_sessions_will:disconnected(WillPid), - cleanup_state_disconnected(State). - %% @todo Cleanup pending messages and awaiting states. -cleanup_state_disconnected(State) -> +cleanup_state_disconnected(#state{ will_pid = WillPid } = State) -> + mqtt_sessions_will:disconnected(WillPid), delete_buffered_qos0(State#state{ connection_pid = undefined, transport = undefined, @@ -1056,7 +1066,7 @@ extract_will(#{ type := connect, will_flag := true, properties := Props } = Msg) force_disconnect(#state{ connection_pid = undefined, transport = undefined } = State) -> State; -force_disconnect(State) -> +force_disconnect(#state{ will_pid = WillPid } = State) -> State1 = disconnect_transport(State), if is_pid(State#state.connection_pid) -> diff --git a/src/mqtt_sessions_will.erl b/src/mqtt_sessions_will.erl index f3c1ccb..9db4d1b 100644 --- a/src/mqtt_sessions_will.erl +++ b/src/mqtt_sessions_will.erl @@ -46,6 +46,7 @@ user_context :: term(), session_expiry_interval :: non_neg_integer(), expiry_ref = undefined :: reference() | undefined, + interval_timer_ref = undefined, timer_ref = undefined, is_stopping :: boolean() }). @@ -53,6 +54,10 @@ %% The connect handshake must complete in 20 seconds. -define(CONNECT_EXPIRY_INTERVAL, 20). +%% Every minute we do a check with the session to see if it is connected. +%% This is to catch any missed disconnects. +-define(CONNECTED_CHECK_INTERVAL, 60). + -spec start_link( atom(), pid() ) -> {ok, pid()}. start_link(Pool, SessionPid) -> @@ -99,9 +104,11 @@ stop(Pid) -> init([ Pool, SessionPid ]) -> erlang:monitor(process, SessionPid), + {ok, Timer} = timer:send_interval(?CONNECTED_CHECK_INTERVAL * 1000, check_session_connected), State = #state{ pool = Pool, session_pid = SessionPid, + interval_timer_ref = Timer, will = #{}, is_stopping = false, session_expiry_interval = 0 @@ -136,8 +143,10 @@ handle_cast(reconnected, State) -> handle_cast({disconnected, IsWill, ExpiryInterval}, State) -> {noreply, do_disconnected(State, IsWill, ExpiryInterval)}; -handle_cast(disconnected, State) -> +handle_cast(disconnected, #state{ timer_ref = undefined } = State) -> {noreply, do_disconnected(State, true, undefined)}; +handle_cast(disconnected, State) -> + {noreply, State}; handle_cast({user_context, UserContext}, State) -> {noreply, State#state{ user_context = UserContext }}; @@ -158,11 +167,16 @@ handle_info({expired, Ref}, #state{ expiry_ref = Ref } = State) -> mqtt_sessions_process:kill(State#state.session_pid), do_publish_will(State), {stop, shutdown, State}; -handle_info({expired, Ref}, #state{ expiry_ref = Ref } = State) -> - do_publish_will(State), - {stop, shutdown, State}; handle_info({expired, _Ref}, State) -> % old timer - ignore + {noreply, State}; +handle_info(check_session_connected, #state{ session_pid = Pid, timer_ref = undefined } = State) -> + State1 = case mqtt_sessions_process:is_connected(Pid) of + true -> State; + false -> do_disconnected(State, true, undefined) + end, + {noreply, State1}; +handle_info(check_session_connected, #state{} = State) -> {noreply, State}. code_change(_Vsn, State, _Extra) ->