From 9c1c12093be0c0df338183ec16e2cd8f24039a52 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 4 Feb 2024 01:05:00 +0100 Subject: [PATCH] lwt-unix --- unix/tar_lwt_unix.ml | 287 +++++++++++++++++++++++++++++++++++------- unix/tar_lwt_unix.mli | 64 ++++++++-- unix/tar_unix.ml | 41 ++++-- unix/tar_unix.mli | 2 +- 4 files changed, 324 insertions(+), 70 deletions(-) diff --git a/unix/tar_lwt_unix.ml b/unix/tar_lwt_unix.ml index 60cf251..98bfbf8 100644 --- a/unix/tar_lwt_unix.ml +++ b/unix/tar_lwt_unix.ml @@ -15,57 +15,252 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -open Lwt.Infix +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] -module Io = struct - type in_channel = Lwt_unix.file_descr - type 'a io = 'a Lwt.t - let really_read fd buf = - let len = Bytes.length buf in - let rec loop idx = - if idx = len then - Lwt.return_unit - else - Lwt_unix.read fd buf idx (len - idx) >>= fun n -> - loop (n + idx) - in - loop 0 - let skip (ifd: Lwt_unix.file_descr) (n: int) = - Lwt_unix.(lseek ifd n SEEK_CUR) >|= ignore +let pp_decode_error ppf = function + | `Fatal err -> Tar.pp_error ppf err + | `Unix (err, fname, arg) -> + Format.fprintf ppf "Unix error %s (function %s, arg %s)" + (Unix.error_message err) fname arg + | `Unexpected_end_of_file -> + Format.fprintf ppf "Unexpected end of file" + | `Msg msg -> + Format.fprintf ppf "Error %s" msg + +let safe f a = + let open Lwt.Infix in + Lwt.catch + (fun () -> f a >|= fun r -> Ok r) + (function + | Unix.Unix_error (e, f, a) -> Lwt.return (Error (`Unix (e, f, a))) + | e -> Lwt.reraise e) - type out_channel = Lwt_unix.file_descr - let really_write fd buf = - let len = String.length buf in - let rec loop idx = - if idx = len then - Lwt.return_unit +let read_complete fd buf len = + let open Lwt_result.Infix in + let rec loop offset = + if offset < len then + safe (Lwt_unix.read fd buf offset) (len - offset) >>= fun read -> + if read = 0 then + Lwt.return (Error `Unexpected_end_of_file) else - Lwt_unix.write_string fd buf idx (len - idx) >>= fun n -> - loop (idx + n) - in - loop 0 -end + loop (offset + read) + else + Lwt.return (Ok ()) + in + loop 0 + +let seek fd n = + safe (Lwt_unix.lseek fd n) Unix.SEEK_CUR + +let safe_close fd = + Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) + +let fold f filename init = + let open Lwt_result.Infix in + safe Lwt_unix.(openfile filename [ O_RDONLY ]) 0 >>= fun fd -> + let rec go t fd ?global ?data acc = + (match data with + | None -> + let buf = Bytes.make Tar.Header.length '\000' in + read_complete fd buf Tar.Header.length >|= fun () -> + Bytes.unsafe_to_string buf + | Some data -> + Lwt.return (Ok data)) >>= fun data -> + match Tar.decode t data with + | Ok (t, Some `Header hdr, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + f fd ?global hdr acc >>= fun acc' -> + seek fd (Tar.Header.compute_zero_padding_length hdr) >>= fun _off -> + go t fd ?global acc' + | Ok (t, Some `Skip n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + seek fd n >>= fun _off -> + go t fd ?global acc + | Ok (t, Some `Read n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let buf = Bytes.make n '\000' in + read_complete fd buf n >>= fun () -> + let data = Bytes.unsafe_to_string buf in + go t fd ?global ~data acc + | Ok (t, None, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + go t fd ?global acc + | Error `Eof -> Lwt.return (Ok acc) + | Error `Fatal _ as e -> Lwt.return e + in + Lwt.finalize + (fun () -> go (Tar.decode_state ()) fd init) + (fun () -> safe_close fd) + +let unix_err_to_msg = function + | `Unix (e, f, s) -> + `Msg (Format.sprintf "error %s in function %s %s" + (Unix.error_message e) f s) -include Io -module HeaderReader = Tar.HeaderReader(Lwt)(Io) -module HeaderWriter = Tar.HeaderWriter(Lwt)(Io) +let copy ~src_fd ~dst_fd len = + let open Lwt_result.Infix in + let blen = 65536 in + let buffer = Bytes.make blen '\000' in + let rec read_write ~src_fd ~dst_fd len = + if len = 0 then + Lwt.return (Ok ()) + else + let l = min blen len in + Lwt_result.map_error + (function + | `Unix _ as e -> unix_err_to_msg e + | `Unexpected_end_of_file -> + `Msg "Unexpected end of file") + (read_complete src_fd buffer l) >>= fun () -> + Lwt_result.map_error unix_err_to_msg + (safe (Lwt_unix.write dst_fd buffer 0) l) >>= fun _written -> + read_write ~src_fd ~dst_fd (len - l) + in + read_write ~src_fd ~dst_fd len + +let extract ?(filter = fun _ -> true) ~src dst = + let open Lwt_result.Infix in + let f fd ?global:_ hdr () = + if filter hdr then + match hdr.Tar.Header.link_indicator with + | Tar.Header.Link.Normal -> + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile (Filename.concat dst hdr.Tar.Header.file_name) + [ O_WRONLY ; O_CREAT ]) hdr.Tar.Header.file_mode) >>= fun dst -> + Lwt.finalize + (fun () -> copy ~src_fd:fd ~dst_fd:dst (Int64.to_int hdr.Tar.Header.file_size)) + (fun () -> safe_close dst) + (* TODO set owner / mode / mtime etc. *) + | _ -> + (* TODO handle directories, links, etc. *) + Lwt_result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) >|= fun _off -> + () + else + Lwt_result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) >|= fun _off -> + () + in + fold f src () (** Return the header needed for a particular file on disk *) -let header_of_file ?level (file: string) : Tar.Header.t Lwt.t = +let header_of_file ?level file = + let open Lwt_result.Infix in let level = Tar.Header.compatibility level in - Lwt_unix.LargeFile.stat file >>= fun stat -> - Lwt_unix.getpwuid stat.Lwt_unix.LargeFile.st_uid >>= fun pwent -> - Lwt_unix.getgrgid stat.Lwt_unix.LargeFile.st_gid >>= fun grent -> - let file_mode = stat.Lwt_unix.LargeFile.st_perm in - let user_id = stat.Lwt_unix.LargeFile.st_uid in - let group_id = stat.Lwt_unix.LargeFile.st_gid in - let file_size = stat.Lwt_unix.LargeFile.st_size in - let mod_time = Int64.of_float stat.Lwt_unix.LargeFile.st_mtime in + safe Lwt_unix.LargeFile.stat file >>= fun stat -> + let file_mode = stat.Lwt_unix.LargeFile.st_perm in + let user_id = stat.Lwt_unix.LargeFile.st_uid in + let group_id = stat.Lwt_unix.LargeFile.st_gid in + let file_size = stat.Lwt_unix.LargeFile.st_size in + let mod_time = Int64.of_float stat.Lwt_unix.LargeFile.st_mtime in let link_indicator = Tar.Header.Link.Normal in - let link_name = "" in - let uname = if level = V7 then "" else pwent.Lwt_unix.pw_name in - let gname = if level = V7 then "" else grent.Lwt_unix.gr_name in - let devmajor = if level = Ustar then stat.Lwt_unix.LargeFile.st_dev else 0 in - let devminor = if level = Ustar then stat.Lwt_unix.LargeFile.st_rdev else 0 in - Lwt.return (Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name - ~uname ~gname ~devmajor ~devminor file file_size) + let link_name = "" in + (if level = V7 then + Lwt.return (Ok "") + else + Lwt.catch + (fun () -> safe Lwt_unix.getpwuid stat.Lwt_unix.LargeFile.st_uid) + (function + | Not_found -> + Lwt.return (Error (`Msg ("No user entry found for UID"))) + | e -> Lwt.reraise e) >|= fun pwent -> + pwent.Lwt_unix.pw_name) >>= fun uname -> + (if level = V7 then + Lwt.return (Ok "") + else + Lwt.catch + (fun () -> safe Lwt_unix.getgrgid stat.Lwt_unix.LargeFile.st_gid) + (function + | Not_found -> + Lwt.return (Error (`Msg ("No group entry found for GID"))) + | e -> Lwt.reraise e) >|= fun grent -> + grent.Lwt_unix.gr_name) >>= fun gname -> + let devmajor = if level = Ustar then stat.Lwt_unix.LargeFile.st_dev else 0 in + let devminor = if level = Ustar then stat.Lwt_unix.LargeFile.st_rdev else 0 in + let hdr = Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name + ~uname ~gname ~devmajor ~devminor file file_size + in + Lwt.return (Ok hdr) + +let append_file ?level ?header filename fd = + let open Lwt_result.Infix in + (match header with + | None -> header_of_file ?level filename + | Some x -> Lwt.return (Ok x)) >>= fun header -> + Lwt_result.lift (Tar.encode_header ?level header) >>= fun header_strings -> + Lwt_list.fold_left_s (fun acc d -> + Lwt_result.lift acc >>= fun _written -> + Lwt_result.map_error unix_err_to_msg + (safe (Lwt_unix.write_string fd d 0) (String.length d))) + (Ok 0) header_strings >>= fun _written -> + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile filename [ O_RDONLY ]) 0) >>= fun src -> + (* TOCTOU [also, header may not be valid for file] *) + Lwt.finalize + (fun () -> copy ~src_fd:src ~dst_fd:fd + (Int64.to_int header.Tar.Header.file_size)) + (fun () -> safe_close src) + +let write_global_extended_header ?level header fd = + let open Lwt_result.Infix in + Lwt_result.lift (Tar.encode_global_extended_header ?level header) >>= fun header_strings -> + Lwt_list.fold_left_s (fun acc d -> + Lwt_result.lift acc >>= fun _written -> + Lwt_result.map_error unix_err_to_msg + (safe (Lwt_unix.write_string fd d 0) (String.length d))) + (Ok 0) header_strings >|= fun _written -> + () + +let write_end fd = + let open Lwt_result.Infix in + Lwt_result.map_error unix_err_to_msg + (safe + (Lwt_unix.write_string fd (Tar.Header.zero_block ^ Tar.Header.zero_block) 0) + (Tar.Header.length + Tar.Header.length)) >|= fun _written -> + () + +let create ?level ?global ?(filter = fun _ -> true) ~src dst = + let open Lwt_result.Infix in + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile dst [ O_WRONLY ; O_CREAT ]) 0o644) >>= fun dst_fd -> + Lwt.finalize + (fun () -> + (match global with + | None -> Lwt.return (Ok ()) + | Some hdr -> write_global_extended_header ?level hdr dst_fd) >>= fun () -> + let rec copy_files directory = + safe Lwt_unix.opendir directory >>= fun dir -> + Lwt.finalize + (fun () -> + let rec next () = + try + safe Lwt_unix.readdir dir >>= fun name -> + let filename = Filename.concat directory name in + header_of_file ?level filename >>= fun header -> + if filter header then + match header.Tar.Header.link_indicator with + | Normal -> + append_file ?level ~header filename dst_fd >>= fun () -> + next () + | Directory -> + (* TODO first finish curdir (and close the dir fd), then go deeper *) + copy_files filename >>= fun () -> + next () + | _ -> Lwt.return (Ok ()) (* NYI *) + else Lwt.return (Ok ()) + with End_of_file -> Lwt.return (Ok ()) + in + next ()) + (fun () -> + Lwt.catch + (fun () -> Lwt_unix.closedir dir) + (fun _ -> Lwt.return_unit)) + in + copy_files src >>= fun () -> + write_end dst_fd) + (fun () -> safe_close dst_fd) diff --git a/unix/tar_lwt_unix.mli b/unix/tar_lwt_unix.mli index 9b97e4d..a4c3d47 100644 --- a/unix/tar_lwt_unix.mli +++ b/unix/tar_lwt_unix.mli @@ -16,20 +16,58 @@ (** Lwt_unix I/O for tar-formatted data *) -val really_read: Lwt_unix.file_descr -> bytes -> unit Lwt.t -(** [really_read fd buf] fills [buf] with data from [fd] or fails - with {!Stdlib.End_of_file}. *) +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] -val really_write: Lwt_unix.file_descr -> string -> unit Lwt.t -(** [really_write fd buf] writes the full contents of [buf] to - [fd] or fails with {!Stdlib.End_of_file}. *) +val pp_decode_error : Format.formatter -> decode_error -> unit -val skip : Lwt_unix.file_descr -> int -> unit Lwt.t -(** [skip fd n] reads [n] bytes from [fd] and discards them. If possible, you - should use [Lwt_unix.lseek fd n Lwt_unix.SEEK_CUR] instead. *) +(** [fold f filename acc] folds over the tar archive. The function [f] is called + for each [hdr : Tar.Header.t]. It should forward the position in the file + descriptor by [hdr.Tar.Header.file_size]. *) +val fold : + (Lwt_unix.file_descr -> ?global:Tar.Header.Extended.t -> Tar.Header.t -> 'a -> + ('a, decode_error) result Lwt.t) -> + string -> 'a -> ('a, decode_error) result Lwt.t -(** Return the header needed for a particular file on disk. *) -val header_of_file : ?level:Tar.Header.compatibility -> string -> Tar.Header.t Lwt.t +(** [extract ~filter ~src dst] extracts the tar archive [src] into the + directory [dst]. If [dst] does not exist, it is created. If [filter] is + provided (defaults to [fun _ -> true]), any file where [filter hdr] returns + [false], is skipped. *) +val extract : + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, decode_error) result Lwt.t -module HeaderReader : Tar.HEADERREADER with type in_channel = Lwt_unix.file_descr and type 'a io = 'a Lwt.t -module HeaderWriter : Tar.HEADERWRITER with type out_channel = Lwt_unix.file_descr and type 'a io = 'a Lwt.t +(** [create ~level ~filter ~src dst] creates a tar archive at [dst]. It uses + [src], a directory name, as input. If [filter] is provided + (defaults to [fun _ -> true]), any file where [filter hdr] returns [false] + is skipped. *) +val create : ?level:Tar.Header.compatibility -> + ?global:Tar.Header.Extended.t -> + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [header_of_file ~level filename] returns the tar header of [filename]. *) +val header_of_file : ?level:Tar.Header.compatibility -> string -> + (Tar.Header.t, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [append_file ~level ~header filename fd] appends the contents of [filename] + to the tar archive [fd]. If [header] is not provided, {header_of_file} is + used for constructing a header. *) +val append_file : ?level:Tar.Header.compatibility -> ?header:Tar.Header.t -> + string -> Lwt_unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [write_global_extended_header ~level hdr fd] writes the extended header [hdr] to + [fd]. *) +val write_global_extended_header : ?level:Tar.Header.compatibility -> + Tar.Header.Extended.t -> Lwt_unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [write_end fd] writes the tar end marker to [fd]. *) +val write_end : Lwt_unix.file_descr -> (unit, [ `Msg of string ]) result Lwt.t diff --git a/unix/tar_unix.ml b/unix/tar_unix.ml index 340462e..394fd60 100644 --- a/unix/tar_unix.ml +++ b/unix/tar_unix.ml @@ -28,7 +28,7 @@ let safe_close fd = let read_complete fd buf len = let rec loop offset = if offset < len then - let* n = safe (Unix.read fd buf offset) (len - offset) in + let* n = safe (Unix.read fd buf offset) (len - offset) in if n = 0 then Error `Unexpected_end_of_file else @@ -136,7 +136,13 @@ let extract ?(filter = fun _ -> true) ~src dst = Fun.protect ~finally:(fun () -> safe_close dst) (fun () -> copy ~src_fd:fd ~dst_fd:dst (Int64.to_int hdr.Tar.Header.file_size)) (* TODO set owner / mode / mtime etc. *) - | _ -> Error (`Msg "not yet handled") + | _ -> + (* TODO handle directories, links, etc. *) + let* _off = + Result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) + in + Ok () else let* _off = Result.map_error unix_err_to_msg @@ -157,9 +163,25 @@ let header_of_file ?level file = (* TODO evaluate stat.st_kind *) let link_indicator = Tar.Header.Link.Normal in let link_name = "" in - let uname = if level = V7 then "" else (Unix.getpwuid stat.Unix.LargeFile.st_uid).Unix.pw_name in + let* uname = + if level = V7 then + Ok "" + else + try + let* passwd_entry = safe Unix.getpwuid stat.Unix.LargeFile.st_uid in + Ok passwd_entry.Unix.pw_name + with Not_found -> Error (`Msg ("No user entry found for UID")) + in let devmajor = if level = Ustar then stat.Unix.LargeFile.st_dev else 0 in - let gname = if level = V7 then "" else (Unix.getgrgid stat.Unix.LargeFile.st_gid).Unix.gr_name in + let* gname = + if level = V7 then + Ok "" + else + try + let* passwd_entry = safe Unix.getgrgid stat.Unix.LargeFile.st_gid in + Ok passwd_entry.Unix.gr_name + with Not_found -> Error (`Msg "No group entry found for GID") + in let devminor = if level = Ustar then stat.Unix.LargeFile.st_rdev else 0 in Ok (Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name ~uname ~gname ~devmajor ~devminor file stat.Unix.LargeFile.st_size) @@ -170,9 +192,9 @@ let append_file ?level ?header filename fd = | Some x -> Ok x in let* header_strings = Tar.encode_header ?level header in - let* _off = + let* _written = List.fold_left (fun acc d -> - let* _off = acc in + let* _written = acc in Result.map_error unix_err_to_msg (safe (Unix.write_substring fd d 0) (String.length d))) (Ok 0) header_strings @@ -188,9 +210,9 @@ let append_file ?level ?header filename fd = let write_global_extended_header ?level header fd = let* header_strings = Tar.encode_global_extended_header ?level header in - let* _off = + let* _written = List.fold_left (fun acc d -> - let* _off = acc in + let* _written = acc in Result.map_error unix_err_to_msg (safe (Unix.write_substring fd d 0) (String.length d))) (Ok 0) header_strings @@ -215,8 +237,7 @@ let create ?level ?global ?(filter = fun _ -> true) ~src dst = (fun () -> let* () = match global with | None -> Ok () - | Some hdr -> - write_global_extended_header ?level hdr dst_fd + | Some hdr -> write_global_extended_header ?level hdr dst_fd in let rec copy_files directory = let* dir = safe Unix.opendir directory in diff --git a/unix/tar_unix.mli b/unix/tar_unix.mli index 357efd3..58423e5 100644 --- a/unix/tar_unix.mli +++ b/unix/tar_unix.mli @@ -54,7 +54,7 @@ val create : ?level:Tar.Header.compatibility -> (** [header_of_file ~level filename] returns the tar header of [filename]. *) val header_of_file : ?level:Tar.Header.compatibility -> string -> - (Tar.Header.t, [ `Unix of (Unix.error * string * string) ]) result + (Tar.Header.t, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result (** [append_file ~level ~header filename fd] appends the contents of [filename] to the tar archive [fd]. If [header] is not provided, {header_of_file} is