Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add post_connect option to sql_pool #377

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/web/postprocess/index.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,8 @@ let graphiql_expected = {|<div class="spec value" id="val-graphiql">
|}

let sql_pool_expected = {|<div class="spec value" id="val-sql_pool">
<a href="#val-sql_pool" class="anchor"></a><code><span><span class="keyword">val</span> sql_pool : <span>?size:int <span class="arrow">-&gt;</span></span> <span>string <span class="arrow">-&gt;</span></span> <a href="#type-middleware">middleware</a></span></code>
<a href="#val-sql_pool" class="anchor"></a><code><span><span class="keyword">val</span> sql_pool : <span>?size:int <span class="arrow">-&gt;</span></span>
<span>?post_connect:<span>(<span><span>(<span class="keyword">module</span> <span class="xref-unresolved">Caqti_lwt</span>.CONNECTION)</span> <span class="arrow">-&gt;</span></span> <span><span><span>(unit,&nbsp;<span class="xref-unresolved">Caqti_error</span>.t)</span> <span class="xref-unresolved">Stdlib</span>.result</span> <a href="#type-promise">promise</a></span>)</span> <span class="arrow">-&gt;</span></span> <span>string <span class="arrow">-&gt;</span></span> <a href="#type-middleware">middleware</a></span></code>
</div>
|}

Expand Down
10 changes: 8 additions & 2 deletions src/dream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -1746,11 +1746,17 @@ val graphiql : ?default_query:string -> string -> handler
{{:https://cheatsheetseries.owasp.org/cheatsheets/Database_Security_Cheat_Sheet.html}
OWASP {i Database Security Cheat Sheet}}. *)

val sql_pool : ?size:int -> string -> middleware
val sql_pool :
?size:int ->
?post_connect:
((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) result promise) ->
string ->
middleware
(** Makes an SQL connection pool available to its inner handler. [?size] is the
maximum number of concurrent connections that the pool will support. The
default value is picked by the driver. Note that for SQLite, [?size] is
capped to [1]. *)
capped to [1]. [post_connect] is an optional callback, which is called for
every new connection that is opened to the database. *)

val sql : request -> (Caqti_lwt.connection -> 'a promise) -> 'a promise
(** Runs the callback with a connection from the SQL pool. See example
Expand Down
13 changes: 10 additions & 3 deletions src/sql/sql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ let foreign_keys_on =
(Caqti_type.unit ->. Caqti_type.unit) "PRAGMA foreign_keys = ON"
[@ocaml.warning "-3"]

let post_connect (module Db : Caqti_lwt.CONNECTION) =
let standard_post_connect (module Db : Caqti_lwt.CONNECTION) =
match Caqti_driver_info.dialect_tag Db.driver_info with
| `Sqlite -> Db.exec foreign_keys_on ()
| _ -> Lwt.return (Ok ())

let sql_pool ?size uri =
let sql_pool ?size ?post_connect uri =
let pool_cell = ref None in
fun inner_handler request ->

Expand All @@ -49,7 +49,14 @@ let sql_pool ?size uri =
'sqlite' is not a valid scheme; did you mean 'sqlite3'?");
let pool =
let pool_config = Caqti_pool_config.create ?max_size:size () in
Caqti_lwt_unix.connect_pool ~pool_config ~post_connect parsed_uri in
Caqti_lwt_unix.connect_pool ~pool_config ~post_connect:(fun db ->
Lwt_result.bind (standard_post_connect db) (fun () ->
match post_connect with
| Some f -> f db
| None -> Lwt_result.return ())
)
parsed_uri
in
match pool with
| Ok pool ->
pool_cell := Some pool;
Expand Down
Loading