diff --git a/rebar.config b/rebar.config index 0ab03ff..6d85987 100644 --- a/rebar.config +++ b/rebar.config @@ -1,9 +1,11 @@ {erl_opts, [debug_info]}. -{deps, [{chatterbox, {pkg, ts_chatterbox}}, +{deps, [ + {chatterbox, ".*", {git, "https://github.com/novalabsxyz/chatterbox", {branch, "master"}}}, ctx, - acceptor_pool, - gproc]}. + {acceptor_pool, {git, "https://github.com/novalabsxyz/acceptor_pool", {branch, "master"}}}, + gproc + ]}. {grpc, [{protos, ["proto"]}, {service_modules, [{'grpc.health.v1.Health', "grpcbox_health"}, @@ -48,7 +50,9 @@ deprecated_function_calls, deprecated_functions]}. {project_plugins, [covertool, - {grpcbox_plugin, "~> 0.7.0"}, + {grpcbox_plugin, + {git, "https://github.com/novalabsxyz/grpcbox_plugin.git", + {branch, "master"}}}, rebar3_lint]}. {cover_enabled, true}. diff --git a/rebar.lock b/rebar.lock index 71b74b7..fc56d62 100644 --- a/rebar.lock +++ b/rebar.lock @@ -1,19 +1,21 @@ {"1.2.0", -[{<<"acceptor_pool">>,{pkg,<<"acceptor_pool">>,<<"1.0.0">>},0}, - {<<"chatterbox">>,{pkg,<<"ts_chatterbox">>,<<"0.11.0">>},0}, +[{<<"acceptor_pool">>, + {git,"https://github.com/novalabsxyz/acceptor_pool", + {ref,"56d676e00c11fd071a6bcc4059e3454960900af7"}}, + 0}, + {<<"chatterbox">>, + {git,"https://github.com/novalabsxyz/chatterbox", + {ref,"cbfe6e46b273f1552b57685c9f6daf710473c609"}}, + 0}, {<<"ctx">>,{pkg,<<"ctx">>,<<"0.6.0">>},0}, {<<"gproc">>,{pkg,<<"gproc">>,<<"0.8.0">>},0}, {<<"hpack">>,{pkg,<<"hpack_erl">>,<<"0.2.3">>},1}]}. [ {pkg_hash,[ - {<<"acceptor_pool">>, <<"43C20D2ACAE35F0C2BCD64F9D2BDE267E459F0F3FD23DAB26485BF518C281B21">>}, - {<<"chatterbox">>, <<"B8F372C706023EB0DE5BF2976764EDB27C70FE67052C88C1F6A66B3A5626847F">>}, {<<"ctx">>, <<"8FF88B70E6400C4DF90142E7F130625B82086077A45364A78D208ED3ED53C7FE">>}, {<<"gproc">>, <<"CEA02C578589C61E5341FCE149EA36CCEF236CC2ECAC8691FBA408E7EA77EC2F">>}, {<<"hpack">>, <<"17670F83FF984AE6CD74B1C456EDDE906D27FF013740EE4D9EFAA4F1BF999633">>}]}, {pkg_hash_ext,[ - {<<"acceptor_pool">>, <<"0CBCD83FDC8B9AD2EEE2067EF8B91A14858A5883CB7CD800E6FCD5803E158788">>}, - {<<"chatterbox">>, <<"722FE2BAD52913AB7E87D849FC6370375F0C961FFB2F0B5E6D647C9170C382A6">>}, {<<"ctx">>, <<"A14ED2D1B67723DBEBBE423B28D7615EB0BDCBA6FF28F2D1F1B0A7E1D4AA5FC2">>}, {<<"gproc">>, <<"580ADAFA56463B75263EF5A5DF4C86AF321F68694E7786CB057FD805D1E2A7DE">>}, {<<"hpack">>, <<"06F580167C4B8B8A6429040DF36CC93BBA6D571FAEAEC1B28816523379CBB23A">>}]} diff --git a/src/grpcbox_pool.erl b/src/grpcbox_pool.erl index 2adc28f..e8d94a2 100644 --- a/src/grpcbox_pool.erl +++ b/src/grpcbox_pool.erl @@ -3,7 +3,8 @@ -behaviour(acceptor_pool). -export([start_link/4, - accept_socket/3]). + accept_socket/3, + pool_sockets/1]). -export([init/1]). @@ -13,6 +14,9 @@ start_link(Name, ServerOpts, ChatterboxOpts, TransportOpts) -> accept_socket(Pool, Socket, Acceptors) -> acceptor_pool:accept_socket(Pool, Socket, Acceptors). +pool_sockets(Pool) -> + acceptor_pool:which_sockets(Pool). + init([ServerOpts, ChatterboxOpts, TransportOpts]) -> {Transport, SslOpts} = case TransportOpts of #{ssl := true, diff --git a/src/grpcbox_reflection_service.erl b/src/grpcbox_reflection_service.erl index e9c262c..fd8abfe 100644 --- a/src/grpcbox_reflection_service.erl +++ b/src/grpcbox_reflection_service.erl @@ -9,38 +9,39 @@ #{error_code => 12, error_message => "unimplemented method since extensions removed in proto3"}}). -server_reflection_info(Ref, Stream) -> - receive - {Ref, eos} -> - ok; - {Ref, Message} -> - handle_message(Message, Stream), - server_reflection_info(Ref, Stream) - end. +server_reflection_info(Message, Stream) -> + handle_message(Message, Stream). +handle_message(eos=_OriginalRequest, Stream) -> + {stop, Stream}; handle_message(#{message_request := {list_services, _}}=OriginalRequest, Stream) -> Services = list_services(), - grpcbox_stream:send(#{original_request => OriginalRequest, + Stream0 = grpcbox_stream:send(false, #{original_request => OriginalRequest, message_response => {list_services_response, - #{service => Services}}}, Stream); + #{service => Services}}}, Stream), + {ok, Stream0}; handle_message(#{message_request := {file_by_filename, Filename}}=OriginalRequest, Stream) -> Response = file_by_filename(Filename), - grpcbox_stream:send(#{original_request => OriginalRequest, - message_response => Response}, Stream); + Stream0 = grpcbox_stream:send(false, #{original_request => OriginalRequest, + message_response => Response}, Stream), + {ok, Stream0}; handle_message(#{message_request := {file_containing_symbol, Symbol}}=OriginalRequest, Stream) -> Response = file_containing_symbol(Symbol), - grpcbox_stream:send(#{original_request => OriginalRequest, - message_response => Response}, Stream); + Stream0 = grpcbox_stream:send(false, #{original_request => OriginalRequest, + message_response => Response}, Stream), + {ok, Stream0}; %% proto3 dropped extensions so we'll just return an empty result handle_message(#{message_request := {all_extension_numbers_of_type, _}}=OriginalRequest, Stream) -> - grpcbox_stream:send(#{original_request => OriginalRequest, + Stream0 = grpcbox_stream:send(false, #{original_request => OriginalRequest, message_response => ?UNIMPLEMENTED_RESPONSE}, - Stream); + Stream), + {ok, Stream0}; handle_message(#{message_request := {file_containing_extension, _}}=OriginalRequest, Stream) -> - grpcbox_stream:send(#{original_request => OriginalRequest, - message_response => ?UNIMPLEMENTED_RESPONSE}, Stream). + Stream0 = grpcbox_stream:send(false, #{original_request => OriginalRequest, + message_response => ?UNIMPLEMENTED_RESPONSE}, Stream), + {ok, Stream0}. %% diff --git a/src/grpcbox_services_sup.erl b/src/grpcbox_services_sup.erl index b9a68a1..6240de3 100644 --- a/src/grpcbox_services_sup.erl +++ b/src/grpcbox_services_sup.erl @@ -49,7 +49,7 @@ init([ServerOpts, GrpcOpts, ListenOpts, PoolOpts, TransportOpts, ServiceSupName] %% unique name for pool based on the ip and port it will listen on Name = pool_name(ListenOpts), - RestartStrategy = #{strategy => rest_for_one}, + RestartStrategy = #{strategy => rest_for_one, intensity => 5, period => 2}, Pool = #{id => grpcbox_pool, start => {grpcbox_pool, start_link, [Name, chatterbox:settings(server, ServerOpts), ChatterboxOpts, TransportOpts]}}, @@ -127,13 +127,22 @@ load_services([], _, _) -> ok; load_services([ServicePbModule | Rest], Services, ServicesTable) -> ServiceNames = ServicePbModule:get_service_names(), + %% NOTE: Methods value may be a map or a prop depending on gpb options when generating the services [begin + %% NOTE: Methods value may be a map or a prop depending on gpb options when generating the services {{service, _}, Methods} = ServicePbModule:get_service_def(ServiceName), %% throws exception if ServiceName isn't in the map or doesn't exist - try ServiceModule = maps:get(ServiceName, Services), + try + ServiceModule = maps:get(ServiceName, Services), {ServiceModule, ServiceModule:module_info(exports)} of {ServiceModule1, Exports} -> [begin + #{name := Name, + input := Input, + output := Output, + input_stream := InputStream, + output_stream := OutputStream, + opts := Opts} = ensure_map(P), SnakedMethodName = atom_snake_case(Name), case lists:member({SnakedMethodName, 2}, Exports) of true -> @@ -149,12 +158,7 @@ load_services([ServicePbModule | Rest], Services, ServicesTable) -> %% TODO: error? log? insert into ets as unimplemented? unimplemented_method end - end || #{name := Name, - input := Input, - output := Output, - input_stream := InputStream, - output_stream := OutputStream, - opts := Opts} <- Methods] + end || P <- Methods] catch _:_ -> %% TODO: error? log? insert into ets as unimplemented? @@ -179,3 +183,8 @@ atom_snake_case(Name) -> Snaked1 = string:replace(Snaked, ".", "_", all), Snaked2 = string:replace(Snaked1, "__", "_", all), list_to_atom(string:to_lower(unicode:characters_to_list(Snaked2))). + +ensure_map(S) when is_map(S)-> + S; +ensure_map(S) when is_list(S)-> + maps:from_list(S). diff --git a/src/grpcbox_socket.erl b/src/grpcbox_socket.erl index 993ea89..732b42e 100644 --- a/src/grpcbox_socket.erl +++ b/src/grpcbox_socket.erl @@ -11,53 +11,121 @@ code_change/3, terminate/2]). -%% public api +-record(state, { + pool, + listen_opts, + pool_opts, + socket, + mref +}). +%% public api start_link(Pool, ListenOpts, AcceptorOpts) -> gen_server:start_link(?MODULE, [Pool, ListenOpts, AcceptorOpts], []). %% gen_server api init([Pool, ListenOpts, PoolOpts]) -> - Port = maps:get(port, ListenOpts, 8080), - IPAddress = maps:get(ip, ListenOpts, {0, 0, 0, 0}), - AcceptorPoolSize = maps:get(size, PoolOpts, 10), - SocketOpts = maps:get(socket_options, ListenOpts, [{reuseaddr, true}, - {nodelay, true}, - {reuseaddr, true}, - {backlog, 32768}, - {keepalive, true}]), - %% Trapping exit so can close socket in terminate/2 - _ = process_flag(trap_exit, true), - Opts = [{active, false}, {mode, binary}, {packet, raw}, {ip, IPAddress} | SocketOpts], - case gen_tcp:listen(Port, Opts) of - {ok, Socket} -> - %% acceptor could close the socket if there is a problem - MRef = monitor(port, Socket), - grpcbox_pool:accept_socket(Pool, Socket, AcceptorPoolSize), - {ok, {Socket, MRef}}; - {error, Reason} -> - {stop, Reason} - end. + {ok, #state{pool = Pool, pool_opts = PoolOpts, listen_opts = ListenOpts}, 0}. handle_call(Req, _, State) -> {stop, {bad_call, Req}, State}. handle_cast(Req, State) -> {stop, {bad_cast, Req}, State}. - -handle_info({'DOWN', MRef, port, Socket, Reason}, {Socket, MRef} = State) -> - {stop, Reason, State}; -handle_info(_, State) -> +handle_info(timeout, State) -> + case start_listener(State) of + {ok, {Socket, MRef}} -> + {noreply, State#state{socket = Socket, mref = MRef}}; + _ -> + erlang:send_after(5000, self(), timeout), + {noreply, State} + end; +handle_info({'DOWN', MRef, port, Socket, _Reason}, #state{mref = MRef, socket = Socket} = State) -> + catch gen_tcp:close(Socket), + erlang:send_after(5000, self(), timeout), + {noreply, State}; +handle_info(_Msg, State) -> {noreply, State}. code_change(_, State, _) -> {ok, State}. -terminate(_, {Socket, MRef}) -> +terminate(_Reason, {Socket, MRef}) -> %% Socket may already be down but need to ensure it is closed to avoid %% eaddrinuse error on restart + %% this takes care of that, unless of course this process is killed... case demonitor(MRef, [flush, info]) of true -> gen_tcp:close(Socket); false -> ok end. + +%% ------------------------------------------------------------------ +%% Internal functions +%% ------------------------------------------------------------------ +start_listener(#state{ + pool = Pool, + listen_opts = ListenOpts, + pool_opts = PoolOpts} = _State) -> + Port = maps:get(port, ListenOpts, 8080), + IPAddress = maps:get(ip, ListenOpts, {0, 0, 0, 0}), + AcceptorPoolSize = maps:get(size, PoolOpts, 10), + SocketOpts = maps:get(socket_options, ListenOpts, [{reuseaddr, true}, + {nodelay, true}, + {reuseaddr, true}, + {backlog, 32768}, + {keepalive, true}]), + + Opts = [{active, false}, {mode, binary}, {packet, raw}, {ip, IPAddress} | SocketOpts], + case gen_tcp:listen(Port, Opts) of + {ok, Socket} -> + %% acceptor could close the socket if there is a problem + MRef = monitor(port, Socket), + {ok, _} = grpcbox_pool:accept_socket(Pool, Socket, AcceptorPoolSize), + {ok, {Socket, MRef}}; + {error, eaddrinuse} -> + %% our desired port is already in use + %% its likely this grpcbox_socket server has been killed ( for reason unknown ) and is restarting + %% previously it would have bound to the port before passing control to our acceptor pool + %% the socket remains open + %% in the restart scenario, the socket process would attempt to bind again + %% to the port and then stop, the sup would keep restarting it + %% and we would end up breaching the restart strategy of the parent sup + %% eventually taking down the entire tree + %% result of which is we have no active listener and grpcbox is effectively down + %% so now if we hit eaddrinuse, we check if our acceptor pool using it + %% if so we close the port here and stop this process + %% NOTE: issuing stop in init wont trigger terminate and so cant rely on + %% the socket being closed there + %% This allows the sup to restart things cleanly + %% We could try to reuse the exising port rather than closing it + %% but side effects were encountered there, so deliberately avoiding + + %% NOTE: acceptor_pool has a grace period for connections before it terminates + %% grpcbox_pool sets this to a default of 5 secs + %% this needs considered when deciding on related supervisor restart strategies + %% AND keep in mind the acceptor pool will continue accepting new connections + %% during this grace period + + %% get the current sockets in use by the acceptor pool + %% if one is bound to our target port then close it + %% need to allow for possibility of multiple services, each with its own socket + %% so we need to identify our interested socket via port number + PoolSockets = grpcbox_pool:pool_sockets(Pool), + MaybeHaveExistingSocket = + lists:foldl( + fun({inet_tcp, {_IP, BoundPortNumber}, Socket, _SockRef}, _Acc) when BoundPortNumber =:= Port -> + {ok, Socket}; + (_, Acc) -> + Acc + end, socket_not_found, PoolSockets), + case MaybeHaveExistingSocket of + {ok, Socket} -> + gen_tcp:close(Socket); + socket_not_found -> + noop + end, + {error, eaddrinuse}; + {error, Reason} -> + {error, Reason} + end. diff --git a/src/grpcbox_stream.erl b/src/grpcbox_stream.erl index 9693d41..b3136c9 100644 --- a/src/grpcbox_stream.erl +++ b/src/grpcbox_stream.erl @@ -6,12 +6,13 @@ -behaviour(h2_stream). --export([send/2, +-export([ send/3, send_headers/2, - add_headers/2, + update_headers/2, add_trailers/2, set_trailers/2, + update_trailers/2, code_to_status/1, error/2, ctx/1, @@ -26,6 +27,12 @@ on_receive_data/2, on_end_stream/1]). +%% state getters and setters +-export([stream_handler_state/1, + stream_handler_state/2, + stream_req_headers/1 +]). + -export_type([t/0, grpc_status/0, grpc_status_message/0, @@ -34,30 +41,29 @@ grpc_error_data/0, grpc_extended_error_response/0]). --record(state, {handler :: pid(), +-record(state, {handler :: pid(), + stream_handler_state :: any(), socket, auth_fun, - buffer :: binary(), - ctx :: ctx:ctx(), - services_table :: ets:tid(), - req_headers=[] :: list(), - full_method :: binary() | undefined, - input_ref :: reference() | undefined, - callback_pid :: pid() | undefined, - connection_pid :: pid(), - request_encoding :: gzip | identity | undefined, - response_encoding :: gzip | identity | undefined, - content_type :: proto | json | undefined, - resp_headers=[] :: list(), - resp_trailers=[] :: list(), - headers_sent=false :: boolean(), - trailers_sent=false :: boolean(), - unary_interceptor :: fun() | undefined, - stream_interceptor :: fun() | undefined, - stream_id :: stream_id(), - method :: #method{} | undefined, - stats_handler :: module() | undefined, - stats :: term() | undefined}). + buffer :: binary(), + ctx :: ctx:ctx(), + services_table :: ets:tid(), + req_headers=[] :: list(), + full_method :: binary() | undefined, + connection_pid :: pid(), + request_encoding :: gzip | identity | undefined, + response_encoding :: gzip | identity | undefined, + content_type :: proto | json | undefined, + resp_headers=[] :: list(), + resp_trailers=[] :: list(), + headers_sent=false :: boolean(), + trailers_sent=false :: boolean(), + unary_interceptor :: fun() | undefined, + stream_interceptor :: fun() | undefined, + stream_id :: stream_id(), + method :: #method{} | undefined, + stats_handler :: module() | undefined, + stats :: term() | undefined}). -type t() :: #state{}. @@ -72,6 +78,17 @@ }. -type grpc_extended_error_response() :: {grpc_extended_error, grpc_error_data()}. +-spec stream_handler_state(t()) -> any(). +stream_handler_state(#state{stream_handler_state = StreamHandlerState}) -> + StreamHandlerState. +-spec stream_handler_state(t(), any()) -> any(). +stream_handler_state(State, NewStreamHandlerState) -> + State#state{stream_handler_state = NewStreamHandlerState}. + +-spec stream_req_headers(t()) -> list(). +stream_req_headers(#state{req_headers = ReqHeaders}) -> + ReqHeaders. + init(ConnPid, StreamId, [Socket, ServicesTable, AuthFun, UnaryInterceptor, StreamInterceptor, StatsHandler]) -> process_flag(trap_exit, true), @@ -125,31 +142,38 @@ handle_service_lookup(Ctx, [Service, Method], State=#state{services_table=Servic method=M}, handle_auth(Ctx, State1); _ -> - end_stream(?GRPC_STATUS_UNIMPLEMENTED, <<"Method not found on server">>, State) + {ok, State1} = end_stream(?GRPC_STATUS_UNIMPLEMENTED, <<"Method not found on server">>, State), + _ = stop_stream(?REFUSED_STREAM, State1), + {ok, State1} end; handle_service_lookup(_, _, State) -> State1 = State#state{resp_headers=[{<<":status">>, <<"200">>}, {<<"user-agent">>, <<"grpc-erlang/0.1.0">>}]}, - end_stream(?GRPC_STATUS_UNIMPLEMENTED, <<"failed parsing path">>, State1), - {ok, State1}. + {ok, State2} = end_stream(?GRPC_STATUS_UNIMPLEMENTED, <<"failed parsing path">>, State1), + _ = stop_stream(?REFUSED_STREAM, State2), + {ok, State2}. handle_auth(_Ctx, State=#state{auth_fun=AuthFun, socket=Socket, - method=#method{input={_, InputStreaming}}}) -> + method=#method{module=Module, + function=Function}}) -> case authenticate(sock:peercert(Socket), AuthFun) of {true, _Identity} -> - case InputStreaming of - true -> - Ref = make_ref(), - Pid = proc_lib:spawn_link(?MODULE, handle_streams, - [Ref, State#state{handler=self()}]), - {ok, State#state{input_ref=Ref, - callback_pid=Pid}}; - _ -> - {ok, State} - end; + State0 = maybe_init_handler_state(Module, Function, State), + %% send resp headers after verifying client request + %% some clients require grpc headers to be sent within a defined time period + %% otherwise they assume the request has failed and bail out + %% previously server would only return headers upon first data msg send + %% this can cause issues with streaming connections, for example + %% if a client connects and there are no data msgs ready to be sent back to them + %% TODO: check what grpc spec says about this + %% TODO: sending the headers here negates update_headers/2 usefullness ? somewhere better to send em? + State1 = send_headers(State0), + {ok, State1}; _ -> - end_stream(?GRPC_STATUS_UNAUTHENTICATED, <<"">>, State) + {ok, State1} = end_stream(?GRPC_STATUS_UNAUTHENTICATED, <<"">>, State), + _ = stop_stream(?REFUSED_STREAM, State1), + {ok, State1} end. authenticate(_, undefined) -> @@ -173,27 +197,65 @@ handle_streams(Ref, State=#state{full_method=FullMethod, output_stream => false}, StreamInterceptor(Ref, State, ServerInfo, fun Module:Function/2) end) of - {ok, Response, State2} -> - send(Response, State2); + {ok, State1} -> + State1; + {ok, Response, State1} -> + State2 = send(false, Response, State1), + {ok, State3} = end_stream(State2), + _ = stop_stream(?STREAM_CLOSED, State3), + {ok, State3}; + {stop, State1} -> + {ok, State2} = end_stream(State1), + _ = stop_stream(?STREAM_CLOSED, State2), + {ok, State2}; + {stop, Response, State1} -> + State2 = send(false, Response, State1), + {ok, State3} = end_stream(State2), + _ = stop_stream(?STREAM_CLOSED, State3), + {ok, State3}; E={grpc_error, _} -> throw(E); E={grpc_extended_error, _} -> throw(E) end; + handle_streams(Ref, State=#state{full_method=FullMethod, stream_interceptor=StreamInterceptor, method=#method{module=Module, function=Function, output={_, true}}}) -> - case StreamInterceptor of - undefined -> - Module:Function(Ref, State); - _ -> - ServerInfo = #{full_method => FullMethod, - service => Module, - input_stream => true, - output_stream => true}, - StreamInterceptor(Ref, State, ServerInfo, fun Module:Function/2) + case (case StreamInterceptor of + undefined -> + Module:Function(Ref, State); + _ -> + ServerInfo = #{full_method => FullMethod, + service => Module, + input_stream => true, + output_stream => true}, + StreamInterceptor(Ref, State, ServerInfo, fun Module:Function/2) + end) of + {ok, State1} -> + State1; + {ok, Response, State1} -> + send(false, Response, State1); + {stop, State1} -> + {ok, State2} = end_stream(State1), + _ = stop_stream(?STREAM_CLOSED, State2), + {ok, State2}; + {stop, Response, State1} -> + State2 = send(false, Response, State1), + {ok, State3} = end_stream(State2), + _ = stop_stream(?STREAM_CLOSED, State3), + {ok, State3}; + {grpc_error, {Status, Message}} -> + {ok, State1} = end_stream(Status, Message, State), + _ = stop_stream(?STREAM_CLOSED, State1), + {ok, State1}; + {grpc_extended_error, #{status := Status, message := Message} = ErrorData} -> + State1 = add_trailers_from_error_data(ErrorData, State), + {ok, State2} = end_stream(Status, Message, State1), + _ = stop_stream(?STREAM_CLOSED, State2), + {ok, State2} end. on_send_push_promise(_, State) -> @@ -218,18 +280,24 @@ on_receive_data(Bin, State=#state{request_encoding=Encoding, {ok, State1#state{buffer=NewBuffer}} catch throw:{grpc_error, {Status, Message}} -> - end_stream(Status, Message, State); + {ok, State2} = end_stream(Status, Message, State), + _ = stop_stream(?STREAM_CLOSED, State2), + {ok, State2}; throw:{grpc_extended_error, #{status := Status, message := Message} = ErrorData} -> State2 = add_trailers_from_error_data(ErrorData, State), - end_stream(Status, Message, State2); + {ok, State3} = end_stream(Status, Message, State2), + _ = stop_stream(?STREAM_CLOSED, State3), + {ok, State3}; C:E:S -> + %% if we dont catch exceptions here, it ends up taking the h2 connection down + %% and thus one stream going down pulls ev thing down ?LOG_INFO("crash: class=~p exception=~p stacktrace=~p", [C, E, S]), - end_stream(?GRPC_STATUS_UNKNOWN, <<>>, State) + {ok, State2} = end_stream(?GRPC_STATUS_UNKNOWN, <<>>, State), + _ = stop_stream(?INTERNAL_ERROR, State2), + {ok, State2} end. -handle_message(EncodedMessage, State=#state{input_ref=Ref, - ctx=Ctx, - callback_pid=Pid, +handle_message(EncodedMessage, State=#state{ctx=Ctx, method=#method{proto=Proto, input={Input, InputStream}, output={_Output, OutputStream}}}) -> @@ -239,15 +307,10 @@ handle_message(EncodedMessage, State=#state{input_ref=Ref, stats_handler(Ctx, in_payload, #{uncompressed_size => erlang:external_size(Message), compressed_size => size(EncodedMessage)}, State), case {InputStream, OutputStream} of - {true, _} -> - Pid ! {Ref, Message}, - State1; - {false, true} -> - _ = proc_lib:spawn_link(?MODULE, handle_streams, - [Message, State1#state{handler=self()}]), - State1; {false, false} -> - handle_unary(Ctx1, Message, State1) + handle_unary(Ctx1, Message, State1); + {_, _} -> + handle_streams(Message, State1#state{handler=self()}) end catch error:{gpb_error, _} -> @@ -263,7 +326,8 @@ handle_unary(Ctx, Message, State=#state{unary_interceptor=UnaryInterceptor, output={_Output, _OutputStream}}}) -> Ctx1 = ctx_with_stream(Ctx, State), case (case UnaryInterceptor of - undefined -> Module:Function(Ctx1, Message); + undefined -> + Module:Function(Ctx1, Message); _ -> ServerInfo = #{full_method => FullMethod, service => Module}, @@ -283,19 +347,13 @@ on_end_stream(State) -> on_end_stream_(State), {ok, State}. -on_end_stream_(#state{input_ref=Ref, - callback_pid=Pid, - method=#method{input={_Input, true}, +on_end_stream_(State=#state{method=#method{input={_Input, true}, output={_Output, false}}}) -> - Pid ! {Ref, eos}; -on_end_stream_(#state{input_ref=Ref, - callback_pid=Pid, - method=#method{input={_Input, true}, + handle_streams(eos, State); +on_end_stream_(State = #state{method=#method{input={_Input, true}, output={_Output, true}}}) -> - Pid ! {Ref, eos}; -on_end_stream_(#state{input_ref=_Ref, - callback_pid=_Pid, - method=#method{input={_Input, false}, + handle_streams(eos, State); +on_end_stream_(#state{method=#method{input={_Input, false}, output={_Output, true}}}) -> ok; on_end_stream_(State=#state{method=#method{output={_Output, false}}}) -> @@ -332,6 +390,11 @@ end_stream(Status, Message, State=#state{connection_pid=ConnPid, State1 = stats_handler(Ctx1, rpc_end, {}, State), {ok, State1#state{trailers_sent=true}}. +stop_stream(Status, State=#state{ connection_pid=ConnPid, + stream_id=StreamId}) -> + h2_connection:rst_stream(ConnPid, StreamId, Status), + {ok, State}. + set_trailers(Ctx, Trailers) -> State = from_ctx(Ctx), ctx_with_stream(Ctx, State#state{resp_trailers=maps:to_list(Trailers)}). @@ -385,37 +448,25 @@ handle_call(ctx, State=#state{ctx=Ctx}) -> handle_call({ctx, Ctx}, State) -> {ok, ok, State#state{ctx=Ctx}}. -handle_info({add_headers, Headers}, State) -> - update_headers(Headers, State); -handle_info({add_trailers, Trailers}, State) -> - update_trailers(Trailers, State); -handle_info({send_proto, Message}, State) -> - send(false, Message, State); -handle_info({'EXIT', _, normal}, State) -> - end_stream(State), - State; -handle_info({'EXIT', _, {grpc_error, {Status, Message}}}, State) -> - end_stream(Status, Message, State), - State; -handle_info({'EXIT', _, {grpc_extended_error, #{status := Status, message := Message} = ErrorData}}, State) -> - State1 = add_trailers_from_error_data(ErrorData, State), - end_stream(Status, Message, State1), - State1; -handle_info({'EXIT', _, _Other}, State) -> - end_stream(?GRPC_STATUS_UNKNOWN, <<"process exited without reason">>, State), - State; -handle_info(_, State) -> - State. - - -add_headers(Headers, #state{handler=Pid}) -> - Pid ! {add_headers, Headers}. +handle_info(Msg, State=#state{method=#method{module=Module, function=Function}}) -> + %% if the handler module exports handle_info/3, then use that + %% the 3 version passes the invoked RPC which can be used + %% by the handler to accommodate any function specific handling + %% really this is a bespoke use case + %% fall back to handle_info/2 if the /3 is not exported + case erlang:function_exported(Module, handle_info, 3) of + true -> Module:handle_info(Function, Msg, State); + false -> + case erlang:function_exported(Module, handle_info, 2) of + true -> Module:handle_info(Msg, State); + false -> + State + end + end. add_trailers(Ctx, Trailers=#{}) -> State=#state{resp_trailers=RespTrailers} = from_ctx(Ctx), - ctx_with_stream(Ctx, State#state{resp_trailers=maps:to_list(Trailers) ++ RespTrailers}); -add_trailers(Headers, #state{handler=Pid}) -> - Pid ! {add_trailers, Headers}. + ctx_with_stream(Ctx, State#state{resp_trailers=maps:to_list(Trailers) ++ RespTrailers}). update_headers(Headers, State=#state{resp_headers=RespHeaders}) -> State#state{resp_headers=RespHeaders ++ Headers}. @@ -423,9 +474,6 @@ update_headers(Headers, State=#state{resp_headers=RespHeaders}) -> update_trailers(Trailers, State=#state{resp_trailers=RespTrailers}) -> State#state{resp_trailers=RespTrailers ++ Trailers}. -send(Message, #state{handler=Pid}) -> - Pid ! {send_proto, Message}. - send(End, Message, State=#state{headers_sent=false}) -> State1 = send_headers(State), send(End, Message, State1); @@ -527,3 +575,9 @@ maybe_encode_header_value(K, V) -> add_trailers_from_error_data(ErrorData, State) -> Trailers = maps:get(trailers, ErrorData, #{}), update_trailers(maps:to_list(Trailers), State). + +maybe_init_handler_state(Module, Function, State)-> + case erlang:function_exported(Module, init, 2) of + true -> Module:init(Function, State); + false -> State + end. diff --git a/test/grpcbox_SUITE.erl b/test/grpcbox_SUITE.erl index b8deeb3..318a640 100644 --- a/test/grpcbox_SUITE.erl +++ b/test/grpcbox_SUITE.erl @@ -185,10 +185,10 @@ init_per_testcase(stream_interceptor, Config) -> services => #{'routeguide.RouteGuide' => routeguide_route_guide}, stream_interceptor => fun(Ref, Stream, _ServerInfo, Handler) -> - grpcbox_stream:add_trailers([{<<"x-grpc-stream-interceptor">>, + Stream2 = grpcbox_stream:update_trailers([{<<"x-grpc-stream-interceptor">>, <<"true">>}], Stream), - Handler(Ref, Stream) + Handler(Ref, Stream2) end}, transport_opts => #{}}]), application:ensure_all_started(grpcbox), diff --git a/test/routeguide_route_guide.erl b/test/routeguide_route_guide.erl index 5e6a6f1..5e5ebb4 100644 --- a/test/routeguide_route_guide.erl +++ b/test/routeguide_route_guide.erl @@ -2,7 +2,9 @@ -include("grpcbox.hrl"). --export([get_feature/2, +-export([ + init/2, + get_feature/2, list_features/2, record_route/2, route_chat/2, @@ -31,6 +33,25 @@ #{name => string(), location => point()}. +%% define init functions required for each RPC, if required +init(_RPC=record_route, GrpcStream)-> + grpcbox_stream:stream_handler_state( + GrpcStream, + #{t_start => erlang:system_time(1), acc => []} + ); +init(_RPC=route_chat, GrpcStream)-> + grpcbox_stream:stream_handler_state( + GrpcStream, + [] + ); +init(_RPC=closed_stream, GrpcStream)-> + grpcbox_stream:stream_handler_state( + GrpcStream, + #{t_start => erlang:system_time(1), acc => []} + ); +init(_RPC, GrpcStream)-> + GrpcStream. + -spec get_feature(Ctx :: ctx:ctx(), Message :: point()) -> {ok, feature(), ctx:ctx()}. get_feature(Ctx, Message) -> Feature = #{name => find_point(Message, data()), @@ -39,47 +60,51 @@ get_feature(Ctx, Message) -> -spec list_features(Message::rectangle(), GrpcStream :: grpcbox_stream:t()) -> ok. list_features(_Message, GrpcStream) -> - grpcbox_stream:add_headers([{<<"info">>, <<"this is a test-implementation">>}], GrpcStream), - grpcbox_stream:send(#{name => <<"Tour Eiffel">>, + GrpcStream0 = grpcbox_stream:update_headers([{<<"info">>, <<"this is a test-implementation">>}], GrpcStream), + GrpcStream1 = grpcbox_stream:send(false, #{name => <<"Tour Eiffel">>, location => #{latitude => 3, - longitude => 5}}, GrpcStream), - grpcbox_stream:send(#{name => <<"Louvre">>, + longitude => 5}}, GrpcStream0), + GrpcStream2 = grpcbox_stream:send(false, #{name => <<"Louvre">>, location => #{latitude => 4, - longitude => 5}}, GrpcStream), - - grpcbox_stream:add_trailers([{<<"nr_of_points_sent">>, <<"2">>}], GrpcStream), - ok. - --spec record_route(reference(), GrpcStream :: grpcbox_stream:t()) -> {ok, route_summary(), grpcbox_stream:t()}. -record_route(Ref, GrpcStream) -> - record_route(Ref, #{t_start => erlang:system_time(1), - acc => []}, GrpcStream). - -record_route(Ref, Data=#{t_start := T0, acc := Points}, GrpcStream) -> - receive - {Ref, eos} -> - %% receiving 'eos' tells us that we need to return a result. - {ok, #{elapsed_time => erlang:system_time(1) - T0, - point_count => length(Points), - feature_count => count_features(Points), - distance => distance(Points)}, GrpcStream}; - {Ref, Point} -> - record_route(Ref, Data#{acc => [Point | Points]}, GrpcStream) - end. - --spec route_chat(reference(), GrpcStream :: grpcbox_stream:t()) -> ok. -route_chat(Ref, GrpcStream) -> - route_chat(Ref, [], GrpcStream). - -route_chat(Ref, Data, GrpcStream) -> - receive - {Ref, eos} -> - ok; - {Ref, #{location := Location} = P} -> - Messages = proplists:get_all_values(Location, Data), - [grpcbox_stream:send(Message, GrpcStream) || Message <- Messages], - route_chat(Ref, [{Location, P} | Data], GrpcStream) - end. + longitude => 5}}, GrpcStream1), + GrpcStream3 = grpcbox_stream:update_trailers([{<<"nr_of_points_sent">>, <<"2">>}], GrpcStream2), + {stop, GrpcStream3}. + +-spec record_route(ReqMessage :: any(), GrpcStream :: grpcbox_stream:t()) -> {stop, route_summary(), grpcbox_stream:t()} | {ok, grpcbox_stream:t()}. +record_route(ReqMessage, GrpcStream) -> + HandlerState = grpcbox_stream:stream_handler_state(GrpcStream), + record_route(ReqMessage, HandlerState, GrpcStream). + +-spec record_route(ReqMessage :: any(), HandlerState :: any(), GrpcStream :: grpcbox_stream:t()) -> {stop, route_summary(), grpcbox_stream:t()} | {ok, grpcbox_stream:t()}. +record_route(eos, _HandlerState=#{t_start := T0, acc := Points}, GrpcStream) -> + %% receiving 'eos' tells us that we need to return a result. + {stop, #{elapsed_time => erlang:system_time(1) - T0, + point_count => length(Points), + feature_count => count_features(Points), + distance => distance(Points)}, GrpcStream}; +record_route(ReqMessage, HandlerState=#{t_start := _T0, acc := Points}, GrpcStream) -> + NewStreamState0 = grpcbox_stream:stream_handler_state( + GrpcStream, + HandlerState#{acc => [ReqMessage | Points]} + ), + {ok, NewStreamState0}. + +-spec route_chat(ReqMessage :: any(), GrpcStream :: grpcbox_stream:t()) -> {stop, grpcbox_stream:t()} | {ok, grpcbox_stream:t()}. +route_chat(ReqMessage, GrpcStream) -> + HandlerState = grpcbox_stream:stream_handler_state(GrpcStream), + route_chat(ReqMessage, HandlerState, GrpcStream). + +-spec route_chat(ReqMessage :: any(), HandlerState :: any(), GrpcStream :: grpcbox_stream:t()) -> {stop, grpcbox_stream:t()} | {ok, grpcbox_stream:t()}. +route_chat(eos, _HandlerState, GrpcStream) -> + {stop, GrpcStream}; +route_chat(ReqMessage=#{location := Location}, HandlerState, GrpcStream) -> + Messages = proplists:get_all_values(Location, HandlerState), + [grpcbox_stream:send(false, Message, GrpcStream) || Message <- Messages], + NewStreamState0 = grpcbox_stream:stream_handler_state( + GrpcStream, + [{Location, ReqMessage} | HandlerState] + ), + {ok, NewStreamState0}. -spec generate_error(Ctx :: ctx:ctx(), Message :: map()) -> grpcbox_stream:grpc_extended_error_response(). generate_error(_Ctx, _Message) -> @@ -95,7 +120,7 @@ generate_error(_Ctx, _Message) -> -spec streaming_generate_error(Message :: map(), GrpcStream :: grpcbox_stream:t()) -> no_return(). streaming_generate_error(_Message, _GrpcStream) -> - exit( +%% exit( { grpc_extended_error, #{ status => ?GRPC_STATUS_INTERNAL, @@ -104,8 +129,8 @@ streaming_generate_error(_Message, _GrpcStream) -> <<"generate_error_trailer">> => <<"error_trailer">> } } - } - ). + }. +%% ). %% Supporting functions