diff --git a/docs/web/postprocess/index.ml b/docs/web/postprocess/index.ml index b6065b1f..36d47bb1 100644 --- a/docs/web/postprocess/index.ml +++ b/docs/web/postprocess/index.ml @@ -1951,7 +1951,8 @@ let graphiql_expected = {|
|} let sql_pool_expected = {|
- val sql_pool : ?size:int -> string -> middleware + val sql_pool : ?size:int -> +?post_connect:((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) Stdlib.result promise) -> string -> middleware
|} diff --git a/src/dream.mli b/src/dream.mli index 4307df6f..fbe0fb88 100644 --- a/src/dream.mli +++ b/src/dream.mli @@ -1746,11 +1746,17 @@ val graphiql : ?default_query:string -> string -> handler {{:https://cheatsheetseries.owasp.org/cheatsheets/Database_Security_Cheat_Sheet.html} OWASP {i Database Security Cheat Sheet}}. *) -val sql_pool : ?size:int -> string -> middleware +val sql_pool : + ?size:int -> + ?post_connect: + ((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) result promise) -> + string -> + middleware (** Makes an SQL connection pool available to its inner handler. [?size] is the maximum number of concurrent connections that the pool will support. The default value is picked by the driver. Note that for SQLite, [?size] is - capped to [1]. *) + capped to [1]. [post_connect] is an optional callback, which is called for + every new connection that is opened to the database. *) val sql : request -> (Caqti_lwt.connection -> 'a promise) -> 'a promise (** Runs the callback with a connection from the SQL pool. See example diff --git a/src/sql/sql.ml b/src/sql/sql.ml index cfa25f46..78f90dcf 100644 --- a/src/sql/sql.ml +++ b/src/sql/sql.ml @@ -25,12 +25,12 @@ let foreign_keys_on = (Caqti_type.unit ->. Caqti_type.unit) "PRAGMA foreign_keys = ON" [@ocaml.warning "-3"] -let post_connect (module Db : Caqti_lwt.CONNECTION) = +let standard_post_connect (module Db : Caqti_lwt.CONNECTION) = match Caqti_driver_info.dialect_tag Db.driver_info with | `Sqlite -> Db.exec foreign_keys_on () | _ -> Lwt.return (Ok ()) -let sql_pool ?size uri = +let sql_pool ?size ?post_connect uri = let pool_cell = ref None in fun inner_handler request -> @@ -49,7 +49,14 @@ let sql_pool ?size uri = 'sqlite' is not a valid scheme; did you mean 'sqlite3'?"); let pool = let pool_config = Caqti_pool_config.create ?max_size:size () in - Caqti_lwt_unix.connect_pool ~pool_config ~post_connect parsed_uri in + Caqti_lwt_unix.connect_pool ~pool_config ~post_connect:(fun db -> + Lwt_result.bind (standard_post_connect db) (fun () -> + match post_connect with + | Some f -> f db + | None -> Lwt_result.return ()) + ) + parsed_uri + in match pool with | Ok pool -> pool_cell := Some pool;