diff --git a/docs/web/postprocess/index.ml b/docs/web/postprocess/index.ml
index b6065b1f..36d47bb1 100644
--- a/docs/web/postprocess/index.ml
+++ b/docs/web/postprocess/index.ml
@@ -1951,7 +1951,8 @@ let graphiql_expected = {|
|}
let sql_pool_expected = {|
-
val sql_pool : ?size:int -> string -> middleware
+
val sql_pool : ?size:int ->
+?post_connect:((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) Stdlib.result promise) -> string -> middleware
|}
diff --git a/src/dream.mli b/src/dream.mli
index 4307df6f..fbe0fb88 100644
--- a/src/dream.mli
+++ b/src/dream.mli
@@ -1746,11 +1746,17 @@ val graphiql : ?default_query:string -> string -> handler
{{:https://cheatsheetseries.owasp.org/cheatsheets/Database_Security_Cheat_Sheet.html}
OWASP {i Database Security Cheat Sheet}}. *)
-val sql_pool : ?size:int -> string -> middleware
+val sql_pool :
+ ?size:int ->
+ ?post_connect:
+ ((module Caqti_lwt.CONNECTION) -> (unit, Caqti_error.t) result promise) ->
+ string ->
+ middleware
(** Makes an SQL connection pool available to its inner handler. [?size] is the
maximum number of concurrent connections that the pool will support. The
default value is picked by the driver. Note that for SQLite, [?size] is
- capped to [1]. *)
+ capped to [1]. [post_connect] is an optional callback, which is called for
+ every new connection that is opened to the database. *)
val sql : request -> (Caqti_lwt.connection -> 'a promise) -> 'a promise
(** Runs the callback with a connection from the SQL pool. See example
diff --git a/src/sql/sql.ml b/src/sql/sql.ml
index cfa25f46..78f90dcf 100644
--- a/src/sql/sql.ml
+++ b/src/sql/sql.ml
@@ -25,12 +25,12 @@ let foreign_keys_on =
(Caqti_type.unit ->. Caqti_type.unit) "PRAGMA foreign_keys = ON"
[@ocaml.warning "-3"]
-let post_connect (module Db : Caqti_lwt.CONNECTION) =
+let standard_post_connect (module Db : Caqti_lwt.CONNECTION) =
match Caqti_driver_info.dialect_tag Db.driver_info with
| `Sqlite -> Db.exec foreign_keys_on ()
| _ -> Lwt.return (Ok ())
-let sql_pool ?size uri =
+let sql_pool ?size ?post_connect uri =
let pool_cell = ref None in
fun inner_handler request ->
@@ -49,7 +49,14 @@ let sql_pool ?size uri =
'sqlite' is not a valid scheme; did you mean 'sqlite3'?");
let pool =
let pool_config = Caqti_pool_config.create ?max_size:size () in
- Caqti_lwt_unix.connect_pool ~pool_config ~post_connect parsed_uri in
+ Caqti_lwt_unix.connect_pool ~pool_config ~post_connect:(fun db ->
+ Lwt_result.bind (standard_post_connect db) (fun () ->
+ match post_connect with
+ | Some f -> f db
+ | None -> Lwt_result.return ())
+ )
+ parsed_uri
+ in
match pool with
| Ok pool ->
pool_cell := Some pool;