Skip to content

Commit

Permalink
Improve unsafe casting
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbin committed Oct 26, 2024
1 parent 9729918 commit da244ca
Showing 1 changed file with 63 additions and 10 deletions.
73 changes: 63 additions & 10 deletions src/provider.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,38 @@ module Trait = struct
let compare_by_uid id1 id2 = Uid.compare (uid id1) (uid id2)
let same (id1 : _ t) (id2 : _ t) = phys_same id1 id2
let implement = Binding0.implement

module Unsafe_cast : sig
(* We limit unsafe casting to cases where the first parameter is already
determined to be the same. *)
val same_witness : ('a, 'i1, _) t -> ('a, 'i2, _) t -> ('i1, 'i2) Type.eq option
end = struct
(* We create [some_witness_val] at top level so that [same_witness] do not allocate. *)
let some_witness_val = Some Type.Equal
let some_witness_repr = Obj.repr some_witness_val

let same_witness
: type a i1 i2. (a, i1, _) t -> (a, i2, _) t -> (i1, i2) Type.eq option
=
fun t1 t2 ->
if same t1 t2 then (Obj.obj some_witness_repr : (i1, i2) Type.eq option) else None
;;

let _f a b =
(* This expression is meant to help create a build error if we change the
type of [some_witness_val], acting as a reminder to go and update
[same_witness] too. For example, if you change [some_witness_val] to
something like:
{[
let some_witness_val = 1
]}
You'll notice that [same_witness] still compiles (although would be
terribly broken), but this expression no longer type checks. *)
(if phys_same a b then some_witness_val else same_witness a b) [@coverage off]
;;
end
end

module Binding = struct
Expand Down Expand Up @@ -130,14 +162,35 @@ module Handler = struct
else (
let mid = (from + to_) / 2 in
let (Binding.T { trait = elt; implementation } as binding) = t.(mid) in
match Trait.compare_by_uid elt trait |> Ordering.of_int with
| Equal ->
match Trait.Unsafe_cast.same_witness elt trait with
| Some Type.Equal ->
if update_cache then t.(0) <- binding;
if_found (Obj.magic implementation)
| Less ->
binary_search t ~trait ~update_cache ~if_not_found ~if_found ~from:(mid + 1) ~to_
| Greater ->
binary_search t ~trait ~update_cache ~if_not_found ~if_found ~from ~to_:(mid - 1))
if_found implementation
| None ->
(match Trait.compare_by_uid elt trait |> Ordering.of_int with
| Equal ->
(* [same_witness a b => (uid a = uid b)] but the converse might not
hold. We treat as invalid usages cases where traits (t1, t2) would
have the same uids without being physically equal. *)
assert false
| Less ->
binary_search
t
~trait
~update_cache
~if_not_found
~if_found
~from:(mid + 1)
~to_
| Greater ->
binary_search
t
~trait
~update_cache
~if_not_found
~if_found
~from
~to_:(mid - 1)))
;;

let make_lookup
Expand All @@ -154,9 +207,9 @@ module Handler = struct
then if_not_found ~trait_info:(Trait.info trait)
else (
let (Binding.T { trait = cached_id; implementation }) = t.(0) in
if Trait.same trait cached_id
then if_found (Obj.magic implementation)
else
match Trait.Unsafe_cast.same_witness trait cached_id with
| Some Type.Equal -> if_found implementation
| None ->
binary_search
t
~trait
Expand Down

0 comments on commit da244ca

Please sign in to comment.