From 04c131546ee6a1e635cac49885fac5f939094858 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 7 Jun 2024 10:47:53 +0200 Subject: [PATCH 01/14] Mirage_crypto.Block.ECB with {de,en}crypt_into Also provide unsafe_{en,de}crypt_into for further performance. --- bench/speed.ml | 17 ++++++++++++- src/cipher_block.ml | 50 +++++++++++++++++++++++++++++-------- src/mirage_crypto.mli | 58 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 11 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index 90d44425..4c68d140 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -45,6 +45,15 @@ let throughput title f = Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" size (bw /. mb) iters time +let throughput_into title f = + Printf.printf "\n* [%s]\n%!" title ; + sizes |> List.iter @@ fun size -> + Gc.full_major () ; + let dst = Bytes.create size in + let (iters, time, bw) = burn (f dst) size in + Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" + size (bw /. mb) iters time + let count_period = 10. let count f n = @@ -353,7 +362,13 @@ let benchmarks = [ bm "aes-128-ecb" (fun name -> let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + throughput_into name + (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in + throughput_into name + (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-cbc-e" (fun name -> let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 3dfa1fcb..afd963cc 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -28,6 +28,10 @@ module Block = struct val block_size : int val encrypt : key:key -> string -> string val decrypt : key:key -> string -> string + val encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit end module type CBC = sig @@ -134,17 +138,43 @@ module Modes = struct let of_secret = Core.of_secret - let (encrypt, decrypt) = - let ecb xform key src = - let n = String.length src in - if n mod block_size <> 0 then invalid_arg "ECB: length %u" n; - let dst = Bytes.create n in - xform ~key ~blocks:(n / block_size) src 0 dst 0 ; - Bytes.unsafe_to_string dst - in - (fun ~key:(key, _) src -> ecb Core.encrypt key src), - (fun ~key:(_, key) src -> ecb Core.decrypt key src) + let unsafe_ecb xform key src src_off dst dst_off len = + xform ~key ~blocks:(len / block_size) src src_off dst dst_off + + let ecb xform key src src_off dst dst_off len = + if len mod block_size <> 0 then + invalid_arg "ECB: length %u not of block size" len; + if String.length src - src_off < len then + invalid_arg "ECB: source length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "ECB: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + unsafe_ecb xform key src src_off dst dst_off len + + let encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = + ecb Core.encrypt key src src_off dst dst_off len + + let unsafe_encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = + unsafe_ecb Core.encrypt key src src_off dst dst_off len + + let decrypt_into ~key:(_, key) src ~src_off dst ~dst_off len = + ecb Core.decrypt key src src_off dst dst_off len + let unsafe_decrypt_into ~key:(_, key) src ~src_off dst ~dst_off len = + unsafe_ecb Core.decrypt key src src_off dst dst_off len + + let encrypt ~key src = + let len = String.length src in + let dst = Bytes.create len in + encrypt_into ~key src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst + + let decrypt ~key src = + let len = String.length src in + let dst = Bytes.create len in + decrypt_into ~key src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst end module CBC_of (Core : Block.Core) : Block.CBC = struct diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index d33b8420..18c6b117 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -157,12 +157,70 @@ module Block : sig module type ECB = sig type key + val of_secret : string -> key + (** Construct the encryption key corresponding to [secret]. + + @raise Invalid_argument if the length of [secret] is not in + {{!key_sizes}[key_sizes]}. *) val key_sizes : int array + (** Key sizes allowed with this cipher. *) + val block_size : int + (** The size of a single block. *) + val encrypt : key:key -> string -> string + (** [encrypt ~key src] encrypts [src] into a freshly allocated buffer of the + same size using [key]. + + @raise Invalid_argument if the length of [src] is not a multiple of + {!block_size}. *) + val decrypt : key:key -> string -> string + (** [decrypt ~key src] decrypts [src] into a freshly allocated buffer of the + same size using [key]. + + @raise Invalid_argument if the length of [src] is not a multiple of + {!block_size}. *) + + val encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key src ~src_off dst dst_off len] encrypts [len] octets + from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [String.length src - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + + val decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [decrypt_into ~key src ~src_off dst dst_off len] decrypts [len] octets + from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [String.length src - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + + (**/**) + val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into ~key src ~src_off dst dst_off len] encrypts [len] + octets from [src] starting at [src_off] into [dst] starting at + [dst_off]. Since buffer lengths and block sizes are not checked, this + may cause memory issues if an invariant is violated: + {ul + {- [len] must be a multiple of {!block_size},} + {- [String.length src - src_off >= len],} + {- [Bytes.length dst - dst_off >= len].}} *) + + val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into ~key src ~src_off dst dst_off len] decrypts [len] + octets from [src] starting at [src_off] into [dst] starting at + [dst_off]. Since buffer lengths and block sizes are not checked, this + may cause memory issues if an invariant is violated: + {ul + {- [len] must be a multiple of {!block_size},} + {- [String.length src - src_off >= len],} + {- [Bytes.length dst - dst_off >= len].}} *) + (**/**) end (** {e Cipher-block chaining} mode. *) From e8614b0bd7ba2fa84974bd8684f5e6728ffd246d Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 7 Jun 2024 11:53:32 +0200 Subject: [PATCH 02/14] Mirage_crypto.Block.CBC now has {de,en}crypt_into functionality This may avoid buffer allocations. There are as well unsafe functions for those feeling bounds checks are unnecessary. --- bench/speed.ml | 26 +++++++++++- src/cipher_block.ml | 93 +++++++++++++++++++++++++++++++------------ src/mirage_crypto.mli | 74 +++++++++++++++++++++++++++++++--- 3 files changed, 161 insertions(+), 32 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index 4c68d140..d0f7c2c6 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -373,12 +373,34 @@ let benchmarks = [ bm "aes-128-cbc-e" (fun name -> let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) and iv = Mirage_crypto_rng.generate 16 in - throughput name (fun cs -> AES.CBC.encrypt ~key ~iv cs)) ; + throughput_into name + (fun dst cs -> AES.CBC.encrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-e-unsafe" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput_into name + (fun dst cs -> AES.CBC.unsafe_encrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-e-unsafe-inplace" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput name + (fun cs -> + let b = Bytes.unsafe_of_string cs in + AES.CBC.unsafe_encrypt_into_inplace ~key ~iv b ~dst_off:0 (String.length cs))) ; bm "aes-128-cbc-d" (fun name -> let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) and iv = Mirage_crypto_rng.generate 16 in - throughput name (fun cs -> AES.CBC.decrypt ~key ~iv cs)) ; + throughput_into name + (fun dst cs -> AES.CBC.decrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-d-unsafe" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput_into name + (fun dst cs -> AES.CBC.unsafe_decrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-ctr" (fun name -> let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret diff --git a/src/cipher_block.ml b/src/cipher_block.ml index afd963cc..729f54e8 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -44,7 +44,19 @@ module Block = struct val encrypt : key:key -> iv:string -> string -> string val decrypt : key:key -> iv:string -> string -> string - val next_iv : iv:string -> string -> string + val next_iv : ?off:int -> string -> iv:string -> string + + val encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + + val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_encrypt_into_inplace : key:key -> iv:string -> + bytes -> dst_off:int -> int -> unit end module type CTR = sig @@ -186,40 +198,71 @@ module Modes = struct let of_secret = Core.of_secret - let bounds_check ~iv cs = - if String.length iv <> block then invalid_arg "CBC: IV length %u" (String.length iv); - if String.length cs mod block <> 0 then - invalid_arg "CBC: argument length %u" (String.length cs) + let bounds_check ?(off = 0) ~iv cs = + if String.length iv <> block then + invalid_arg "CBC: IV length %u not of block size" (String.length iv); + if (String.length cs - off) mod block <> 0 then + invalid_arg "CBC: argument length %u (off %u) not of block size" + (String.length cs) off - let next_iv ~iv cs = - bounds_check ~iv cs ; - if String.length cs > 0 then + let next_iv ?(off = 0) cs ~iv = + bounds_check ~iv cs ~off ; + if String.length cs > off then String.sub cs (String.length cs - block_size) block_size else iv - let encrypt ~key:(key, _) ~iv src = - bounds_check ~iv src ; - let dst = Bytes.of_string src in + let unsafe_encrypt_into_inplace ~key:(key, _) ~iv dst ~dst_off len = let rec loop iv iv_i dst_i = function - 0 -> () - | b -> Native.xor_into_bytes iv iv_i dst dst_i block ; - Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; - loop (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) + | 0 -> () + | b -> + Native.xor_into_bytes iv iv_i dst dst_i block ; + Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; + loop (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) in - loop iv 0 0 (Bytes.length dst / block) ; + loop iv 0 dst_off (len / block) + + let unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len = + Bytes.unsafe_blit_string src src_off dst dst_off len; + unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len + + let encrypt_into ~key ~iv src ~src_off dst ~dst_off len = + bounds_check ~off:src_off ~iv src; + if String.length src - src_off < len then + invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" + (Bytes.length dst) dst_off len; + unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len + + let encrypt ~key ~iv src = + let dst = Bytes.create (String.length src) in + encrypt_into ~key ~iv src ~src_off:0 dst ~dst_off:0 (String.length src); Bytes.unsafe_to_string dst - let decrypt ~key:(_, key) ~iv src = - bounds_check ~iv src ; - let msg = Bytes.create (String.length src) - and b = String.length src / block in + let unsafe_decrypt_into ~key:(_, key) ~iv src ~src_off dst ~dst_off len = + let b = len / block in if b > 0 then begin - Core.decrypt ~key ~blocks:b src 0 msg 0 ; - Native.xor_into_bytes iv 0 msg 0 block ; - Native.xor_into_bytes src 0 msg block ((b - 1) * block) ; - end ; - Bytes.unsafe_to_string msg + Core.decrypt ~key ~blocks:b src src_off dst dst_off ; + Native.xor_into_bytes iv 0 dst dst_off block ; + Native.xor_into_bytes src src_off dst (dst_off + block) ((b - 1) * block) ; + end + let decrypt_into ~key ~iv src ~src_off dst ~dst_off len = + bounds_check ~off:src_off ~iv src; + if String.length src - src_off < len then + invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" + (Bytes.length dst) dst_off len; + unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len + + let decrypt ~key ~iv src = + let len = String.length src in + let msg = Bytes.create len in + decrypt_into ~key ~iv src ~src_off:0 msg ~dst_off:0 len; + Bytes.unsafe_to_string msg end module CTR_of (Core : Block.Core) (Ctr : Counters.S) : diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index 18c6b117..ca5f46f9 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -253,8 +253,8 @@ module Block : sig @raise Invalid_argument if [iv] is not [block_size], or [msg] is not [k * block_size] long. *) - val next_iv : iv:string -> string -> string - (** [next_iv ~iv ciphertext] is the first [iv] {e following} the + val next_iv : ?off:int -> string -> iv:string -> string + (** [next_iv ~iv ciphertext ~off] is the first [iv] {e following} the encryption that used [iv] to produce [ciphertext]. For protocols which perform inter-message chaining, this is the [iv] @@ -266,9 +266,73 @@ module Block : sig {[encrypt ~iv msg1 || encrypt ~iv:(next_iv ~iv (encrypt ~iv msg1)) msg2 == encrypt ~iv (msg1 || msg2)]} - @raise Invalid_argument if the length of [iv] is not [block_size], or - the length of [ciphertext] is not [k * block_size] for some [k]. *) - end + @raise Invalid_argument if the length of [iv] is not [block_size]. + @raise Invalid_argument if the length of [ciphertext] is not a multiple + of [block_size]. *) + + val encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key ~iv src ~src_off dst dst_off len] encrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if the length of [iv] is not {!block_size}. + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [String.length src - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + + val decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [decrypt_into ~key ~iv src ~src_off dst dst_off len] decrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if the length of [iv] is not {!block_size}. + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [String.length src - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + + (**/**) + val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into ~key ~iv src ~src_off dst dst_off len] encrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + It is unsafe since buffer lengths are not checks. This may casue memory + issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [String.length src - src_off >= len],} + {- [Bytes.length dst - dst_off >= len].}} *) + + val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into ~key ~iv src ~src_off dst dst_off len] decrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + It is unsafe since buffer lengths are not checks. This may casue memory + issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [String.length src - src_off >= len],} + {- [Bytes.length dst - dst_off >= len].}} *) + + val unsafe_encrypt_into_inplace : key:key -> iv:string -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into_inplace ~key ~iv dst dst_off len] encrypts [len] + octets from [dst] starting at [dst_off] into [dst] starting at [dst_off]. + + The [dst] buffer must contain the message to be encrypted. + + It is unsafe since buffer lengths are not checks. This may casue memory + issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [String.length src - src_off >= len],} + {- [Bytes.length dst - dst_off >= len].}} *) + (**/**) +end (** {e Counter} mode. *) module type CTR = sig From 5cbad0a3dd7127fd78d236e2a1deace4602bdb86 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 7 Jun 2024 18:25:44 +0200 Subject: [PATCH 03/14] counters: add an offset parameter --- src/cipher_block.ml | 30 +++++++++++++++--------------- src/native.ml | 6 +++--- src/native/mirage_crypto.h | 2 +- src/native/misc.c | 4 ++-- src/native/misc_sse.c | 8 ++++---- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 729f54e8..508a1225 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -99,7 +99,7 @@ module Counters = struct val size : int val add : ctr -> int64 -> ctr val of_octets : string -> ctr - val unsafe_count_into : ctr -> bytes -> blocks:int -> unit + val unsafe_count_into : ctr -> bytes -> off:int -> blocks:int -> unit end module C64be = struct @@ -107,10 +107,10 @@ module Counters = struct let size = 8 let of_octets cs = String.get_int64_be cs 0 let add = Int64.add - let unsafe_count_into t buf ~blocks = - let tmp = Bytes.create 8 in - Bytes.set_int64_be tmp 0 t; - Native.count8be tmp buf ~blocks + let unsafe_count_into t buf ~off ~blocks = + let ctr = Bytes.create 8 in + Bytes.set_int64_be ctr 0 t; + Native.count8be ~ctr buf ~off ~blocks end module C128be = struct @@ -123,10 +123,10 @@ module Counters = struct let w0' = Int64.add w0 n in let flip = if Int64.logxor w0 w0' < 0L then w0' > w0 else w0' < w0 in ((if flip then Int64.succ w1 else w1), w0') - let unsafe_count_into (w1, w0) buf ~blocks = - let tmp = Bytes.create 16 in - Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; - Native.count16be tmp buf ~blocks + let unsafe_count_into (w1, w0) buf ~off ~blocks = + let ctr = Bytes.create 16 in + Bytes.set_int64_be ctr 0 w1; Bytes.set_int64_be ctr 8 w0; + Native.count16be ~ctr buf ~off ~blocks end module C128be32 = struct @@ -134,10 +134,10 @@ module Counters = struct let add (w1, w0) n = let hi = 0xffffffff00000000L and lo = 0x00000000ffffffffL in (w1, Int64.(logor (logand hi w0) (add n w0 |> logand lo))) - let unsafe_count_into (w1, w0) buf ~blocks = - let tmp = Bytes.create 16 in - Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; - Native.count16be4 tmp buf ~blocks + let unsafe_count_into (w1, w0) buf ~off ~blocks = + let ctr = Bytes.create 16 in + Bytes.set_int64_be ctr 0 w1; Bytes.set_int64_be ctr 8 w0; + Native.count16be4 ~ctr buf ~off ~blocks end end @@ -280,13 +280,13 @@ module Modes = struct let stream ~key ~ctr n = let blocks = imax 0 n / block_size in let buf = Bytes.create n in - Ctr.unsafe_count_into ctr ~blocks buf ; + Ctr.unsafe_count_into ctr buf ~off:0 ~blocks ; Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) 0 buf 0 ; let slack = imax 0 n mod block_size in if slack <> 0 then begin let buf' = Bytes.create block_size in let ctr = Ctr.add ctr (Int64.of_int blocks) in - Ctr.unsafe_count_into ctr ~blocks:1 buf' ; + Ctr.unsafe_count_into ctr buf' ~off:0 ~blocks:1 ; Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string buf') 0 buf' 0 ; Bytes.unsafe_blit buf' 0 buf (blocks * block_size) slack end; diff --git a/src/native.ml b/src/native.ml index 55437a70..bf081a26 100644 --- a/src/native.ml +++ b/src/native.ml @@ -37,9 +37,9 @@ end * Unsolved: bounds-checked XORs are slowing things down considerably... *) external xor_into_bytes : string -> int -> bytes -> int -> int -> unit = "mc_xor_into_bytes" [@@noalloc] -external count8be : bytes -> bytes -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] -external count16be : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] -external count16be4 : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] +external count8be : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] +external count16be : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] +external count16be4 : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] external misc_mode : unit -> int = "mc_misc_mode" [@@noalloc] diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 0542db2f..6608a1b1 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -114,6 +114,6 @@ CAMLprim value mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n); CAMLprim value -mc_count_16_be_4_generic (value ctr, value dst, value blocks); +mc_count_16_be_4_generic (value ctr, value dst, value off, value blocks); #endif /* H__MIRAGE_CRYPTO */ diff --git a/src/native/misc.c b/src/native/misc.c index ba9590f8..dea76e18 100644 --- a/src/native/misc.c +++ b/src/native/misc.c @@ -60,9 +60,9 @@ mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n) } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value blocks) { \ + CAMLprim value name (value ctr, value dst, value off, value blocks) { \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8 (dst), Long_val (blocks) ); \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) ); \ return Val_unit; \ } diff --git a/src/native/misc_sse.c b/src/native/misc_sse.c index 1f2265da..c155d468 100644 --- a/src/native/misc_sse.c +++ b/src/native/misc_sse.c @@ -48,11 +48,11 @@ mc_xor_into_bytes (value b1, value off1, value b2, value off2, value n) { } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value blocks) { \ - _mc_switch_accel(ssse3, \ - name##_generic (ctr, dst, blocks), \ + CAMLprim value name (value ctr, value dst, value off, value blocks) { \ + _mc_switch_accel(ssse3, \ + name##_generic (ctr, dst, off, blocks), \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8 (dst), Long_val (blocks) )) \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) )) \ return Val_unit; \ } From 148b4d4af73c4fc187e4c62877415132e1585866 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 7 Jun 2024 19:00:30 +0200 Subject: [PATCH 04/14] Mirage_crypto.Block.CTR with {de,en}crypt_into --- bench/speed.ml | 7 ++- src/cipher_block.ml | 73 +++++++++++++++++++++------- src/mirage_crypto.mli | 110 +++++++++++++++++++++++++++--------------- 3 files changed, 133 insertions(+), 57 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index d0f7c2c6..58202f02 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -405,7 +405,12 @@ let benchmarks = [ bm "aes-128-ctr" (fun name -> let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret and ctr = Mirage_crypto_rng.generate 16 |> AES.CTR.ctr_of_octets in - throughput name (fun cs -> AES.CTR.encrypt ~key ~ctr cs)) ; + throughput_into name (fun dst cs -> AES.CTR.encrypt_into ~key ~ctr cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-ctr-unsafe" (fun name -> + let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret + and ctr = Mirage_crypto_rng.generate 16 |> AES.CTR.ctr_of_octets in + throughput_into name (fun dst cs -> AES.CTR.unsafe_encrypt_into ~key ~ctr cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-gcm" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 508a1225..22181578 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -64,18 +64,29 @@ module Block = struct type key val of_secret : string -> key - type ctr - val key_sizes : int array val block_size : int + type ctr + val add_ctr : ctr -> int64 -> ctr + val next_ctr : ?off:int -> string -> ctr:ctr -> ctr + val ctr_of_octets : string -> ctr + val stream : key:key -> ctr:ctr -> int -> string val encrypt : key:key -> ctr:ctr -> string -> string val decrypt : key:key -> ctr:ctr -> string -> string - val add_ctr : ctr -> int64 -> ctr - val next_ctr : ctr:ctr -> string -> ctr - val ctr_of_octets : string -> ctr + val stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + val encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + + val unsafe_stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + val unsafe_encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit end module type GCM = sig @@ -277,30 +288,58 @@ module Modes = struct let (key_sizes, block_size) = Core.(key, block) let of_secret = Core.e_of_secret - let stream ~key ~ctr n = - let blocks = imax 0 n / block_size in - let buf = Bytes.create n in - Ctr.unsafe_count_into ctr buf ~off:0 ~blocks ; - Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) 0 buf 0 ; - let slack = imax 0 n mod block_size in + let unsafe_stream_into ~key ~ctr buf ~off len = + let blocks = imax 0 len / block_size in + Ctr.unsafe_count_into ctr buf ~off ~blocks ; + Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) off buf off ; + let slack = imax 0 len mod block_size in if slack <> 0 then begin let buf' = Bytes.create block_size in let ctr = Ctr.add ctr (Int64.of_int blocks) in Ctr.unsafe_count_into ctr buf' ~off:0 ~blocks:1 ; Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string buf') 0 buf' 0 ; - Bytes.unsafe_blit buf' 0 buf (blocks * block_size) slack - end; + Bytes.unsafe_blit buf' 0 buf (off + blocks * block_size) slack + end + + let stream_into ~key ~ctr buf ~off len = + if Bytes.length buf - off < len then + invalid_arg "CTR: buffer length %u - off %u < len %u" + (Bytes.length buf) off len; + unsafe_stream_into ~key ~ctr buf ~off len + + let stream ~key ~ctr n = + let buf = Bytes.create n in + unsafe_stream_into ~key ~ctr buf ~off:0 n; Bytes.unsafe_to_string buf + let unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = + unsafe_stream_into ~key ~ctr dst ~off:dst_off len; + Uncommon.unsafe_xor_into src ~src_off dst ~dst_off len + + let encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = + if String.length src - src_off < len then + invalid_arg "CTR: src length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "CTR: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len + let encrypt ~key ~ctr src = - let res = Bytes.unsafe_of_string (stream ~key ~ctr (String.length src)) in - Native.xor_into_bytes src 0 res 0 (String.length src) ; - Bytes.unsafe_to_string res + let len = String.length src in + let dst = Bytes.create len in + encrypt_into ~key ~ctr src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst let decrypt = encrypt + let decrypt_into = encrypt_into + + let unsafe_decrypt_into = unsafe_encrypt_into + let add_ctr = Ctr.add - let next_ctr ~ctr msg = add_ctr ctr (Int64.of_int @@ String.length msg // block_size) + let next_ctr ?(off = 0) msg ~ctr = + add_ctr ctr (Int64.of_int @@ (String.length msg - off) // block_size) let ctr_of_octets = Ctr.of_octets end diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index ca5f46f9..0231c158 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -202,20 +202,18 @@ module Block : sig (**/**) val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit - (** [unsafe_encrypt_into ~key src ~src_off dst dst_off len] encrypts [len] - octets from [src] starting at [src_off] into [dst] starting at - [dst_off]. Since buffer lengths and block sizes are not checked, this - may cause memory issues if an invariant is violated: + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: {ul {- [len] must be a multiple of {!block_size},} {- [String.length src - src_off >= len],} {- [Bytes.length dst - dst_off >= len].}} *) val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit - (** [unsafe_decrypt_into ~key src ~src_off dst dst_off len] decrypts [len] - octets from [src] starting at [src_off] into [dst] starting at - [dst_off]. Since buffer lengths and block sizes are not checked, this - may cause memory issues if an invariant is violated: + (** [unsafe_decrypt_into] is {!decrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: {ul {- [len] must be a multiple of {!block_size},} {- [String.length src - src_off >= len],} @@ -260,8 +258,8 @@ module Block : sig For protocols which perform inter-message chaining, this is the [iv] for the next message. - It is either [iv], when [len ciphertext = 0], or the last block of - [ciphertext]. Note that + It is either [iv], when [String.length ciphertext - off = 0], or the + last block of [ciphertext]. Note that {[encrypt ~iv msg1 || encrypt ~iv:(next_iv ~iv (encrypt ~iv msg1)) msg2 == encrypt ~iv (msg1 || msg2)]} @@ -293,11 +291,9 @@ module Block : sig (**/**) val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit - (** [unsafe_encrypt_into ~key ~iv src ~src_off dst dst_off len] encrypts [len] - octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. - It is unsafe since buffer lengths are not checks. This may casue memory - issues if an invariant is violated: + This may casue memory issues if an invariant is violated: {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} @@ -306,11 +302,9 @@ module Block : sig val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit - (** [unsafe_decrypt_into ~key ~iv src ~src_off dst dst_off len] decrypts [len] - octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + (** [unsafe_decrypt_into] is {!decrypt_into}, but without bounds checks. - It is unsafe since buffer lengths are not checks. This may casue memory - issues if an invariant is violated: + This may casue memory issues if an invariant is violated: {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} @@ -319,13 +313,10 @@ module Block : sig val unsafe_encrypt_into_inplace : key:key -> iv:string -> bytes -> dst_off:int -> int -> unit - (** [unsafe_encrypt_into_inplace ~key ~iv dst dst_off len] encrypts [len] - octets from [dst] starting at [dst_off] into [dst] starting at [dst_off]. - - The [dst] buffer must contain the message to be encrypted. + (** [unsafe_encrypt_into_inplace] is {!unsafe_encrypt_into}, but assumes + that [dst] already contains the mesage to be encrypted. - It is unsafe since buffer lengths are not checks. This may casue memory - issues if an invariant is violated: + This may casue memory issues if an invariant is violated: {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} @@ -353,6 +344,27 @@ end type ctr + val add_ctr : ctr -> int64 -> ctr + (** [add_ctr ctr n] adds [n] to [ctr]. *) + + val next_ctr : ?off:int -> string -> ctr:ctr -> ctr + (** [next_ctr ~off msg ~ctr] is the state of the counter after encrypting or + decrypting [msg] at offset [off] with the counter [ctr]. + + For protocols which perform inter-message chaining, this is the + counter for the next message. + + It is computed as [C.add ctr (ceil (len msg / block_size))]. Note that + if [len msg1 = k * block_size], + +{[encrypt ~ctr msg1 || encrypt ~ctr:(next_ctr ~ctr msg1) msg2 + == encrypt ~ctr (msg1 || msg2)]} + + *) + + val ctr_of_octets : string -> ctr + (** [ctr_of_octets buf] converts the value of [buf] into a counter. *) + val stream : key:key -> ctr:ctr -> int -> string (** [stream ~key ~ctr n] is the raw keystream. @@ -371,31 +383,51 @@ end val encrypt : key:key -> ctr:ctr -> string -> string (** [encrypt ~key ~ctr msg] is - [stream ~key ~ctr ~off (len msg) lxor msg]. *) + [stream ~key ~ctr (len msg) lxor msg]. *) val decrypt : key:key -> ctr:ctr -> string -> string (** [decrypt] is [encrypt]. *) - val add_ctr : ctr -> int64 -> ctr - (** [add_ctr ctr n] adds [n] to [ctr]. *) + val stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + (** [stream_into ~key ~ctr dst ~off len] is the raw key stream put into + [dst] starting at [off]. - val next_ctr : ctr:ctr -> string -> ctr - (** [next_ctr ~ctr msg] is the state of the counter after encrypting or - decrypting [msg] with the counter [ctr]. + @raise Invalid_argument if [Bytes.length dst - off < len]. *) - For protocols which perform inter-message chaining, this is the - counter for the next message. + val encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key ~ctr src ~src_off dst ~dst_off len] produces the + key stream into [dst] at [dst_off], and then xors it with [src] at + [src_off]. - It is computed as [C.add ctr (ceil (len msg / block_size))]. Note that - if [len msg1 = k * block_size], + @raise Invalid_argument if [Bytes.length dst - off < len]. + @raise Invalid_argument if [String.length src - off < len]. *) -{[encrypt ~ctr msg1 || encrypt ~ctr:(next_ctr ~ctr msg1) msg2 - == encrypt ~ctr (msg1 || msg2)]} + val decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [decrypt_into] is {!encrypt_into}. *) - *) + (**/**) + val unsafe_stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + (** [unsafe_stream_into] is {!stream_into}, but without bounds checks. - val ctr_of_octets : string -> ctr - (** [ctr_of_octets buf] converts the value of [buf] into a counter. *) + This may cause memory issues if the invariant is violated: + {ul + {- [Bytes.length buf - off >= len].}} *) + + val unsafe_encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: + {ul + {- [Bytes.length dst - off >= len],} + {- [String.length src - off < len].}} *) + + val unsafe_decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into] is {!unsafe_encrypt_into}. *) + (**/**) end (** {e Galois/Counter Mode}. *) From 0316c43c9d5bfc4eb075646b01242de20e49e18b Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 7 Jun 2024 23:13:05 +0200 Subject: [PATCH 05/14] GCM and ChaCha have {de,en}crypt_into now --- bench/speed.ml | 24 ++++++-- src/aead.ml | 12 ++++ src/chacha20.ml | 119 +++++++++++++++++++++++++----------- src/cipher_block.ml | 97 ++++++++++++++++++++++------- src/mirage_crypto.mli | 63 ++++++++++++++++++- src/native.ml | 6 +- src/native/ghash_ctmul.c | 4 +- src/native/ghash_generic.c | 4 +- src/native/ghash_pclmul.c | 6 +- src/native/mirage_crypto.h | 2 +- src/native/poly1305-donna.c | 8 +-- src/poly1305.ml | 12 ++-- 12 files changed, 272 insertions(+), 85 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index 58202f02..9578ede9 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -45,11 +45,11 @@ let throughput title f = Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" size (bw /. mb) iters time -let throughput_into title f = +let throughput_into ?(add = 0) title f = Printf.printf "\n* [%s]\n%!" title ; sizes |> List.iter @@ fun size -> Gc.full_major () ; - let dst = Bytes.create size in + let dst = Bytes.create (size + add) in let (iters, time, bw) = burn (f dst) size in Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" size (bw /. mb) iters time @@ -356,9 +356,16 @@ let benchmarks = [ fst ecdh_shares); bm "chacha20-poly1305" (fun name -> - let key = Mirage_crypto.Chacha20.of_secret (Mirage_crypto_rng.generate 32) + let key = Chacha20.of_secret (Mirage_crypto_rng.generate 32) and nonce = Mirage_crypto_rng.generate 8 in - throughput name (Mirage_crypto.Chacha20.authenticate_encrypt ~key ~nonce)) ; + throughput_into ~add:Chacha20.tag_size name + (fun dst cs -> Chacha20.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))) ; + + bm "chacha20-poly1305-unsafe" (fun name -> + let key = Chacha20.of_secret (Mirage_crypto_rng.generate 32) + and nonce = Mirage_crypto_rng.generate 8 in + throughput_into ~add:Chacha20.tag_size name + (fun dst cs -> Chacha20.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))) ; bm "aes-128-ecb" (fun name -> let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in @@ -415,7 +422,14 @@ let benchmarks = [ bm "aes-128-gcm" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in - throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce cs)); + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs - AES.GCM.tag_size) (String.length cs))); + + bm "aes-128-gcm-unsafe" (fun name -> + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs - AES.GCM.tag_size) (String.length cs))); bm "aes-128-ghash" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) diff --git a/src/aead.ml b/src/aead.ml index a03214e1..30b716a1 100644 --- a/src/aead.ml +++ b/src/aead.ml @@ -10,4 +10,16 @@ module type AEAD = sig string -> string * string val authenticate_decrypt_tag : key:key -> nonce:string -> ?adata:string -> tag:string -> string -> string option + val authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + val authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + val unsafe_authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + val unsafe_authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool end diff --git a/src/chacha20.ml b/src/chacha20.ml index f0d97840..a3ef469c 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -42,77 +42,126 @@ let init ctr ~key ~nonce = Bytes.unsafe_blit_string nonce 0 state nonce_off (String.length nonce) ; state, inc -let crypt ~key ~nonce ?(ctr = 0L) data = +let crypt_into ~key ~nonce ~ctr src ~src_off dst ~dst_off len = let state, inc = init ctr ~key ~nonce in - let l = String.length data in - let block_count = l // block in + let block_count = len // block in let last_len = - let last = l mod block in + let last = len mod block in if last = 0 then block else last in - let res = Bytes.create l in let rec loop i = function | 0 -> () | 1 -> if last_len = block then begin - chacha20_block state i res ; - Native.xor_into_bytes data i res i block + chacha20_block state (dst_off + i) dst ; + Native.xor_into_bytes src (src_off + i) dst (dst_off + i) block end else begin let buf = Bytes.create block in chacha20_block state 0 buf ; - Native.xor_into_bytes data i buf 0 last_len ; - Bytes.unsafe_blit buf 0 res i last_len + Native.xor_into_bytes src (src_off + i) buf 0 last_len ; + Bytes.unsafe_blit buf 0 dst (dst_off + i) last_len end | n -> - chacha20_block state i res ; - Native.xor_into_bytes data i res i block ; + chacha20_block state (dst_off + i) dst ; + Native.xor_into_bytes src (src_off + i) dst (dst_off + i) block ; inc state; loop (i + block) (n - 1) in - loop 0 block_count ; + loop 0 block_count + +let crypt ~key ~nonce ?(ctr = 0L) data = + let l = String.length data in + let res = Bytes.create l in + crypt_into ~key ~nonce ~ctr data ~src_off:0 res ~dst_off:0 l; Bytes.unsafe_to_string res module P = Poly1305.It +let tag_size = P.mac_size + let generate_poly1305_key ~key ~nonce = crypt ~key ~nonce (String.make 32 '\000') -let mac ~key ~adata ciphertext = - let pad16 b = - let len = String.length b mod 16 in +let mac_into ~key ~adata src ~src_off len dst ~dst_off = + let pad16 l = + let len = l mod 16 in if len = 0 then "" else String.make (16 - len) '\000' - and len = + and len_buf = let data = Bytes.create 16 in Bytes.set_int64_le data 0 (Int64.of_int (String.length adata)); - Bytes.set_int64_le data 8 (Int64.of_int (String.length ciphertext)); + Bytes.set_int64_le data 8 (Int64.of_int len); Bytes.unsafe_to_string data in - P.macl ~key [ adata ; pad16 adata ; ciphertext ; pad16 ciphertext ; len ] + let p1 = pad16 (String.length adata) and p2 = pad16 len in + P.mac_into ~key [ adata, 0, String.length adata ; + p1, 0, String.length p1 ; + src, src_off, len ; + p2, 0, String.length p2 ; + len_buf, 0, String.length len_buf ] + dst ~dst_off + +let mac ~key ~adata ciphertext = + let r = Bytes.create tag_size in + mac_into ~key ~adata ciphertext ~src_off:0 (String.length ciphertext) r ~dst_off:0; + Bytes.unsafe_to_string r -let authenticate_encrypt_tag ~key ~nonce ?(adata = "") data = +let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = let poly1305_key = generate_poly1305_key ~key ~nonce in - let ciphertext = crypt ~key ~nonce ~ctr:1L data in - let mac = mac ~key:poly1305_key ~adata ciphertext in - ciphertext, mac + crypt_into ~key ~nonce ~ctr:1L src ~src_off dst ~dst_off len; + mac_into ~key:poly1305_key ~adata (Bytes.unsafe_to_string dst) ~src_off:dst_off len dst ~dst_off:tag_off + +let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + if String.length src - src_off < len then + invalid_arg "Chacha20: src length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "Chacha20: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if Bytes.length dst - tag_off < tag_size then + invalid_arg "Chacha20: dst length %u - tag_off %u < tag_size %u" + (Bytes.length dst) tag_off tag_size; + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - cdata ^ ctag + let l = String.length data in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata data ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst + +let authenticate_encrypt_tag ~key ~nonce ?adata data = + let r = authenticate_encrypt ~key ~nonce ?adata data in + String.sub r 0 (String.length data), String.sub r (String.length data) tag_size -let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag data = +let unsafe_authenticate_decrypt_into ~key ~nonce ?(adata = "") src ~src_off ~tag_off dst ~dst_off len = let poly1305_key = generate_poly1305_key ~key ~nonce in - let ctag = mac ~key:poly1305_key ~adata data in - let plain = crypt ~key ~nonce ~ctr:1L data in - if Eqaf.equal tag ctag then Some plain else None + let ctag = Bytes.create tag_size in + mac_into ~key:poly1305_key ~adata src ~src_off len ctag ~dst_off:0; + crypt_into ~key ~nonce ~ctr:1L src ~src_off dst ~dst_off len; + Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) + +let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + if String.length src - src_off < len then + invalid_arg "Chacha20: src length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "Chacha20: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if String.length src - tag_off < tag_size then + invalid_arg "Chacha20: src length %u - tag_off %u < tag_size %u" + (String.length src) tag_off tag_size; + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata data = - if String.length data < P.mac_size then + if String.length data < tag_size then None else - let cipher, tag = - let p = String.length data - P.mac_size in - String.sub data 0 p, String.sub data p P.mac_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher + let l = String.length data - tag_size in + let r = Bytes.create l in + if unsafe_authenticate_decrypt_into ~key ~nonce ?adata data ~src_off:0 ~tag_off:l r ~dst_off:0 l then + Some (Bytes.unsafe_to_string r) + else + None -let tag_size = P.mac_size +let authenticate_decrypt_tag ~key ~nonce ?adata ~tag data = + let cdata = data ^ tag in + authenticate_decrypt ~key ~nonce ?adata cdata diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 22181578..db137b0c 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -347,6 +347,7 @@ module Modes = struct type key val derive : string -> key val digesti : key:key -> (string Uncommon.iter) -> string + val digesti_off_len : key:key -> (string * int * int) Uncommon.iter -> string val tagsize : int end = struct type key = string @@ -357,10 +358,15 @@ module Modes = struct let k = Bytes.create keysize in Native.GHASH.keyinit cs k; Bytes.unsafe_to_string k + let digesti_off_len ~key i = + let res = Bytes.make tagsize '\x00' in + i (fun (cs, off, len) -> Native.GHASH.ghash key res cs off len); + Bytes.unsafe_to_string res let digesti ~key i = let res = Bytes.make tagsize '\x00' in - i (fun cs -> Native.GHASH.ghash key res cs (String.length cs)); + i (fun cs -> Native.GHASH.ghash key res cs 0 (String.length cs)); Bytes.unsafe_to_string res + end module GCM_of (C : Block.Core) : Block.GCM = struct @@ -397,36 +403,74 @@ module Modes = struct CTR.ctr_of_octets @@ GHASH.digesti ~key:hkey @@ iter2 nonce (pack64s 0L (bits64 nonce)) - let tag ~key ~hkey ~ctr ?(adata = "") cdata = - CTR.encrypt ~key ~ctr @@ - GHASH.digesti ~key:hkey @@ - iter3 adata cdata (pack64s (bits64 adata) (bits64 cdata)) + let unsafe_tag_into ~key ~hkey ~ctr ?(adata = "") cdata ~off ~len dst ~tag_off = + CTR.unsafe_encrypt_into ~key ~ctr + (GHASH.digesti_off_len ~key:hkey + (iter3 (adata, 0, String.length adata) (cdata, off, len) + (pack64s (bits64 adata) (Int64.of_int (len * 8)), 0, 16))) + ~src_off:0 dst ~dst_off:tag_off tag_size - let authenticate_encrypt_tag ~key:{ key; hkey } ~nonce ?adata data = - let ctr = counter ~hkey nonce in - let cdata = CTR.(encrypt ~key ~ctr:(add_ctr ctr 1L) data) in - let ctag = tag ~key ~hkey ~ctr ?adata cdata in - cdata, ctag + let unsafe_authenticate_encrypt_into ~key:{ key; hkey } ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + let ctr = counter ~hkey nonce in + CTR.(unsafe_encrypt_into ~key ~ctr:(add_ctr ctr 1L) src ~src_off dst ~dst_off len); + unsafe_tag_into ~key ~hkey ~ctr ?adata (Bytes.unsafe_to_string dst) ~off:dst_off ~len dst ~tag_off + + let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + if String.length src - src_off < len then + invalid_arg "GCM: source length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "GCM: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if Bytes.length dst - tag_off < tag_size then + invalid_arg "GCM: dst length %u - tag_off %u < tag_size %u" + (Bytes.length dst) tag_off tag_size; + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - cdata ^ ctag + let l = String.length data in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata data ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst + + let authenticate_encrypt_tag ~key ~nonce ?adata data = + let r = authenticate_encrypt ~key ~nonce ?adata data in + String.sub r 0 (String.length data), + String.sub r (String.length data) tag_size - let authenticate_decrypt_tag ~key:{ key; hkey } ~nonce ?adata ~tag:tag_data cipher = - let ctr = counter ~hkey nonce in - let data = CTR.(encrypt ~key ~ctr:(add_ctr ctr 1L) cipher) in - let ctag = tag ~key ~hkey ~ctr ?adata cipher in - if Eqaf.equal tag_data ctag then Some data else None + let unsafe_authenticate_decrypt_into ~key:{ key; hkey } ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + let ctr = counter ~hkey nonce in + CTR.(unsafe_encrypt_into ~key ~ctr:(add_ctr ctr 1L) src ~src_off dst ~dst_off len); + let ctag = Bytes.create tag_size in + unsafe_tag_into ~key ~hkey ~ctr ?adata src ~off:src_off ~len ctag ~tag_off:0; + Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) + + let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + if String.length src - src_off < len then + invalid_arg "GCM: source length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "GCM: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if String.length src - tag_off < tag_size then + invalid_arg "GCM: src length %u - tag_off %u < tag_size %u" + (String.length src) tag_off tag_size; + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata cdata = if String.length cdata < tag_size then None else - let cipher, tag = - String.sub cdata 0 (String.length cdata - tag_size), - String.sub cdata (String.length cdata - tag_size) tag_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher + let l = String.length cdata - tag_size in + let data = Bytes.create l in + if unsafe_authenticate_decrypt_into ~key ~nonce ?adata cdata ~src_off:0 ~tag_off:l data ~dst_off:0 l then + Some (Bytes.unsafe_to_string data) + else + None + + let authenticate_decrypt_tag ~key ~nonce ?adata ~tag:tag_data cipher = + let cdata = cipher ^ tag_data in + authenticate_decrypt ~key ~nonce ?adata cdata end module CCM16_of (C : Block.Core) : Block.CCM16 = struct @@ -465,6 +509,15 @@ module Modes = struct String.sub data (String.length data - tag_size) tag_size in authenticate_decrypt_tag ~key ~nonce ?adata ~tag data + + let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + assert false + let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off = + assert false + let unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + assert false + let unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off = + assert false end end diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index 0231c158..deaca37b 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -74,8 +74,8 @@ module Poly1305 : sig (** [maci ~key iter] is the all-in-one mac computation: [get (feedi (empty ~key) iter)]. *) - val macl : key:string -> string list -> string - (** [macl ~key datas] computes the [mac] of [datas]. *) + val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + (** [mac_into ~key datas dst dst_off] computes the [mac] of [datas]. *) end (** {1 Symmetric-key cryptography} *) @@ -141,6 +141,65 @@ module type AEAD = sig returned. @raise Invalid_argument if [nonce] is not of the right size. *) + + (** {1 Authenticated encryption and decryption into existing buffers} *) + + val authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + (** [authenticate_encrypt_into ~key ~nonce ~adata msg ~src_off dst ~dst_off ~tag_off len] + encrypts [msg] starting at [src_off] with [key] and [nonce]. The output + is put into [dst] at [dst_off], the tag into [dst] at [tag_off]. + + @raise Invalid_argument if [nonce] is not of the right size. + @raise Invalid_argument if [String.length msg - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [Bytes.length dst - tag_off < tag_size]. + *) + + val authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + (** [authenticate_decrypt_into ~key ~nonce ~adata msg ~src_off ~tag_off dst ~dst_off] + computes the authentication tag using [key], [nonce], and [adata], and + decrypts the encrypted data from [msg] starting at [src_off] into [dst] + starting at [dst_off]. If the authentication tags match, [true] is + returned, and the decrypted data is in [dst]. + + @raise Invalid_argument if [nonce] is not of the right size. + @raise Invalid_argument if [String.length msg - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [String.length msg - tag_off < tag_size]. *) + + (**/**) + val unsafe_authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + (** [unsafe_authenticate_encrypt_into] is {!authenticate_encrypt_into}, but + without bounds checks. + + @raise Invalid_argument if [nonce] is not of the right size. + + This may cause memory issues if an invariant is violated: + {ul + {- [String.length msg - src_off >= len].} + {- [Bytes.length dst - dst_off >= len].} + {- [Bytes.length dst - tag_off >= tag_size].}} *) + + val unsafe_authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + (** [unsafe_authenticate_decrypt_into] is {!authenticate_decrypt_into}, but + without bounds checks. + + @raise Invalid_argument if [nonce] is not of the right size. + + This may cause memory issues if an invariant is violated: + {ul + {- [String.length msg - src_off >= len].} + {- [Bytes.length dst - dst_off >= len].} + {- [String.length msg - tag_off >= tag_size].}} *) + (**/**) end (** Block ciphers. diff --git a/src/native.ml b/src/native.ml index bf081a26..f6d59da3 100644 --- a/src/native.ml +++ b/src/native.ml @@ -20,8 +20,8 @@ end module Poly1305 = struct external init : bytes -> string -> unit = "mc_poly1305_init" [@@noalloc] - external update : bytes -> string -> int -> unit = "mc_poly1305_update" [@@noalloc] - external finalize : bytes -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc] + external update : bytes -> string -> int -> int -> unit = "mc_poly1305_update" [@@noalloc] + external finalize : bytes -> bytes -> int -> unit = "mc_poly1305_finalize" [@@noalloc] external ctx_size : unit -> int = "mc_poly1305_ctx_size" [@@noalloc] external mac_size : unit -> int = "mc_poly1305_mac_size" [@@noalloc] end @@ -29,7 +29,7 @@ end module GHASH = struct external keysize : unit -> int = "mc_ghash_key_size" [@@noalloc] external keyinit : string -> bytes -> unit = "mc_ghash_init_key" [@@noalloc] - external ghash : string -> bytes -> string -> int -> unit = "mc_ghash" [@@noalloc] + external ghash : string -> bytes -> string -> int -> int -> unit = "mc_ghash" [@@noalloc] external mode : unit -> int = "mc_ghash_mode" [@@noalloc] end diff --git a/src/native/ghash_ctmul.c b/src/native/ghash_ctmul.c index 7788fd05..bb1a2b05 100644 --- a/src/native/ghash_ctmul.c +++ b/src/native/ghash_ctmul.c @@ -290,8 +290,8 @@ CAMLprim value mc_ghash_init_key_generic (value key, value m) { return Val_unit; } -CAMLprim value mc_ghash_generic (value m, value hash, value src, value len) { - br_ghash_ctmul(Bp_val(hash), Bp_val(m), _st_uint8(src), Int_val(len)); +CAMLprim value mc_ghash_generic (value m, value hash, value src, value off, value len) { + br_ghash_ctmul(Bp_val(hash), Bp_val(m), _st_uint8_off(src, off), Int_val(len)); return Val_unit; } diff --git a/src/native/ghash_generic.c b/src/native/ghash_generic.c index 2cc49532..68768c85 100644 --- a/src/native/ghash_generic.c +++ b/src/native/ghash_generic.c @@ -101,9 +101,9 @@ CAMLprim value mc_ghash_init_key_generic (value key, value m) { } CAMLprim value -mc_ghash_generic (value m, value hash, value src, value len) { +mc_ghash_generic (value m, value hash, value src, value off, value len) { __ghash ((__uint128_t *) Bp_val (m), (uint64_t *) Bp_val (hash), - _st_uint8 (src), Int_val (len) ); + _st_uint8_off (src, off), Int_val (len) ); return Val_unit; } diff --git a/src/native/ghash_pclmul.c b/src/native/ghash_pclmul.c index 58ca02ea..7c7ea95b 100644 --- a/src/native/ghash_pclmul.c +++ b/src/native/ghash_pclmul.c @@ -204,11 +204,11 @@ CAMLprim value mc_ghash_init_key (value key, value m) { } CAMLprim value -mc_ghash (value k, value hash, value src, value len) { +mc_ghash (value k, value hash, value src, value off, value len) { _mc_switch_accel(pclmul, - mc_ghash_generic(k, hash, src, len), + mc_ghash_generic(k, hash, src, off, len), __ghash ( (__m128i *) Bp_val (k), (__m128i *) Bp_val (hash), - (__m128i *) _st_uint8 (src), Int_val (len) )) + (__m128i *) _st_uint8_off (src, off), Int_val (len) )) return Val_unit; } diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 6608a1b1..5496a965 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -105,7 +105,7 @@ CAMLprim value mc_ghash_key_size_generic (__unit ()); CAMLprim value mc_ghash_init_key_generic (value key, value m); CAMLprim value -mc_ghash_generic (value m, value hash, value src, value len); +mc_ghash_generic (value m, value hash, value src, value off, value len); CAMLprim value mc_xor_into_generic (value b1, value off1, value b2, value off2, value n); diff --git a/src/native/poly1305-donna.c b/src/native/poly1305-donna.c index 567649ab..46991dc2 100644 --- a/src/native/poly1305-donna.c +++ b/src/native/poly1305-donna.c @@ -59,13 +59,13 @@ CAMLprim value mc_poly1305_init (value ctx, value key) { return Val_unit; } -CAMLprim value mc_poly1305_update (value ctx, value buf, value len) { - poly1305_update ((poly1305_context *) Bytes_val(ctx), _st_uint8(buf), Int_val(len)); +CAMLprim value mc_poly1305_update (value ctx, value buf, value off, value len) { + poly1305_update ((poly1305_context *) Bytes_val(ctx), _st_uint8_off(buf, off), Int_val(len)); return Val_unit; } -CAMLprim value mc_poly1305_finalize (value ctx, value mac) { - poly1305_finish ((poly1305_context *) Bytes_val(ctx), Bytes_val(mac)); +CAMLprim value mc_poly1305_finalize (value ctx, value mac, value off) { + poly1305_finish ((poly1305_context *) Bytes_val(ctx), _bp_uint8_off(mac, off)); return Val_unit; } diff --git a/src/poly1305.ml b/src/poly1305.ml index eb571b82..aec53354 100644 --- a/src/poly1305.ml +++ b/src/poly1305.ml @@ -11,7 +11,7 @@ module type S = sig val mac : key:string -> string -> string val maci : key:string -> string iter -> string - val macl : key:string -> string list -> string + val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit end module It : S = struct @@ -31,7 +31,7 @@ module It : S = struct ctx let update ctx data = - P.update ctx data (String.length data) + P.update ctx data 0 (String.length data) let feed ctx cs = let t = dup ctx in @@ -45,7 +45,7 @@ module It : S = struct let final ctx = let res = Bytes.create mac_size in - P.finalize ctx res; + P.finalize ctx res 0; Bytes.unsafe_to_string res let get ctx = final (dup ctx) @@ -54,8 +54,8 @@ module It : S = struct let maci ~key iter = feedi (empty ~key) iter |> final - let macl ~key datas = + let mac_into ~key datas dst ~dst_off = let ctx = empty ~key in - List.iter (update ctx) datas; - final ctx + List.iter (fun (d, off, len) -> P.update ctx d off len) datas; + P.finalize ctx dst dst_off end From 5243e87a348fee12cb46aa753909a8a92981beec Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 9 Jun 2024 12:50:28 +0200 Subject: [PATCH 06/14] CCM16 with {de,en}crypt_into --- src/ccm.ml | 64 +++++++++++++++++++-------------------- src/chacha20.ml | 5 --- src/cipher_block.ml | 74 ++++++++++++++++++++++++++++++++------------- 3 files changed, 84 insertions(+), 59 deletions(-) diff --git a/src/ccm.ml b/src/ccm.ml index 8b68b7f3..23c30cf7 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -74,10 +74,8 @@ let prepare_header nonce adata plen tlen = type mode = Encrypt | Decrypt -let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = - let datalen = String.length data in - let cbcheader = prepare_header nonce adata datalen maclen in - let dst = Bytes.create datalen in +let crypto_core_into ~cipher ~mode ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len = + let cbcheader = prepare_header nonce adata len maclen in let small_q = 15 - String.length nonce in let ctr_flag_val = flags 0 0 (small_q - 1) in @@ -104,54 +102,54 @@ let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = doit (Bytes.make block_size '\x00') 0 cbcheader 0 in - let rec loop iv ctr src src_off dst dst_off= + let rec loop iv ctr src src_off dst dst_off len = let cbcblock, cbc_off = match mode with | Encrypt -> src, src_off | Decrypt -> Bytes.unsafe_to_string dst, dst_off in - match String.length src - src_off with - | 0 -> iv - | x when x < block_size -> + if len = 0 then + iv + else if len < block_size then begin let buf = Bytes.make block_size '\x00' in - Bytes.unsafe_blit dst dst_off buf 0 x; + Bytes.unsafe_blit dst dst_off buf 0 len ; ctrblock ctr buf ; - Bytes.unsafe_blit buf 0 dst dst_off x ; - unsafe_xor_into src ~src_off dst ~dst_off x ; - Bytes.unsafe_blit_string cbcblock cbc_off buf 0 x; - Bytes.unsafe_fill buf x (block_size - x) '\x00'; + Bytes.unsafe_blit buf 0 dst dst_off len ; + unsafe_xor_into src ~src_off dst ~dst_off len ; + Bytes.unsafe_blit_string cbcblock cbc_off buf 0 len ; + Bytes.unsafe_fill buf len (block_size - len) '\x00'; cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 ; iv - | _ -> + end else begin ctrblock ctr dst ; unsafe_xor_into src ~src_off dst ~dst_off block_size ; cbc cbcblock cbc_off iv 0 ; - loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) + loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) (len - block_size) + end in - let last = loop cbcprep 1 data 0 dst 0 in - let t = Bytes.sub last 0 maclen in - (dst, t) + let last = loop cbcprep 1 src src_off dst dst_off len in + (* assert (maclen = Bytes.length last); *) + (* assert (block_size = maclen); *) + last + +let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = + let datalen = String.length data in + let dst = Bytes.create datalen in + let t = crypto_core_into ~cipher ~mode ~key ~nonce ~maclen ~adata data ~src_off:0 dst ~dst_off:0 datalen in + dst, t let crypto_t t nonce cipher key = let ctr = gen_ctr nonce 0 in cipher ~key (Bytes.unsafe_to_string ctr) ~src_off:0 ctr ~dst_off:0 ; unsafe_xor_into (Bytes.unsafe_to_string ctr) ~src_off:0 t ~dst_off:0 (Bytes.length t) -let valid_nonce nonce = - let nsize = String.length nonce in - if nsize < 7 || nsize > 13 then - invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize - -let generation_encryption ~cipher ~key ~nonce ~maclen ~adata data = - valid_nonce nonce; - let cdata, t = crypto_core ~cipher ~mode:Encrypt ~key ~nonce ~maclen ~adata data in +let unsafe_generation_encryption_into ~cipher ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off ~tag_off len = + let t = crypto_core_into ~cipher ~mode:Encrypt ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; - Bytes.unsafe_to_string cdata, Bytes.unsafe_to_string t + Bytes.unsafe_blit t 0 dst tag_off maclen -let decryption_verification ~cipher ~key ~nonce ~maclen ~adata ~tag data = - valid_nonce nonce; - let cdata, t = crypto_core ~cipher ~mode:Decrypt ~key ~nonce ~maclen ~adata data in +let unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen ~adata src ~src_off ~tag_off dst ~dst_off len = + let tag = String.sub src tag_off maclen in + let t = crypto_core_into ~cipher ~mode:Decrypt ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; - match Eqaf.equal tag (Bytes.unsafe_to_string t) with - | true -> Some (Bytes.unsafe_to_string cdata) - | false -> None + Eqaf.equal tag (Bytes.unsafe_to_string t) diff --git a/src/chacha20.ml b/src/chacha20.ml index a3ef469c..07ed07d7 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -100,11 +100,6 @@ let mac_into ~key ~adata src ~src_off len dst ~dst_off = len_buf, 0, String.length len_buf ] dst ~dst_off -let mac ~key ~adata ciphertext = - let r = Bytes.create tag_size in - mac_into ~key ~adata ciphertext ~src_off:0 (String.length ciphertext) r ~dst_off:0; - Bytes.unsafe_to_string r - let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = let poly1305_key = generate_poly1305_key ~key ~nonce in crypt_into ~key ~nonce ~ctr:1L src ~src_off dst ~dst_off len; diff --git a/src/cipher_block.ml b/src/cipher_block.ml index db137b0c..e196fc06 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -486,38 +486,70 @@ module Modes = struct let (key_sizes, block_size) = C.(key, block) let cipher ~key src ~src_off dst ~dst_off = - if String.length src - src_off < block_size || Bytes.length dst - dst_off < block_size then - invalid_arg "src len %u, dst len %u" (String.length src - src_off) (Bytes.length dst - dst_off); C.encrypt ~key ~blocks:1 src src_off dst dst_off - let authenticate_encrypt_tag ~key ~nonce ?(adata = "") cs = - Ccm.generation_encryption ~cipher ~key ~nonce ~maclen:tag_size ~adata cs + let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = + Ccm.unsafe_generation_encryption_into ~cipher ~key ~nonce ~maclen:tag_size + ~adata src ~src_off dst ~dst_off ~tag_off len + + let valid_nonce nonce = + let nsize = String.length nonce in + if nsize < 7 || nsize > 13 then + invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize + + let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + if String.length src - src_off < len then + invalid_arg "CCM: source length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "CCM: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if Bytes.length dst - tag_off < tag_size then + invalid_arg "CCM: dst length %u - tag_off %u < tag_size %u" + (Bytes.length dst) tag_off tag_size; + valid_nonce nonce; + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata cs = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata cs in - cdata ^ ctag + valid_nonce nonce; + let l = String.length cs in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata cs ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst + + let authenticate_encrypt_tag ~key ~nonce ?adata cs = + let res = authenticate_encrypt ~key ~nonce ?adata cs in + String.sub res 0 (String.length cs), String.sub res (String.length cs) tag_size - let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag cs = - Ccm.decryption_verification ~cipher ~key ~nonce ~maclen:tag_size ~adata ~tag cs + let unsafe_authenticate_decrypt_into ~key ~nonce ?(adata = "") src ~src_off ~tag_off dst ~dst_off len = + Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen:tag_size ~adata src ~src_off ~tag_off dst ~dst_off len + + let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + if String.length src - src_off < len then + invalid_arg "CCM: source length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "CCM: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if String.length src - tag_off < tag_size then + invalid_arg "CCM: src length %u - tag_off %u < tag_size %u" + (String.length src) tag_off tag_size; + valid_nonce nonce; + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata data = if String.length data < tag_size then None else - let data, tag = - String.sub data 0 (String.length data - tag_size), - String.sub data (String.length data - tag_size) tag_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag data + let dlen = String.length data - tag_size in + let dst = Bytes.create dlen in + if authenticate_decrypt_into ~key ~nonce ?adata data ~src_off:0 ~tag_off:dlen dst ~dst_off:0 dlen then + Some (Bytes.unsafe_to_string dst) + else + None - let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - assert false - let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off = - assert false - let unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - assert false - let unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off = - assert false + let authenticate_decrypt_tag ~key ~nonce ?adata ~tag cs = + authenticate_decrypt ~key ~nonce ?adata (cs ^ tag) end end From 4482eebd341070b43f375774f711053bfc0421f2 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 9 Jun 2024 16:19:07 +0200 Subject: [PATCH 07/14] minor adjustments to speed --- bench/speed.ml | 50 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index 9578ede9..3b1c90e5 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -372,6 +372,22 @@ let benchmarks = [ throughput_into name (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + bm "aes-192-ecb" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-192-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-256-ecb" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in + throughput_into name (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-256-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in + throughput_into name (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + bm "aes-128-ecb-unsafe" (fun name -> let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in throughput_into name @@ -423,35 +439,45 @@ let benchmarks = [ let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in throughput_into ~add:AES.GCM.tag_size name - (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs - AES.GCM.tag_size) (String.length cs))); + (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); bm "aes-128-gcm-unsafe" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in throughput_into ~add:AES.GCM.tag_size name - (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs - AES.GCM.tag_size) (String.length cs))); + (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); bm "aes-128-ghash" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in - throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce ~adata:cs "")); + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce ~adata:cs "" ~src_off:0 dst ~dst_off:0 ~tag_off:0 0)); + + bm "aes-128-ghash-unsafe" (fun name -> + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce ~adata:cs "" ~src_off:0 dst ~dst_off:0 ~tag_off:0 0)); bm "aes-128-ccm" (fun name -> let key = AES.CCM16.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 10 in - throughput name (fun cs -> AES.CCM16.authenticate_encrypt ~key ~nonce cs)); - - bm "aes-192-ecb" (fun name -> - let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + throughput_into ~add:AES.CCM16.tag_size name + (fun dst cs -> AES.CCM16.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); - bm "aes-256-ecb" (fun name -> - let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + bm "aes-128-ccm-unsafe" (fun name -> + let key = AES.CCM16.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 10 in + throughput_into ~add:AES.CCM16.tag_size name + (fun dst cs -> AES.CCM16.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); bm "d3des-ecb" (fun name -> let key = DES.ECB.of_secret (Mirage_crypto_rng.generate 24) in - throughput name (fun cs -> DES.ECB.encrypt ~key cs)) ; + throughput_into name (fun dst cs -> DES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "d3des-ecb-unsafe" (fun name -> + let key = DES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> DES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "fortuna" (fun name -> let open Mirage_crypto_rng.Fortuna in From 3399544574dc05be1c1e3113fd71c47cbac21a69 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:13:23 +0200 Subject: [PATCH 08/14] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Reynir Björnsson --- src/mirage_crypto.mli | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index deaca37b..04eb7324 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -148,7 +148,7 @@ module type AEAD = sig ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> tag_off:int -> int -> unit (** [authenticate_encrypt_into ~key ~nonce ~adata msg ~src_off dst ~dst_off ~tag_off len] - encrypts [msg] starting at [src_off] with [key] and [nonce]. The output + encrypts [len] bytes of [msg] starting at [src_off] with [key] and [nonce]. The output is put into [dst] at [dst_off], the tag into [dst] at [tag_off]. @raise Invalid_argument if [nonce] is not of the right size. @@ -160,9 +160,9 @@ module type AEAD = sig val authenticate_decrypt_into : key:key -> nonce:string -> ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> dst_off:int -> int -> bool - (** [authenticate_decrypt_into ~key ~nonce ~adata msg ~src_off ~tag_off dst ~dst_off] + (** [authenticate_decrypt_into ~key ~nonce ~adata msg ~src_off ~tag_off dst ~dst_off len] computes the authentication tag using [key], [nonce], and [adata], and - decrypts the encrypted data from [msg] starting at [src_off] into [dst] + decrypts the [len] bytes encrypted data from [msg] starting at [src_off] into [dst] starting at [dst_off]. If the authentication tags match, [true] is returned, and the decrypted data is in [dst]. From 08a8b16d7da8326fbf54a3832e0549a0be456826 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:30:21 +0200 Subject: [PATCH 09/14] revise bounds checks (cc @reynir @palainp), also check off >= 0 --- src/cipher_block.ml | 102 +++++++++++++++----------------------------- 1 file changed, 35 insertions(+), 67 deletions(-) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index e196fc06..af958e28 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -152,6 +152,15 @@ module Counters = struct end end +let check_offset ~tag ~buf ~off ~len actual_len = + if off < 0 then + invalid_arg "%s: %s off %u < 0" + tag buf off; + if actual_len - off < len then + invalid_arg "%s: %s length %u - off %u < len %u" + tag buf actual_len off len +[@@inline] + module Modes = struct module ECB_of (Core : Block.Core) : Block.ECB = struct @@ -167,12 +176,8 @@ module Modes = struct let ecb xform key src src_off dst dst_off len = if len mod block_size <> 0 then invalid_arg "ECB: length %u not of block size" len; - if String.length src - src_off < len then - invalid_arg "ECB: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "ECB: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; + check_offset ~tag:"ECB" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"ECB" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_ecb xform key src src_off dst dst_off len let encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = @@ -209,15 +214,16 @@ module Modes = struct let of_secret = Core.of_secret - let bounds_check ?(off = 0) ~iv cs = + let block_size_check ?(off = 0) ~iv cs = if String.length iv <> block then invalid_arg "CBC: IV length %u not of block size" (String.length iv); if (String.length cs - off) mod block <> 0 then invalid_arg "CBC: argument length %u (off %u) not of block size" (String.length cs) off + [@@inline] let next_iv ?(off = 0) cs ~iv = - bounds_check ~iv cs ~off ; + block_size_check ~iv cs ~off ; if String.length cs > off then String.sub cs (String.length cs - block_size) block_size else iv @@ -237,13 +243,9 @@ module Modes = struct unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len let encrypt_into ~key ~iv src ~src_off dst ~dst_off len = - bounds_check ~off:src_off ~iv src; - if String.length src - src_off < len then - invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" - (Bytes.length dst) dst_off len; + block_size_check ~off:src_off ~iv src; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len let encrypt ~key ~iv src = @@ -260,13 +262,9 @@ module Modes = struct end let decrypt_into ~key ~iv src ~src_off dst ~dst_off len = - bounds_check ~off:src_off ~iv src; - if String.length src - src_off < len then - invalid_arg "CBC: src has insufficient length (%u - src_off:%u < len %u)" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CBC: dst has insufficient length (%u - dst_off:%u < len %u)" - (Bytes.length dst) dst_off len; + block_size_check ~off:src_off ~iv src; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len let decrypt ~key ~iv src = @@ -302,9 +300,7 @@ module Modes = struct end let stream_into ~key ~ctr buf ~off len = - if Bytes.length buf - off < len then - invalid_arg "CTR: buffer length %u - off %u < len %u" - (Bytes.length buf) off len; + check_offset ~tag:"CTR" ~buf:"buf" ~off ~len (Bytes.length buf); unsafe_stream_into ~key ~ctr buf ~off len let stream ~key ~ctr n = @@ -317,12 +313,8 @@ module Modes = struct Uncommon.unsafe_xor_into src ~src_off dst ~dst_off len let encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "CTR: src length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CTR: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; + check_offset ~tag:"CTR" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CTR" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len let encrypt ~key ~ctr src = @@ -416,15 +408,9 @@ module Modes = struct unsafe_tag_into ~key ~hkey ~ctr ?adata (Bytes.unsafe_to_string dst) ~off:dst_off ~len dst ~tag_off let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - if String.length src - src_off < len then - invalid_arg "GCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "GCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if Bytes.length dst - tag_off < tag_size then - invalid_arg "GCM: dst length %u - tag_off %u < tag_size %u" - (Bytes.length dst) tag_off tag_size; + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"GCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = @@ -446,15 +432,9 @@ module Modes = struct Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "GCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "GCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if String.length src - tag_off < tag_size then - invalid_arg "GCM: src length %u - tag_off %u < tag_size %u" - (String.length src) tag_off tag_size; + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata cdata = @@ -498,15 +478,9 @@ module Modes = struct invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = - if String.length src - src_off < len then - invalid_arg "CCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if Bytes.length dst - tag_off < tag_size then - invalid_arg "CCM: dst length %u - tag_off %u < tag_size %u" - (Bytes.length dst) tag_off tag_size; + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"CCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); valid_nonce nonce; unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len @@ -525,15 +499,9 @@ module Modes = struct Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen:tag_size ~adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = - if String.length src - src_off < len then - invalid_arg "CCM: source length %u - src_off %u < len %u" - (String.length src) src_off len; - if Bytes.length dst - dst_off < len then - invalid_arg "CCM: dst length %u - dst_off %u < len %u" - (Bytes.length dst) dst_off len; - if String.length src - tag_off < tag_size then - invalid_arg "CCM: src length %u - tag_off %u < tag_size %u" - (String.length src) tag_off tag_size; + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); valid_nonce nonce; unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len From acf74f4f30d360fe3aae4a93aff5e1ae77679c86 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:39:52 +0200 Subject: [PATCH 10/14] revise block_size check --- src/cipher_block.ml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index af958e28..249d1ee1 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -214,16 +214,16 @@ module Modes = struct let of_secret = Core.of_secret - let block_size_check ?(off = 0) ~iv cs = + let check_block_size ~iv len = if String.length iv <> block then invalid_arg "CBC: IV length %u not of block size" (String.length iv); - if (String.length cs - off) mod block <> 0 then - invalid_arg "CBC: argument length %u (off %u) not of block size" - (String.length cs) off + if len mod block <> 0 then + invalid_arg "CBC: argument length %u not of block size" + len [@@inline] let next_iv ?(off = 0) cs ~iv = - block_size_check ~iv cs ~off ; + check_block_size ~iv (String.length cs - off) ; if String.length cs > off then String.sub cs (String.length cs - block_size) block_size else iv @@ -243,7 +243,7 @@ module Modes = struct unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len let encrypt_into ~key ~iv src ~src_off dst ~dst_off len = - block_size_check ~off:src_off ~iv src; + check_block_size ~iv len; check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len @@ -262,7 +262,7 @@ module Modes = struct end let decrypt_into ~key ~iv src ~src_off dst ~dst_off len = - block_size_check ~off:src_off ~iv src; + check_block_size ~iv len; check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len From 829ceb51c0ace39419923e44bdfad0ac4b524850 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:40:04 +0200 Subject: [PATCH 11/14] update documentation, esp off < 0 --- src/mirage_crypto.mli | 46 +++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index 04eb7324..e9e6e371 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -248,16 +248,16 @@ module Block : sig from [src] starting at [src_off] into [dst] starting at [dst_off]. @raise Invalid_argument if [len] is not a multiple of {!block_size}. - @raise Invalid_argument if [String.length src - src_off < len]. - @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) val decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit (** [decrypt_into ~key src ~src_off dst dst_off len] decrypts [len] octets from [src] starting at [src_off] into [dst] starting at [dst_off]. @raise Invalid_argument if [len] is not a multiple of {!block_size}. - @raise Invalid_argument if [String.length src - src_off < len]. - @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) (**/**) val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit @@ -266,8 +266,8 @@ module Block : sig This may cause memory issues if an invariant is violated: {ul {- [len] must be a multiple of {!block_size},} - {- [String.length src - src_off >= len],} - {- [Bytes.length dst - dst_off >= len].}} *) + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit (** [unsafe_decrypt_into] is {!decrypt_into}, but without bounds checks. @@ -275,8 +275,8 @@ module Block : sig This may cause memory issues if an invariant is violated: {ul {- [len] must be a multiple of {!block_size},} - {- [String.length src - src_off >= len],} - {- [Bytes.length dst - dst_off >= len].}} *) + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) (**/**) end @@ -334,8 +334,8 @@ module Block : sig @raise Invalid_argument if the length of [iv] is not {!block_size}. @raise Invalid_argument if [len] is not a multiple of {!block_size}. - @raise Invalid_argument if [String.length src - src_off < len]. - @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) val decrypt_into : key:key -> iv:string -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit @@ -344,8 +344,8 @@ module Block : sig @raise Invalid_argument if the length of [iv] is not {!block_size}. @raise Invalid_argument if [len] is not a multiple of {!block_size}. - @raise Invalid_argument if [String.length src - src_off < len]. - @raise Invalid_argument if [Bytes.length dst - dst_off < len]. *) + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) (**/**) val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> @@ -356,8 +356,8 @@ module Block : sig {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} - {- [String.length src - src_off >= len],} - {- [Bytes.length dst - dst_off >= len].}} *) + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit @@ -367,8 +367,8 @@ module Block : sig {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} - {- [String.length src - src_off >= len],} - {- [Bytes.length dst - dst_off >= len].}} *) + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) val unsafe_encrypt_into_inplace : key:key -> iv:string -> bytes -> dst_off:int -> int -> unit @@ -379,8 +379,8 @@ module Block : sig {ul {- the length of [iv] must be {!block_size},} {- [len] must be a multiple of {!block_size},} - {- [String.length src - src_off >= len],} - {- [Bytes.length dst - dst_off >= len].}} *) + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) (**/**) end @@ -459,8 +459,8 @@ end key stream into [dst] at [dst_off], and then xors it with [src] at [src_off]. - @raise Invalid_argument if [Bytes.length dst - off < len]. - @raise Invalid_argument if [String.length src - off < len]. *) + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. *) val decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit @@ -472,7 +472,7 @@ end This may cause memory issues if the invariant is violated: {ul - {- [Bytes.length buf - off >= len].}} *) + {- [off >= 0 && Bytes.length buf - off >= len].}} *) val unsafe_encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit @@ -480,8 +480,8 @@ end This may cause memory issues if an invariant is violated: {ul - {- [Bytes.length dst - off >= len],} - {- [String.length src - off < len].}} *) + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len],} + {- [src_off >= 0 && String.length src - src_off >= len].}} *) val unsafe_decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit From 7805a7cbca4d8c7aeb1f0be7b4d63d064c46f785 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:46:52 +0200 Subject: [PATCH 12/14] poly1305: mac_into appropriate bounds checks, also unsafe_mac_into --- src/chacha20.ml | 10 +++++----- src/mirage_crypto.mli | 5 +++++ src/poly1305.ml | 20 +++++++++++++++++++- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/chacha20.ml b/src/chacha20.ml index 07ed07d7..00119b5c 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -93,11 +93,11 @@ let mac_into ~key ~adata src ~src_off len dst ~dst_off = Bytes.unsafe_to_string data in let p1 = pad16 (String.length adata) and p2 = pad16 len in - P.mac_into ~key [ adata, 0, String.length adata ; - p1, 0, String.length p1 ; - src, src_off, len ; - p2, 0, String.length p2 ; - len_buf, 0, String.length len_buf ] + P.unsafe_mac_into ~key [ adata, 0, String.length adata ; + p1, 0, String.length p1 ; + src, src_off, len ; + p2, 0, String.length p2 ; + len_buf, 0, String.length len_buf ] dst ~dst_off let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index e9e6e371..d85797e6 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -76,6 +76,11 @@ module Poly1305 : sig val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit (** [mac_into ~key datas dst dst_off] computes the [mac] of [datas]. *) + + (**/**) + val unsafe_mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + (** [unsafe_mac_into ~key datas dst dst_off] is {!mac_into} without bounds checks. *) + (**/**) end (** {1 Symmetric-key cryptography} *) diff --git a/src/poly1305.ml b/src/poly1305.ml index aec53354..0a2cb72d 100644 --- a/src/poly1305.ml +++ b/src/poly1305.ml @@ -12,6 +12,7 @@ module type S = sig val mac : key:string -> string -> string val maci : key:string -> string iter -> string val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + val unsafe_mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit end module It : S = struct @@ -54,8 +55,25 @@ module It : S = struct let maci ~key iter = feedi (empty ~key) iter |> final - let mac_into ~key datas dst ~dst_off = + let unsafe_mac_into ~key datas dst ~dst_off = let ctx = empty ~key in List.iter (fun (d, off, len) -> P.update ctx d off len) datas; P.finalize ctx dst dst_off + + let mac_into ~key datas dst ~dst_off = + if Bytes.length dst - dst_off < mac_size then + Uncommon.invalid_arg "Poly1305: dst length %u - off %u < len %u" + (Bytes.length dst) dst_off mac_size; + if dst_off < 0 then + Uncommon.invalid_arg "Poly1305: dst_off %u < 0" dst_off; + let ctx = empty ~key in + List.iter (fun (d, off, len) -> + if off < 0 then + Uncommon.invalid_arg "Poly1305: d off %u < 0" off; + if String.length d - off < len then + Uncommon.invalid_arg "Poly1305: d length %u - off %u < len %u" + (String.length d) off len; + P.update ctx d off len) + datas; + P.finalize ctx dst dst_off end From 332890e105b43abd143d6e0a7e3225f35d54199a Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 14:51:15 +0200 Subject: [PATCH 13/14] ccm: remove maclen argument, and ensure tag_size = block_size --- src/ccm.ml | 25 +++++++++++-------------- src/cipher_block.ml | 12 ++++++------ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/ccm.ml b/src/ccm.ml index 23c30cf7..a0e02ec6 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -74,8 +74,8 @@ let prepare_header nonce adata plen tlen = type mode = Encrypt | Decrypt -let crypto_core_into ~cipher ~mode ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len = - let cbcheader = prepare_header nonce adata len maclen in +let crypto_core_into ~cipher ~mode ~key ~nonce ~adata src ~src_off dst ~dst_off len = + let cbcheader = prepare_header nonce adata len block_size in let small_q = 15 - String.length nonce in let ctr_flag_val = flags 0 0 (small_q - 1) in @@ -127,15 +127,12 @@ let crypto_core_into ~cipher ~mode ~key ~nonce ~maclen ~adata src ~src_off dst ~ loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) (len - block_size) end in - let last = loop cbcprep 1 src src_off dst dst_off len in - (* assert (maclen = Bytes.length last); *) - (* assert (block_size = maclen); *) - last + loop cbcprep 1 src src_off dst dst_off len -let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = +let crypto_core ~cipher ~mode ~key ~nonce ~adata data = let datalen = String.length data in let dst = Bytes.create datalen in - let t = crypto_core_into ~cipher ~mode ~key ~nonce ~maclen ~adata data ~src_off:0 dst ~dst_off:0 datalen in + let t = crypto_core_into ~cipher ~mode ~key ~nonce ~adata data ~src_off:0 dst ~dst_off:0 datalen in dst, t let crypto_t t nonce cipher key = @@ -143,13 +140,13 @@ let crypto_t t nonce cipher key = cipher ~key (Bytes.unsafe_to_string ctr) ~src_off:0 ctr ~dst_off:0 ; unsafe_xor_into (Bytes.unsafe_to_string ctr) ~src_off:0 t ~dst_off:0 (Bytes.length t) -let unsafe_generation_encryption_into ~cipher ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off ~tag_off len = - let t = crypto_core_into ~cipher ~mode:Encrypt ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len in +let unsafe_generation_encryption_into ~cipher ~key ~nonce ~adata src ~src_off dst ~dst_off ~tag_off len = + let t = crypto_core_into ~cipher ~mode:Encrypt ~key ~nonce ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; - Bytes.unsafe_blit t 0 dst tag_off maclen + Bytes.unsafe_blit t 0 dst tag_off block_size -let unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen ~adata src ~src_off ~tag_off dst ~dst_off len = - let tag = String.sub src tag_off maclen in - let t = crypto_core_into ~cipher ~mode:Decrypt ~key ~nonce ~maclen ~adata src ~src_off dst ~dst_off len in +let unsafe_decryption_verification_into ~cipher ~key ~nonce ~adata src ~src_off ~tag_off dst ~dst_off len = + let tag = String.sub src tag_off block_size in + let t = crypto_core_into ~cipher ~mode:Decrypt ~key ~nonce ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; Eqaf.equal tag (Bytes.unsafe_to_string t) diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 249d1ee1..5ac862c5 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -363,7 +363,7 @@ module Modes = struct module GCM_of (C : Block.Core) : Block.GCM = struct - let _ = assert (C.block = 16) + assert (C.block = 16) module CTR = CTR_of (C) (Counters.C128be32) type key = { key : C.ekey ; hkey : GHASH.key } @@ -455,9 +455,9 @@ module Modes = struct module CCM16_of (C : Block.Core) : Block.CCM16 = struct - let _ = assert (C.block = 16) + assert (C.block = 16) - let tag_size = 16 + let tag_size = C.block type key = C.ekey @@ -469,8 +469,8 @@ module Modes = struct C.encrypt ~key ~blocks:1 src src_off dst dst_off let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = - Ccm.unsafe_generation_encryption_into ~cipher ~key ~nonce ~maclen:tag_size - ~adata src ~src_off dst ~dst_off ~tag_off len + Ccm.unsafe_generation_encryption_into ~cipher ~key ~nonce ~adata + src ~src_off dst ~dst_off ~tag_off len let valid_nonce nonce = let nsize = String.length nonce in @@ -496,7 +496,7 @@ module Modes = struct String.sub res 0 (String.length cs), String.sub res (String.length cs) tag_size let unsafe_authenticate_decrypt_into ~key ~nonce ?(adata = "") src ~src_off ~tag_off dst ~dst_off len = - Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~maclen:tag_size ~adata src ~src_off ~tag_off dst ~dst_off len + Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); From 52e5105dd161684745791451132de8acaf5bf9e8 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Tue, 18 Jun 2024 15:06:44 +0200 Subject: [PATCH 14/14] add tailcall annotations, remove an argument from ccm's loop --- src/ccm.ml | 18 +++++++++--------- src/chacha20.ml | 2 +- src/cipher_block.ml | 2 +- src/cipher_stream.ml | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/ccm.ml b/src/ccm.ml index a0e02ec6..ecee28ec 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -10,7 +10,7 @@ let encode_len buf ~off size value = | 0 -> Bytes.set_uint8 buf off num | m -> Bytes.set_uint8 buf (off + m) (num land 0xff); - ass (num lsr 8) (pred m) + (ass [@tailcall]) (num lsr 8) (pred m) in ass value (pred size) @@ -91,25 +91,25 @@ let crypto_core_into ~cipher ~mode ~key ~nonce ~adata src ~src_off dst ~dst_off cipher ~key (Bytes.unsafe_to_string block) ~src_off:dst_off block ~dst_off in - let cbcprep = + let iv = let rec doit iv iv_off block block_off = match Bytes.length block - block_off with | 0 -> Bytes.sub iv iv_off block_size | _ -> cbc (Bytes.unsafe_to_string iv) iv_off block block_off; - doit block block_off block (block_off + block_size) + (doit [@tailcall]) block block_off block (block_off + block_size) in doit (Bytes.make block_size '\x00') 0 cbcheader 0 in - let rec loop iv ctr src src_off dst dst_off len = + let rec loop ctr src src_off dst dst_off len = let cbcblock, cbc_off = match mode with | Encrypt -> src, src_off | Decrypt -> Bytes.unsafe_to_string dst, dst_off in if len = 0 then - iv + () else if len < block_size then begin let buf = Bytes.make block_size '\x00' in Bytes.unsafe_blit dst dst_off buf 0 len ; @@ -118,16 +118,16 @@ let crypto_core_into ~cipher ~mode ~key ~nonce ~adata src ~src_off dst ~dst_off unsafe_xor_into src ~src_off dst ~dst_off len ; Bytes.unsafe_blit_string cbcblock cbc_off buf 0 len ; Bytes.unsafe_fill buf len (block_size - len) '\x00'; - cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 ; - iv + cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 end else begin ctrblock ctr dst ; unsafe_xor_into src ~src_off dst ~dst_off block_size ; cbc cbcblock cbc_off iv 0 ; - loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) (len - block_size) + (loop [@tailcall]) (succ ctr) src (src_off + block_size) dst (dst_off + block_size) (len - block_size) end in - loop cbcprep 1 src src_off dst dst_off len + loop 1 src src_off dst dst_off len; + iv let crypto_core ~cipher ~mode ~key ~nonce ~adata data = let datalen = String.length data in diff --git a/src/chacha20.ml b/src/chacha20.ml index 00119b5c..2c70251d 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -65,7 +65,7 @@ let crypt_into ~key ~nonce ~ctr src ~src_off dst ~dst_off len = chacha20_block state (dst_off + i) dst ; Native.xor_into_bytes src (src_off + i) dst (dst_off + i) block ; inc state; - loop (i + block) (n - 1) + (loop [@tailcall]) (i + block) (n - 1) in loop 0 block_count diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 5ac862c5..d430492f 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -234,7 +234,7 @@ module Modes = struct | b -> Native.xor_into_bytes iv iv_i dst dst_i block ; Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; - loop (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) + (loop [@tailcall]) (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) in loop iv 0 dst_off (len / block) diff --git a/src/cipher_stream.ml b/src/cipher_stream.ml index 67ee0a63..69bbae3f 100644 --- a/src/cipher_stream.ml +++ b/src/cipher_stream.ml @@ -26,7 +26,7 @@ module ARC4 = struct let j = (j + si + x) land 0xff in let sj = s.(j) in s.(i) <- sj ; s.(j) <- si ; - loop j (succ i) + (loop [@tailcall]) j (succ i) in ( loop 0 0 ; (0, 0, s) ) @@ -44,7 +44,7 @@ module ARC4 = struct s.(i) <- sj ; s.(j) <- si ; let k = s.((si + sj) land 0xff) in Bytes.set_uint8 res n (k lxor String.get_uint8 buf n); - mix i j (succ n) + (mix [@tailcall]) i j (succ n) in let key' = mix i j 0 in { key = key' ; message = Bytes.unsafe_to_string res }