diff --git a/bin/otar.ml b/bin/otar.ml index 33de886..75fedd5 100644 --- a/bin/otar.ml +++ b/bin/otar.ml @@ -13,7 +13,7 @@ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) - +(* let () = Printexc.record_backtrace true module Tar_gz = Tar_gz.Make @@ -129,3 +129,4 @@ let () = match Sys.argv with | _ -> let cmd = Filename.basename Sys.argv.(0) in Format.eprintf "%s []\n%s list \n" cmd cmd +*) diff --git a/lib/tar.ml b/lib/tar.ml index 35b2403..7fe8709 100644 --- a/lib/tar.ml +++ b/lib/tar.ml @@ -661,230 +661,157 @@ module Header = struct Int64.(div (add (pred (of_int length)) x.file_size) (of_int length)) end -module type ASYNC = sig - type 'a t - val ( >>= ): 'a t -> ('a -> 'b t) -> 'b t - val return: 'a -> 'a t -end - -module type READER = sig - type in_channel - type 'a io - val really_read: in_channel -> bytes -> unit io - val skip: in_channel -> int -> unit io -end - -module type WRITER = sig - type out_channel - type 'a io - val really_write: out_channel -> string -> unit io -end - -module type HEADERREADER = sig - type in_channel - type 'a io - val read : global:Header.Extended.t option -> in_channel -> - (Header.t * Header.Extended.t option, [ `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ]) result io -end - -module type HEADERWRITER = sig - type out_channel - type 'a io - val write : ?level:Header.compatibility -> Header.t -> out_channel -> (unit, [> `Msg of string ]) result io - val write_global_extended_header : Header.Extended.t -> out_channel -> (unit, [> `Msg of string ]) result io -end - let longlink = "././@LongLink" -module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) = struct - open Async - open Reader - - type in_channel = Reader.in_channel - type 'a io = 'a t - - (* This is not a bind, but more a lift and bind combined. *) - let ( let^* ) x f = - match x with - | Ok x -> f x - | Error _ as e -> return e - - let fix_link_indicator x = - (* For backward compatibility we treat normal files ending in slash as - directories. Because [Link.of_char] treats unrecognized link indicator - values as normal files we check directly. This is not completely correct - as [Header.Link.of_char] turns unknown link indicators into - [Header.Link.Normal]. Ideally, it should only be done for '0' and - '\000'. *) - if String.length x.Header.file_name > 0 - && x.file_name.[String.length x.file_name - 1] = '/' - && x.link_indicator = Header.Link.Normal then - { x with link_indicator = Header.Link.Directory } - else - x - - let read ~global (ifd: Reader.in_channel) : (Header.t * Header.Extended.t option, [ `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ]) result t = - (* We might need to read 2 headers at once if we encounter a Pax header *) - let buffer = Bytes.make Header.length '\000' in - let real_header_buf = Bytes.make Header.length '\000' in - - let next_block global () = - really_read ifd buffer >>= fun () -> - return (Header.unmarshal ?extended:global (Bytes.unsafe_to_string buffer)) +let fix_link_indicator x = + (* For backward compatibility we treat normal files ending in slash as + directories. Because [Link.of_char] treats unrecognized link indicator + values as normal files we check directly. This is not completely correct + as [Header.Link.of_char] turns unknown link indicators into + [Header.Link.Normal]. Ideally, it should only be done for '0' and + '\000'. *) + if String.length x.Header.file_name > 0 + && x.file_name.[String.length x.file_name - 1] = '/' + && x.link_indicator = Header.Link.Normal then + { x with link_indicator = Header.Link.Directory } + else + x + +type decode_state = { + global : Header.Extended.t option; + state : [ `Active of bool + | `Global_extended_header of Header.t + | `Per_file_extended_header of Header.t + | `Real_header of Header.Extended.t + | `Next_longlink of Header.t ]; + next_longlink : string option ; + next_longname : string option +} + +let decode_state ?global () = + { global ; state = `Active false ; next_longlink = None ; next_longname = None } + +let construct_header t (hdr : Header.t) = + let hdr = Option.fold ~none:hdr ~some:(fun file_name -> { hdr with file_name }) t.next_longname in + let hdr = Option.fold ~none:hdr ~some:(fun link_name -> { hdr with link_name }) t.next_longlink in + let hdr = fix_link_indicator hdr in + { t with next_longlink = None ; next_longname = None ; state = `Active false }, + hdr + +let decode t data = + match t.state with + | `Global_extended_header x -> + let* global = + (* unmarshal merges the previous global (if any) with the + discovered global (if any) and returns the new global. *) + Result.map_error (fun e -> `Fatal e) + (Header.Extended.unmarshal ~global:t.global data) in - - let rec get_hdr ~next_longname ~next_longlink global () : (Header.t * Header.Extended.t option, [> `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ]) result t = - next_block global () >>= function - | Ok x when x.Header.link_indicator = Header.Link.GlobalExtendedHeader -> - let extra_header_buf = Bytes.make (Int64.to_int x.Header.file_size) '\000' in - really_read ifd extra_header_buf >>= fun () -> - skip ifd (Header.compute_zero_padding_length x) >>= fun () -> - (* unmarshal merges the previous global (if any) with the - discovered global (if any) and returns the new global. *) - let^* global = - Result.map_error - (fun e -> `Fatal e) - (Header.Extended.unmarshal ~global (Bytes.unsafe_to_string extra_header_buf)) - in - get_hdr ~next_longname ~next_longlink (Some global) () - | Ok x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader -> - let extra_header_buf = Bytes.make (Int64.to_int x.Header.file_size) '\000' in - really_read ifd extra_header_buf >>= fun () -> - skip ifd (Header.compute_zero_padding_length x) >>= fun () -> - let^* extended = - Result.map_error - (fun e -> `Fatal e) - (Header.Extended.unmarshal ~global (Bytes.unsafe_to_string extra_header_buf)) - in - really_read ifd real_header_buf >>= fun () -> - let^* x = - Result.map_error - (fun _ -> `Fatal `Corrupt_pax_header) - (Header.unmarshal ~extended (Bytes.unsafe_to_string real_header_buf)) - in - let x = fix_link_indicator x in - return (Ok (x, global)) - | Ok ({ Header.link_indicator = Header.Link.LongLink | Header.Link.LongName; _ } as x) when x.Header.file_name = longlink -> - let extra_header_buf = Bytes.create (Int64.to_int x.Header.file_size) in - really_read ifd extra_header_buf >>= fun () -> - skip ifd (Header.compute_zero_padding_length x) >>= fun () -> - let name = String.sub (Bytes.unsafe_to_string extra_header_buf) 0 (Bytes.length extra_header_buf - 1) in - let next_longlink = if x.Header.link_indicator = Header.Link.LongLink then Some name else next_longlink in - let next_longname = if x.Header.link_indicator = Header.Link.LongName then Some name else next_longname in - get_hdr ~next_longname ~next_longlink global () - | Ok x -> - (* XXX: unclear how/if pax headers should interact with gnu extensions *) - let x = match next_longname with - | None -> x - | Some file_name -> { x with file_name } - in - let x = match next_longlink with - | None -> x - | Some link_name -> { x with link_name } - in - let x = fix_link_indicator x in - return (Ok (x, global)) - | Error `Zero_block -> - begin - next_block global () >>= function - | Ok x -> return (Ok (x, global)) - | Error `Zero_block -> return (Error `Eof) - | Error ((`Checksum_mismatch | `Unmarshal _) as e) -> return (Error (`Fatal e)) - end - | Error ((`Checksum_mismatch | `Unmarshal _) as e) -> - return (Error (`Fatal e)) + Ok ({ t with global = Some global ; state = `Active false }, + Some (`Skip (Header.compute_zero_padding_length x)), + Some global) + | `Per_file_extended_header x -> + let* extended = + Result.map_error + (fun e -> `Fatal e) + (Header.Extended.unmarshal ~global:t.global data) in - get_hdr ~next_longname:None ~next_longlink:None global () - -end - -module HeaderWriter(Async: ASYNC)(Writer: WRITER with type 'a io = 'a Async.t) = struct - open Async - open Writer - - type out_channel = Writer.out_channel - type 'a io = 'a t - - let write_unextended ?level header fd = - let level = Header.compatibility level in - let blank = {Header.file_name = longlink; file_mode = 0; user_id = 0; group_id = 0; mod_time = 0L; file_size = 0L; link_indicator = Header.Link.LongLink; link_name = ""; uname = "root"; gname = "root"; devmajor = 0; devminor = 0; extended = None} in - (if level = Header.GNU then begin - begin - if String.length header.Header.link_name > Header.sizeof_hdr_link_name then begin - let file_size = String.length header.Header.link_name + 1 in - let blank = {blank with Header.file_size = Int64.of_int file_size} in - let buffer = Bytes.make Header.length '\000' in - match - Header.marshal ~level buffer { blank with link_indicator = Header.Link.LongLink } - with - | Error _ as e -> return e - | Ok () -> - really_write fd (Bytes.unsafe_to_string buffer) >>= fun () -> - let payload = header.Header.link_name ^ "\000" in - really_write fd payload >>= fun () -> - really_write fd (Header.zero_padding blank) >>= fun () -> - return (Ok ()) - end else - return (Ok ()) - end >>= function - | Error _ as e -> return e - | Ok () -> - begin - if String.length header.Header.file_name > Header.sizeof_hdr_file_name then begin - let file_size = String.length header.Header.file_name + 1 in - let blank = {blank with Header.file_size = Int64.of_int file_size} in - let buffer = Bytes.make Header.length '\000' in - match - Header.marshal ~level buffer { blank with link_indicator = Header.Link.LongName } - with - | Error _ as e -> return e - | Ok () -> - really_write fd (Bytes.unsafe_to_string buffer) >>= fun () -> - let payload = header.Header.file_name ^ "\000" in - really_write fd payload >>= fun () -> - really_write fd (Header.zero_padding blank) >>= fun () -> - return (Ok ()) - end else - return (Ok ()) - end >>= function - | Error _ as e -> return e - | Ok () -> return (Ok ()) - end else - return (Ok ())) >>= function - | Error _ as e -> return e - | Ok () -> - let buffer = Bytes.make Header.length '\000' in - match Header.marshal ~level buffer header with - | Error _ as e -> return e - | Ok () -> - really_write fd (Bytes.unsafe_to_string buffer) >>= fun () -> - return (Ok ()) - - let write_extended ?level ~link_indicator hdr fd = - let link_indicator_name = match link_indicator with - | Header.Link.PerFileExtendedHeader -> "paxheader" - | Header.Link.GlobalExtendedHeader -> "pax_global_header" - | _ -> assert false + Ok ({ t with state = `Real_header extended }, + Some (`Skip (Header.compute_zero_padding_length x)), + None) + | `Real_header extended -> + let* x = + Result.map_error + (fun _ -> `Fatal `Corrupt_pax_header) (* NB better error *) + (Header.unmarshal ~extended data) in - let pax_payload = Header.Extended.marshal hdr in - let pax = Header.make ~link_indicator link_indicator_name - (Int64.of_int @@ String.length pax_payload) in - write_unextended ?level pax fd >>= function - | Error _ as e -> return e - | Ok () -> - really_write fd pax_payload >>= fun () -> - really_write fd (Header.zero_padding pax) >>= fun () -> - return (Ok ()) - - let write ?level header fd = - ( match header.Header.extended with - | None -> return (Ok ()) - | Some e -> - write_extended ?level ~link_indicator:Header.Link.PerFileExtendedHeader e fd ) - >>= function - | Error _ as e -> return e - | Ok () -> write_unextended ?level header fd - - let write_global_extended_header global fd = - write_extended ~link_indicator:Header.Link.GlobalExtendedHeader global fd -end + let t, hdr = construct_header t x in + Ok (t, Some (`Header hdr), None) + | `Next_longlink x -> + let name = String.sub data 0 (String.length data - 1) in + let next_longlink = if x.Header.link_indicator = Header.Link.LongLink then Some name else t.next_longlink in + let next_longname = if x.Header.link_indicator = Header.Link.LongName then Some name else t.next_longname in + Ok ({ t with next_longlink ; next_longname ; state = `Active false }, + Some (`Skip (Header.compute_zero_padding_length x)), + None) + | `Active read_zero -> + match Header.unmarshal ?extended:t.global data with + | Ok x when x.Header.link_indicator = Header.Link.GlobalExtendedHeader -> + Ok ({ t with state = `Global_extended_header x }, + Some (`Read (Int64.to_int x.Header.file_size)), + None) + | Ok x when x.Header.link_indicator = Header.Link.PerFileExtendedHeader -> + Ok ({ t with state = `Per_file_extended_header x }, + Some (`Read (Int64.to_int x.Header.file_size)), + None) + | Ok ({ Header.link_indicator = Header.Link.LongLink | Header.Link.LongName; _ } as x) when x.Header.file_name = longlink -> + Ok ({ t with state = `Next_longlink x }, + Some (`Read (Int64.to_int x.Header.file_size)), + None) + | Ok x -> + let t, hdr = construct_header t x in + Ok (t, Some (`Header hdr), None) + | Error `Zero_block -> + if read_zero then + Error `Eof + else + Ok ({ t with state = `Active true }, None, None) + | Error ((`Checksum_mismatch | `Unmarshal _) as e) -> + Error (`Fatal e) + +let encode_long level link_indicator payload = + let blank = {Header.file_name = longlink; file_mode = 0; user_id = 0; group_id = 0; mod_time = 0L; file_size = 0L; link_indicator = Header.Link.LongLink; link_name = ""; uname = "root"; gname = "root"; devmajor = 0; devminor = 0; extended = None} in + let payload = payload ^ "\000" in + let file_size = String.length payload in + let blank = {blank with Header.file_size = Int64.of_int file_size} in + let buffer = Bytes.make Header.length '\000' in + let* () = Header.marshal ~level buffer { blank with link_indicator } in + Ok [ Bytes.unsafe_to_string buffer ; payload ; Header.zero_padding blank ] + +let encode_unextended_header ?level header = + let level = Header.compatibility level in + let* pre = + if level = Header.GNU then + let* longlink = + if String.length header.Header.link_name > Header.sizeof_hdr_link_name then + encode_long level Header.Link.LongLink header.Header.link_name + else + Ok [] + in + let* longname = + if String.length header.Header.file_name > Header.sizeof_hdr_file_name then + encode_long level Header.Link.LongName header.Header.file_name + else + Ok [] + in + Ok (longlink @ longname) + else + Ok [] + in + let buffer = Bytes.make Header.length '\000' in + let* () = Header.marshal ~level buffer header in + Ok (pre @ [ Bytes.unsafe_to_string buffer ]) + +let encode_extended_header ?level scope hdr = + let link_indicator, link_indicator_name = match scope with + | `Per_file -> Header.Link.PerFileExtendedHeader, "paxheader" + | `Global ->Header.Link.GlobalExtendedHeader, "pax_global_header" + | _ -> assert false + in + let pax_payload = Header.Extended.marshal hdr in + let pax = + Header.make ~link_indicator link_indicator_name + (Int64.of_int @@ String.length pax_payload) + in + let* pax_hdr = encode_unextended_header ?level pax in + Ok (pax_hdr @ [ pax_payload ; Header.zero_padding pax ]) + +let encode_header ?level header = + let* extended = + Option.fold ~none:(Ok []) ~some:(encode_extended_header ?level `Per_file) header.Header.extended + in + let* rest = encode_unextended_header ?level header in + Ok (extended @ rest) + +let encode_global_extended_header ?level global = + encode_extended_header ?level `Global global diff --git a/lib/tar.mli b/lib/tar.mli index c995969..f0a24de 100644 --- a/lib/tar.mli +++ b/lib/tar.mli @@ -139,47 +139,34 @@ module Header : sig val to_sectors: t -> int64 end -module type ASYNC = sig - type 'a t - val ( >>= ): 'a t -> ('a -> 'b t) -> 'b t - val return: 'a -> 'a t -end - -module type READER = sig - type in_channel - type 'a io - val really_read: in_channel -> bytes -> unit io - val skip: in_channel -> int -> unit io -end - -module type WRITER = sig - type out_channel - type 'a io - val really_write: out_channel -> string -> unit io -end - -module type HEADERREADER = sig - type in_channel - type 'a io - - (** Returns the next header block or error [`Eof] if two consecutive - zero-filled blocks are discovered. Assumes stream is positioned at the - possible start of a header block. - @param global Holds the current global pax extended header, if - any. Needs to be given to the next call to [read]. *) - val read : global:Header.Extended.t option -> in_channel -> - (Header.t * Header.Extended.t option, [ `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ]) result io -end - -module type HEADERWRITER = sig - type out_channel - type 'a io - val write : ?level:Header.compatibility -> Header.t -> out_channel -> (unit, [> `Msg of string ]) result io - val write_global_extended_header : Header.Extended.t -> out_channel -> (unit, [> `Msg of string ]) result io -end - -module HeaderReader(Async: ASYNC)(Reader: READER with type 'a io = 'a Async.t) : - HEADERREADER with type in_channel = Reader.in_channel and type 'a io = 'a Async.t - -module HeaderWriter(Async: ASYNC)(Writer: WRITER with type 'a io = 'a Async.t) : - HEADERWRITER with type out_channel = Writer.out_channel and type 'a io = 'a Async.t +(** {1 Decoding and encoding of a whole archive} *) + +(** The type of the decode state. *) +type decode_state + +(** [decode_state ~global ()] constructs a decode_state. *) +val decode_state : ?global:Header.Extended.t -> unit -> decode_state + +(** [decode t data] decodes [data] taking the current state [t] into account. + It may result on success in a new state, optionally some action that should + be done ([`Read] or [`Skip]), or a decoded [`Header]. Possibly a new global + PAX header is provided as well. + + If no [`Read] or [`Skip] is returned, the new state should be used with + [decode] with the next [Header.length] sized string, which will lead to + further decoding until [`Eof] (or an error) occurs. *) +val decode : decode_state -> string -> + (decode_state * [ `Read of int | `Skip of int | `Header of Header.t ] option * Header.Extended.t option, + [ `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ]) + result + +(** [encode_header ~level hdr] encodes the header with the provided [level] + (defaults to [V7]) into a list of strings to be written to the disk. + Once a header is written, the payload (padded to multiples of + [Header.length]) should follow. *) +val encode_header : ?level:Header.compatibility -> + Header.t -> (string list, [> `Msg of string ]) result + +(** [encode_global_extended_header hdr] encodes the global extended header as + a list of strings. *) +val encode_global_extended_header : ?level:Header.compatibility -> Header.Extended.t -> (string list, [> `Msg of string ]) result diff --git a/lib_test/dune b/lib_test/dune index e355bb1..79ed943 100644 --- a/lib_test/dune +++ b/lib_test/dune @@ -9,4 +9,5 @@ alcotest-lwt lwt tar-unix - tar-mirage)) + tar-mirage +)) diff --git a/lib_test/global_extended_headers_test.ml b/lib_test/global_extended_headers_test.ml index a5ae6de..c130382 100644 --- a/lib_test/global_extended_headers_test.ml +++ b/lib_test/global_extended_headers_test.ml @@ -1,37 +1,5 @@ let level = Tar.Header.Ustar -module Writer = struct - type out_channel = Stdlib.out_channel - type 'a io = 'a - let really_write oc str = - output_string oc str -end - -module HW = Tar.HeaderWriter - (struct type 'a t = 'a - let ( >>= ) x f = f x - let return x = x end) - (Writer) - -module Reader = struct - type in_channel = Stdlib.in_channel - type 'a io = 'a - let really_read ic buf = - really_input ic buf 0 (Bytes.length buf) - let skip ic len = - let cur = pos_in ic in - seek_in ic (cur + len) - let read ic buf = - let max = Bytes.length buf in - input ic buf 0 max -end - -module HR = Tar.HeaderReader - (struct type 'a t = 'a - let ( >>= ) x f = f x - let return x = x end) - (Reader) - let make_extended user_id = Tar.Header.Extended.make ~user_id () @@ -41,92 +9,67 @@ let make_file = let name = "file" ^ string_of_int !gen in incr gen; let hdr = Tar.Header.make name 0L in - hdr, fun cout -> - Tar.Header.zero_padding hdr - |> output_string cout + hdr + +let ( let* ) = Result.bind (* Tests that global and per-file extended headers correctly override each other. *) let use_global_extended_headers _test_ctxt = (* Write an archive using global and per-file pax extended headers *) begin try Sys.remove "test.tar" with _ -> () end; - let cout = open_out_bin "test.tar" in + let cout = Unix.openfile "test.tar" [ Unix.O_CREAT ; Unix.O_WRONLY ] 0o644 in let g0 = make_extended 1000 in - let hdr, f = make_file () in - match HW.write_global_extended_header g0 cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg + let g1 = make_extended 3000 in + match + Fun.protect ~finally:(fun () -> Unix.close cout) + (fun () -> + let* () = Tar_unix.write_global_extended_header ~level g0 cout in + let hdr = make_file () in + let* () = Tar_unix.write_header ~level hdr cout in + let hdr = make_file () in + let hdr = { hdr with Tar.Header.extended = Some (make_extended 2000) } in + let* () = Tar_unix.write_header ~level hdr cout in + let hdr = make_file () in + let* () = Tar_unix.write_header ~level hdr cout in + let hdr = make_file () in + let* () = Tar_unix.write_global_extended_header ~level g1 cout in + let* () = Tar_unix.write_header ~level hdr cout in + Tar_unix.write_end cout) + with + | Error `Msg msg -> Alcotest.failf "failed to write something: %s" msg + | Error `Unix (err, f, a) -> + Alcotest.failf "failed to write: unix error %s %s %s" (Unix.error_message err) f a | Ok () -> - match HW.write ~level hdr cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg - | Ok () -> - f cout; - let hdr, f = make_file () in - let hdr = { hdr with Tar.Header.extended = Some (make_extended 2000) } in - match HW.write ~level hdr cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg - | Ok () -> - f cout; - let hdr, f = make_file () in - match HW.write ~level hdr cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg - | Ok () -> - f cout; - let g1 = make_extended 3000 in - let hdr, f = make_file () in - match HW.write_global_extended_header g1 cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg - | Ok () -> - match HW.write ~level hdr cout with - | Error `Msg msg -> Alcotest.failf "failed to write header %s" msg - | Ok () -> - f cout; - Writer.really_write cout Tar.Header.zero_block; - Writer.really_write cout Tar.Header.zero_block; - close_out cout; - (* Read the same archive, testing that headers have been squashed. *) - let cin = open_in_bin "test.tar" in - let global = ref None in - let header = - let pp ppf hdr = Fmt.pf ppf "%s" (Tar.Header.Extended.to_detailed_string hdr) in - Alcotest.testable (fun ppf hdr -> Fmt.pf ppf "%a" Fmt.(option pp) hdr) ( = ) - in - ( match HR.read ~global:!global cin with - | Ok (hdr, global') -> - Alcotest.check header "expected global header" (Some g0) global'; - global := global'; - Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id; - let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in - Reader.skip cin to_skip; - | Error `Eof -> failwith "Couldn't read header, end of file" - | Error (`Fatal err) -> Fmt.failwith "Couldn't read header: %a" Tar.pp_error err ); - ( match HR.read ~global:!global cin with - | Ok (hdr, global') -> - Alcotest.check header "expected global header" (Some g0) global'; - global := global'; - Alcotest.(check int) "expected user" 2000 hdr.Tar.Header.user_id; - let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in - Reader.skip cin to_skip; - | Error _ -> failwith "Couldn't read header" ); - ( match HR.read ~global:!global cin with - | Ok (hdr, global') -> - Alcotest.check header "expected global header" (Some g0) global'; - global := global'; - Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id; - let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in - Reader.skip cin to_skip; - | Error _ -> failwith "Couldn't read header" ); - ( match HR.read ~global:!global cin with - | Ok (hdr, global') -> - Alcotest.check header "expected global header" (Some g1) global'; - global := global'; - Alcotest.(check int) "expected user" 3000 hdr.Tar.Header.user_id; - let to_skip = Tar.Header.(Int64.to_int (to_sectors hdr) * length) in - Reader.skip cin to_skip; - | Error _ -> failwith "Couldn't read header" ); - ( match HR.read ~global:!global cin with - | Error `Eof -> () - | _ -> failwith "Should have found EOF"); - () + (* Read the same archive, testing that headers have been squashed. *) + let header = + let pp ppf hdr = Fmt.pf ppf "%s" (Tar.Header.Extended.to_detailed_string hdr) in + Alcotest.testable (fun ppf hdr -> Fmt.pf ppf "%a" Fmt.(option pp) hdr) ( = ) + in + let f _fd ?global hdr idx = + match idx with + | 0 -> + Alcotest.check header "expected global header" (Some g0) global; + Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id; + Ok 1 + | 1 -> + Alcotest.check header "expected global header" (Some g0) global; + Alcotest.(check int) "expected user" 2000 hdr.Tar.Header.user_id; + Ok 2 + | 2 -> + Alcotest.check header "expected global header" (Some g0) global; + Alcotest.(check int) "expected user" 1000 hdr.Tar.Header.user_id; + Ok 3 + | 3 -> + Alcotest.check header "expected global header" (Some g1) global; + Alcotest.(check int) "expected user" 3000 hdr.Tar.Header.user_id; + Ok 4 + | _ -> Alcotest.fail "too many headers" + in + match Tar_unix.fold f "test.tar" 0 with + | Ok 4 -> () + | Ok n -> Alcotest.failf "early abort, expected 4, received %u" n + | Error e -> Alcotest.failf "failed to read: %a" Tar_unix.pp_decode_error e let () = let suite = "tar - pax global extended headers", [ diff --git a/lib_test/parse_test.ml b/lib_test/parse_test.ml index 05ee8e2..03d0bd9 100644 --- a/lib_test/parse_test.ml +++ b/lib_test/parse_test.ml @@ -31,21 +31,15 @@ module Unix = struct if Sys.win32 then truncate (convert_path `Windows path) else truncate path end -let list fd = - let rec loop global acc = - match Tar_unix.HeaderReader.read ~global fd with - | Ok (hdr, global) -> - print_endline hdr.Tar.Header.file_name; - Tar_unix.skip fd - (Int64.to_int hdr.Tar.Header.file_size + Tar.Header.compute_zero_padding_length hdr); - loop global (hdr :: acc) - | Error `Eof -> - List.rev acc - | Error `Fatal e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e +let list filename = + let f fd ?global:_ hdr acc = + print_endline hdr.Tar.Header.file_name; + ignore Unix.(lseek fd (Int64.to_int hdr.Tar.Header.file_size) SEEK_CUR); + Ok (hdr :: acc) in - let r = loop None [] in - List.iter (fun h -> print_endline h.Tar.Header.file_name) r; - r + match Tar_unix.fold f filename [] with + | Ok acc -> List.rev acc + | Error e -> Alcotest.failf "unexpected error: %a" Tar_unix.pp_decode_error e let pp_header f x = Fmt.pf f "%s" (Tar.Header.to_detailed_string x) let header = Alcotest.testable pp_header ( = ) @@ -104,10 +98,8 @@ let with_tar ?(level:Tar.Header.compatibility option) ?files ?(sector_size = 512 let can_read_tar () = with_tar () @@ fun tar_filename files -> - let fd = Unix.openfile tar_filename [ O_RDONLY; O_CLOEXEC ] 0 in - let files' = List.map (fun t -> t.Tar.Header.file_name) (list fd) in + let files' = List.map (fun t -> t.Tar.Header.file_name) (list tar_filename) in flush stdout; - Unix.close fd; let missing = set_difference files files' in let missing' = set_difference files' files in Alcotest.(check (list string)) "missing" [] missing; @@ -121,53 +113,45 @@ let can_write_pax () = let fd = Unix.openfile filename [ O_CREAT; O_WRONLY; O_CLOEXEC ] 0o0644 in Fun.protect (fun () -> - let hdr = Tar.Header.make ~user_id "test" 0L in - match Tar_unix.HeaderWriter.write hdr fd with + let header = Tar.Header.make ~user_id "test" 0L in + match Tar_unix.write_header header fd with | Ok () -> - Tar_unix.really_write fd Tar.Header.zero_block; - Tar_unix.really_write fd Tar.Header.zero_block; + (match Tar_unix.write_end fd with + | Ok () -> () + | Error `Msg msg -> + Alcotest.failf "error writing end %s" msg) | Error `Msg msg -> Alcotest.failf "error writing header %s" msg + | Error `Unix (e, f, a) -> + Alcotest.failf "error writing header - unix error %s %s %s" + (Unix.error_message e) f a ) ~finally:(fun () -> Unix.close fd); (* Read it back and verify the header was read *) - let fd = Unix.openfile filename [ O_RDONLY; O_CLOEXEC ] 0 in - Fun.protect - (fun () -> - match list fd with - | [ one ] -> Alcotest.(check int) "user_id" user_id one.Tar.Header.user_id - | xs -> Alcotest.failf "Headers = %a" (Fmt.list pp_header) xs - ) ~finally:(fun () -> Unix.close fd) - + match list filename with + | [ one ] -> Alcotest.(check int) "user_id" user_id one.Tar.Header.user_id + | xs -> Alcotest.failf "Headers = %a" (Fmt.list pp_header) xs let can_list_longlink_tar () = - let fd = Unix.openfile "lib_test/long.tar" [ O_RDONLY; O_CLOEXEC ] 0o0 in - Fun.protect - (fun () -> - let all = list fd in - let filenames = List.map (fun h -> h.Tar.Header.file_name) all in - (* List.iteri (fun i x -> Printf.fprintf stderr "%d: %s\n%!" i x) filenames; *) - let expected = [ - "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/"; - "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/BCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/"; - "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/BCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/CDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.txt"; - ] in - Alcotest.(check (list string)) "respects filenames" expected filenames - ) ~finally:(fun () -> Unix.close fd) + let all = list "lib_test/long.tar" in + let filenames = List.map (fun h -> h.Tar.Header.file_name) all in + (* List.iteri (fun i x -> Printf.fprintf stderr "%d: %s\n%!" i x) filenames; *) + let expected = [ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/"; + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/BCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/"; + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/BCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789/CDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.txt"; + ] in + Alcotest.(check (list string)) "respects filenames" expected filenames let can_list_long_pax_tar () = - let fd = Unix.openfile "lib_test/long-pax.tar" [ O_RDONLY; O_CLOEXEC ] 0x0 in - Fun.protect - (fun () -> - let all = list fd in - let filenames = List.map (fun h -> h.Tar.Header.file_name) all in - (* List.iteri (fun i x -> Printf.fprintf stderr "%d: %s\n%!" i x) filenames; *) - let expected = [ - "t/"; - "t/someveryveryverylonggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggname"; - "t/someveryveryverylonggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggglink"; - ] in - Alcotest.(check (list string)) "respects filenames" expected filenames - ) ~finally:(fun () -> Unix.close fd) + let all = list "lib_test/long-pax.tar" in + let filenames = List.map (fun h -> h.Tar.Header.file_name) all in + (* List.iteri (fun i x -> Printf.fprintf stderr "%d: %s\n%!" i x) filenames; *) + let expected = [ + "t/"; + "t/someveryveryverylonggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggname"; + "t/someveryveryverylonggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggglink"; + ] in + Alcotest.(check (list string)) "respects filenames" expected filenames (* "pax-shenanigans.tar" is an archive with a regular file "placeholder" with a pax header "path=clearly/a/directory/". The resulting header has normal link @@ -181,15 +165,14 @@ let can_list_long_pax_tar () = - Reynir *) let can_list_pax_implicit_dir () = - let fd = Unix.openfile "lib_test/pax-shenanigans.tar" [ O_RDONLY; O_CLOEXEC ] 0x0 in - Fun.protect ~finally:(fun () -> Unix.close fd) - (fun () -> - match Tar_unix.HeaderReader.read ~global:None fd with - | Error `Fatal e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e - | Error `Eof -> Alcotest.fail "unexpected end of file" - | Ok (hdr, _global) -> - Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.link_indicator; - Alcotest.(check string) "filename is patched" "clearly/a/directory/" hdr.file_name) + let f _fd ?global:_ hdr () = + Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.Tar.Header.link_indicator; + Alcotest.(check string) "filename is patched" "clearly/a/directory/" hdr.file_name; + Ok () + in + match Tar_unix.fold f "lib_test/pax-shenanigans.tar" () with + | Ok () -> () + | Error e -> Alcotest.failf "unexpected error: %a" Tar_unix.pp_decode_error e (* Sample tar generated with commit 1583f71ea33b2836d3fb996ac7dc35d55abe2777: [let buf = @@ -203,16 +186,14 @@ let can_list_pax_implicit_dir () = Tar.Header.marshal ~level (Cstruct.shift buf 1024) hdr; buf] *) let can_list_longlink_implicit_dir () = - let fd = Unix.openfile "lib_test/long-implicit-dir.tar" [ O_RDONLY; O_CLOEXEC ] 0x0 in - Fun.protect ~finally:(fun () -> Unix.close fd) - (fun () -> - match Tar_unix.HeaderReader.read ~global:None fd with - | Ok (hdr, _global) -> - Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.link_indicator; - Alcotest.(check string) "filename is patched" "some/long/name/for/a/directory/" hdr.file_name - | Error `Fatal e -> Alcotest.failf "unexpected error: %a" Tar.pp_error e - | Error `Eof -> Alcotest.fail "unexpected end of file") - + let f _fd ?global:_ hdr () = + Alcotest.(check link) "is directory" Tar.Header.Link.Directory hdr.Tar.Header.link_indicator; + Alcotest.(check string) "filename is patched" "some/long/name/for/a/directory/" hdr.file_name; + Ok () + in + match Tar_unix.fold f "lib_test/long-implicit-dir.tar" () with + | Ok () -> () + | Error e -> Alcotest.failf "unexpected error: %a" Tar_unix.pp_decode_error e let starts_with ~prefix s = let len_s = String.length s @@ -224,25 +205,39 @@ let starts_with ~prefix s = in len_s >= len_pre && aux 0 let can_transform_tar () = - (* let level = Tar.Header.Ustar in with_tar ~level () @@ fun tar_in _file_list -> - let fd_in = Unix.openfile tar_in [ O_RDONLY; O_CLOEXEC ] 0 in let tar_out = Filename.temp_file "tar-transformed" ".tar" in let fd_out = Unix.openfile tar_out [ O_WRONLY; O_CREAT; O_CLOEXEC ] 0o644 in with_tmpdir @@ fun temp_dir -> - Tar_unix.Archive.transform ~level (fun hdr -> - {hdr with Tar.Header.file_name = Filename.concat temp_dir hdr.file_name}) - fd_in fd_out; - Unix.close fd_in; - Unix.close fd_out; - let fd_in = Unix.openfile tar_out [ O_RDONLY; O_CLOEXEC ] 0 in - Tar_unix.Archive.with_next_file fd_in ~global:None (fun fd_file _global hdr -> - Alcotest.(check string) "Filename was transformed" temp_dir - (String.sub hdr.file_name 0 (min (String.length hdr.file_name) (String.length temp_dir))); - Tar_unix.skip fd_file (Int64.to_int hdr.file_size)); - Unix.close fd_in - *) () + let f fd ?global:_ hdr _ = + ignore Unix.(lseek fd (Int64.to_int hdr.Tar.Header.file_size) SEEK_CUR); + let hdr = + { hdr with + Tar.Header.file_name = Filename.concat temp_dir hdr.file_name; + file_size = 0L + } + in + match Tar_unix.write_header ~level hdr fd_out with + | Ok () -> Ok () + | Error _ -> Alcotest.fail "error writing header" + in + match Tar_unix.fold f tar_in () with + | Error e -> Alcotest.failf "error folding %a" Tar_unix.pp_decode_error e + | Ok () -> + match Tar_unix.write_end fd_out with + | Error _ -> Alcotest.fail "couldn't write end" + | Ok () -> + Unix.close fd_out; + let f fd ?global:_ hdr _ = + ignore Unix.(lseek fd (Int64.to_int hdr.Tar.Header.file_size) SEEK_CUR); + Alcotest.(check string) "Filename was transformed" temp_dir + (String.sub hdr.file_name 0 (min (String.length hdr.file_name) (String.length temp_dir))); + Ok () + in + match Tar_unix.fold f tar_out () with + | Error e -> Alcotest.failf "error folding2 %a" Tar_unix.pp_decode_error e + | Ok () -> () module Block4096 = struct include Block diff --git a/mirage/tar_mirage.ml b/mirage/tar_mirage.ml index 35f5b55..4e45bbc 100644 --- a/mirage/tar_mirage.ml +++ b/mirage/tar_mirage.ml @@ -75,42 +75,66 @@ module Make_KV_RO (BLOCK : Mirage_block.S) = struct in Lwt.return r - module Reader = struct - type in_channel = { - b: BLOCK.t; - (** offset in bytes *) - mutable offset: int64; - info: Mirage_block.info; - } - type 'a io = 'a Lwt.t - let really_read in_channel buffer = - let len = Bytes.length buffer in - assert(len <= 512); - (* Tar assumes 512 byte sectors, but BLOCK might have 4096 byte sectors for example *) - let sector_size = in_channel.info.Mirage_block.sector_size in - let sector' = Int64.(div in_channel.offset (of_int sector_size)) in - let sector_aligned_len = - if len mod sector_size == 0 then len else - len + (sector_size - len mod sector_size) - in - let tmp = Cstruct.create sector_aligned_len in - BLOCK.read in_channel.b sector' [ tmp ] - >>= function - | Error e -> failwith (Format.asprintf "Failed to read sector %Ld from block device: %a" sector' - BLOCK.pp_error e) - | Ok () -> - (* If the BLOCK sector size is big, then we need to select the 512 bytes we want *) - let offset = Int64.(to_int (sub in_channel.offset (mul sector' (of_int sector_size)))) in - in_channel.offset <- Int64.(add in_channel.offset (of_int len)); - Cstruct.blit_to_bytes tmp offset buffer 0 len; - Lwt.return_unit - let skip in_channel n = - in_channel.offset <- Int64.(add in_channel.offset (of_int n)); - Lwt.return_unit - let _get_current_tar_sector in_channel = Int64.div in_channel.offset 512L - - end - module HR = Tar.HeaderReader(Lwt)(Reader) + let read_data info b offset buffer len = + assert(len <= 512); + (* Tar assumes 512 byte sectors, but BLOCK might have 4096 byte sectors for example *) + let sector_size = info.Mirage_block.sector_size in + let sector' = Int64.(div offset (of_int sector_size)) in + let sector_aligned_len = + if len mod sector_size == 0 then + len + else + len + (sector_size - len mod sector_size) + in + let tmp = Cstruct.create sector_aligned_len in + BLOCK.read b sector' [ tmp ] >>= function + | Error e -> + Lwt.return (Error (`Msg + (Format.asprintf "Failed to read sector %Ld from block device: %a" sector' + BLOCK.pp_error e))) + | Ok () -> + (* If the BLOCK sector size is big, then we need to select the 512 bytes we want *) + let offset_in_cs = Int64.(to_int (sub offset (mul sector' (of_int sector_size)))) in + Cstruct.blit_to_bytes tmp offset_in_cs buffer 0 len; + Lwt.return (Ok ()) + + let fold info b f init = + let open Lwt_result.Infix in + let rec go t offset ?global ?data acc = + (match data with + | None -> + let buf = Bytes.make Tar.Header.length '\000' in + read_data info b offset buf Tar.Header.length >|= fun () -> + Int64.(add offset (of_int Tar.Header.length)), Bytes.unsafe_to_string buf + | Some data -> + Lwt.return (Ok (offset, data))) >>= fun (offset, data) -> + match Tar.decode t data with + | Ok (t, Some `Header hdr, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + f offset ?global hdr acc >>= fun acc' -> + let off' = + Int64.(add offset (add hdr.Tar.Header.file_size + (of_int (Tar.Header.compute_zero_padding_length hdr)))) + in + go t off' ?global acc' + | Ok (t, Some `Skip n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let off' = Int64.(add offset (of_int n)) in + go t off' ?global acc + | Ok (t, Some `Read n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let buf = Bytes.make n '\000' in + read_data info b offset buf n >>= fun () -> + let data = Bytes.unsafe_to_string buf in + let off' = Int64.(add offset (of_int n)) in + go t off' ?global ~data acc + | Ok (t, None, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + go t offset ?global acc + | Error `Eof -> Lwt.return (Ok acc) + | Error `Fatal _ as e -> Lwt.return e + in + go (Tar.decode_state ()) 0L init (* [read_partial_sector t sector_start ~offset ~length dst] reads a single sector and blits [length] bytes from [offset] into [dst] @@ -255,33 +279,37 @@ module Make_KV_RO (BLOCK : Mirage_block.S) = struct let ssize = info.Mirage_block.sector_size in if ssize mod 512 <> 0 || ssize < 512 then invalid_arg "Sector size needs to be >= 512 and a multiple of 512"; - let in_channel = { Reader.b; offset = 0L; info } in - let rec loop ~global map = - HR.read ~global in_channel >>= function - | Error `Eof -> Lwt.return map - | Error `Fatal e -> - Format.kasprintf failwith "Error reading archive: %a" Tar.pp_error e - | Ok (tar, global) -> - let filename = trim_slash tar.Tar.Header.file_name in - let map = - if filename = "" then - map - else - let data_tar_offset = Int64.div in_channel.Reader.offset 512L in - let v_or_d = if is_dict filename then Dict (tar, StringMap.empty) else Value (tar, data_tar_offset) in - insert map (Mirage_kv.Key.v filename) v_or_d - in - Reader.skip in_channel (Int64.to_int tar.Tar.Header.file_size) >>= fun () -> - Reader.skip in_channel (Tar.Header.compute_zero_padding_length tar) >>= fun () -> - loop ~global map + let f offset ?global:_ hdr (_, map) = + let filename = trim_slash hdr.Tar.Header.file_name in + let map = + if filename = "" then + map + else + let data_tar_offset = Int64.(div offset (of_int Tar.Header.length)) in + let v_or_d = + if is_dict filename then + Dict (hdr, StringMap.empty) + else + Value (hdr, data_tar_offset) + in + insert map (Mirage_kv.Key.v filename) v_or_d + in + let eof = Int64.(add offset + (add hdr.Tar.Header.file_size + (of_int (Tar.Header.compute_zero_padding_length hdr)))) + in + Lwt.return (Ok (eof, map)) in - let root = StringMap.empty in - loop ~global:None root >>= fun map -> - (* This is after the two [zero_block]s *) - let end_of_archive = in_channel.Reader.offset in - let map = Dict (Tar.Header.make "/" 0L, map) in - let write_lock = Lwt_mutex.create () in - Lwt.return ({ b; map; info; end_of_archive; write_lock }) + fold info b f (0L, StringMap.empty) >>= function + | Error `Fatal e -> + Format.kasprintf failwith "Fatal error reading archive: %a" Tar.pp_error e + | Error `Msg msg -> + Format.kasprintf failwith "Error reading archive: %s" msg + | Ok (end_of_archive, map) -> + let end_of_archive = Int64.(add end_of_archive (of_int (2 * Tar.Header.length))) in + let map = Dict (Tar.Header.make "/" 0L, map) in + let write_lock = Lwt_mutex.create () in + Lwt.return ({ b; map; info; end_of_archive; write_lock }) let disconnect _ = Lwt.return_unit @@ -292,7 +320,14 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc include Make_KV_RO(BLOCK) - type write_error = [ `Block of BLOCK.error | `Block_write of BLOCK.write_error | Mirage_kv.write_error | `Entry_already_exists | `Path_segment_is_a_value | `Append_only | `Write_header of string ] + type write_error = [ + | `Block of BLOCK.error + | `Block_write of BLOCK.write_error + | Mirage_kv.write_error + | `Entry_already_exists + | `Path_segment_is_a_value + | `Append_only + | `Msg of string ] let pp_write_error ppf = function | `Block e -> Fmt.pf ppf "read error while writing: %a" BLOCK.pp_error e @@ -301,7 +336,7 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc | `Entry_already_exists -> Fmt.string ppf "entry already exists" | `Path_segment_is_a_value -> Fmt.string ppf "path segment is a value" | `Append_only -> Fmt.string ppf "append only" - | `Write_header msg -> Fmt.pf ppf "writing tar header failed: %s" msg + | `Msg msg -> Fmt.pf ppf "writing tar header failed: %s" msg let write t sector_start buffers = Lwt_result.map_error (fun e -> `Block_write e) @@ -364,51 +399,31 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc let map = remove map key in Dict (root, map) - module Writer = struct - type out_channel = { - b: BLOCK.t; - (** offset in bytes *) - mutable offset: int64; - info: Mirage_block.info; - } - type 'a io = 'a Lwt.t - exception Read of BLOCK.error - exception Write of BLOCK.write_error - let really_write out_channel str = - assert (String.length str <= Tar.Header.length); - let data = - let cs = Cstruct.create Tar.Header.length in - Cstruct.blit_from_string str 0 cs 0 (String.length str); - cs - in - let sector_size = out_channel.info.sector_size in - let sector = Int64.(div out_channel.offset (of_int sector_size)) in - let block = Cstruct.create sector_size in - BLOCK.read out_channel.b sector [ block ] >>= function - | Error e -> raise (Read e) - | Ok () -> - let start_offset = Int64.to_int out_channel.offset mod sector_size in - Cstruct.blit data 0 block start_offset (Cstruct.length data); - BLOCK.write out_channel.b sector [ block ] >>= function - | Error e -> raise (Write e) - | Ok () -> - Lwt.return_unit - end - module HW = Tar.HeaderWriter(Lwt)(Writer) + let write_data info b offset buffer = + assert (String.length buffer <= Tar.Header.length); + let sector_size = info.Mirage_block.sector_size in + let sector = Int64.(div offset (of_int sector_size)) in + let block = Cstruct.create sector_size in + BLOCK.read b sector [ block ] >>= function + | Error e -> Lwt.return (Error (`Block e)) + | Ok () -> + let start_offset = Int64.to_int offset mod sector_size in + Cstruct.blit_from_string buffer 0 block start_offset (String.length buffer); + BLOCK.write b sector [ block ] >>= function + | Error e -> Lwt.return (Error (`Block_write e)) + | Ok () -> Lwt.return (Ok ()) let write_header (t : t) header_start_bytes hdr = - let hw = Writer.{ b = t.b ; offset = header_start_bytes ; info = t.info } in (* it is important we write at level [Ustar] at most as we assume the header(s) taking up exactly 512 bytes. With [GNU] level extra blocks may be used for long names. *) - Lwt.catch - (fun () -> HW.write ~level:Tar.Header.Ustar hdr hw >|= function - | Ok () -> Ok () - | Error `Msg msg -> Error (`Write_header msg)) - (function - | Writer.Read e -> Lwt.return (Error (`Block e)) - | Writer.Write e -> Lwt.return (Error (`Block_write e)) - | exn -> raise exn) + let open Lwt_result.Infix in + Lwt_result.lift (Tar.encode_header ~level:Tar.Header.Ustar hdr) >>= fun datas -> + Lwt_list.fold_left_s (fun acc buf -> + Lwt_result.lift acc >>= fun off' -> + write_data t.info t.b off' buf >|= fun () -> + Int64.(add off' (of_int (String.length buf)))) + (Ok header_start_bytes) datas let set t key data = Lwt_mutex.with_lock t.write_lock (fun () -> @@ -486,7 +501,7 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc in write t (succ data_start_sector) remaining_sectors >>>= fun () -> (* finally write header and first block *) - write_header t header_start_bytes hdr >>>= fun () -> + write_header t header_start_bytes hdr >>>= fun _new_offset -> (* read in slack at beginning which could include the header *) read_partial_sector t data_start_sector first_sector ~offset:0L ~length:data_start_sector_offset >>>= fun () -> @@ -555,7 +570,7 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc | Error _ as e -> e end >>>= fun (hdr, data_offset) -> let hdr = { hdr with Tar.Header.file_name = Mirage_kv.Key.to_string dest } in - write_header t Int64.(sub (mul data_offset (of_int Tar.Header.length)) (of_int Tar.Header.length)) hdr >>>= fun () -> + write_header t Int64.(sub (mul data_offset (of_int Tar.Header.length)) (of_int Tar.Header.length)) hdr >>>= fun _new_off -> t.map <- update_insert t.map dest hdr data_offset; t.map <- update_remove t.map source; Lwt_result.return ()) @@ -680,7 +695,7 @@ module Make_KV_RW (CLOCK : Mirage_clock.PCLOCK) (BLOCK : Mirage_block.S) = struc ~length:(sub sector_size last_sector_offset) end >>>= fun () -> write t to_zero_start_sector (Array.to_list data) >>>= fun () -> - write_header t header_start_bytes hdr >>>= fun () -> + write_header t header_start_bytes hdr >>>= fun _new_offset -> let tar_offset = div (sub t.end_of_archive (of_int Tar.Header.length)) (of_int Tar.Header.length) in t.end_of_archive <- end_bytes; t.map <- update_insert t.map key hdr tar_offset; diff --git a/unix/tar_lwt_unix.ml b/unix/tar_lwt_unix.ml index 60cf251..c0ff4a2 100644 --- a/unix/tar_lwt_unix.ml +++ b/unix/tar_lwt_unix.ml @@ -15,57 +15,251 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -open Lwt.Infix - -module Io = struct - type in_channel = Lwt_unix.file_descr - type 'a io = 'a Lwt.t - let really_read fd buf = - let len = Bytes.length buf in - let rec loop idx = - if idx = len then - Lwt.return_unit - else - Lwt_unix.read fd buf idx (len - idx) >>= fun n -> - loop (n + idx) - in - loop 0 - let skip (ifd: Lwt_unix.file_descr) (n: int) = - Lwt_unix.(lseek ifd n SEEK_CUR) >|= ignore - - type out_channel = Lwt_unix.file_descr - let really_write fd buf = - let len = String.length buf in - let rec loop idx = - if idx = len then - Lwt.return_unit +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] + +let pp_decode_error ppf = function + | `Fatal err -> Tar.pp_error ppf err + | `Unix (err, fname, arg) -> + Format.fprintf ppf "Unix error %s (function %s, arg %s)" + (Unix.error_message err) fname arg + | `Unexpected_end_of_file -> + Format.fprintf ppf "Unexpected end of file" + | `Msg msg -> + Format.fprintf ppf "Error %s" msg + +let safe f a = + let open Lwt.Infix in + Lwt.catch + (fun () -> f a >|= fun r -> Ok r) + (function + | Unix.Unix_error (e, f, a) -> Lwt.return (Error (`Unix (e, f, a))) + | e -> Lwt.reraise e) + +let read_complete fd buf len = + let open Lwt_result.Infix in + let rec loop offset = + if offset < len then + safe (Lwt_unix.read fd buf offset) (len - offset) >>= fun read -> + if read = 0 then + Lwt.return (Error `Unexpected_end_of_file) else - Lwt_unix.write_string fd buf idx (len - idx) >>= fun n -> - loop (idx + n) - in - loop 0 -end + loop (offset + read) + else + Lwt.return (Ok ()) + in + loop 0 + +let seek fd n = + safe (Lwt_unix.lseek fd n) Unix.SEEK_CUR + +let safe_close fd = + Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) + +let fold f filename init = + let open Lwt_result.Infix in + safe Lwt_unix.(openfile filename [ O_RDONLY ]) 0 >>= fun fd -> + let rec go t fd ?global ?data acc = + (match data with + | None -> + let buf = Bytes.make Tar.Header.length '\000' in + read_complete fd buf Tar.Header.length >|= fun () -> + Bytes.unsafe_to_string buf + | Some data -> + Lwt.return (Ok data)) >>= fun data -> + match Tar.decode t data with + | Ok (t, Some `Header hdr, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + f fd ?global hdr acc >>= fun acc' -> + seek fd (Tar.Header.compute_zero_padding_length hdr) >>= fun _off -> + go t fd ?global acc' + | Ok (t, Some `Skip n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + seek fd n >>= fun _off -> + go t fd ?global acc + | Ok (t, Some `Read n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let buf = Bytes.make n '\000' in + read_complete fd buf n >>= fun () -> + let data = Bytes.unsafe_to_string buf in + go t fd ?global ~data acc + | Ok (t, None, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + go t fd ?global acc + | Error `Eof -> Lwt.return (Ok acc) + | Error `Fatal _ as e -> Lwt.return e + in + Lwt.finalize + (fun () -> go (Tar.decode_state ()) fd init) + (fun () -> safe_close fd) + +let unix_err_to_msg = function + | `Unix (e, f, s) -> + `Msg (Format.sprintf "error %s in function %s %s" + (Unix.error_message e) f s) -include Io -module HeaderReader = Tar.HeaderReader(Lwt)(Io) -module HeaderWriter = Tar.HeaderWriter(Lwt)(Io) +let copy ~src_fd ~dst_fd len = + let open Lwt_result.Infix in + let blen = 65536 in + let buffer = Bytes.make blen '\000' in + let rec read_write ~src_fd ~dst_fd len = + if len = 0 then + Lwt.return (Ok ()) + else + let l = min blen len in + Lwt_result.map_error + (function + | `Unix _ as e -> unix_err_to_msg e + | `Unexpected_end_of_file -> + `Msg "Unexpected end of file") + (read_complete src_fd buffer l) >>= fun () -> + Lwt_result.map_error unix_err_to_msg + (safe (Lwt_unix.write dst_fd buffer 0) l) >>= fun _written -> + read_write ~src_fd ~dst_fd (len - l) + in + read_write ~src_fd ~dst_fd len + +let extract ?(filter = fun _ -> true) ~src dst = + let open Lwt_result.Infix in + let f fd ?global:_ hdr () = + if filter hdr then + match hdr.Tar.Header.link_indicator with + | Tar.Header.Link.Normal -> + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile (Filename.concat dst hdr.Tar.Header.file_name) + [ O_WRONLY ; O_CREAT ]) hdr.Tar.Header.file_mode) >>= fun dst -> + Lwt.finalize + (fun () -> copy ~src_fd:fd ~dst_fd:dst (Int64.to_int hdr.Tar.Header.file_size)) + (fun () -> safe_close dst) + (* TODO set owner / mode / mtime etc. *) + | _ -> + (* TODO handle directories, links, etc. *) + Lwt_result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) >|= fun _off -> + () + else + Lwt_result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) >|= fun _off -> + () + in + fold f src () (** Return the header needed for a particular file on disk *) -let header_of_file ?level (file: string) : Tar.Header.t Lwt.t = +let header_of_file ?level file = + let open Lwt_result.Infix in let level = Tar.Header.compatibility level in - Lwt_unix.LargeFile.stat file >>= fun stat -> - Lwt_unix.getpwuid stat.Lwt_unix.LargeFile.st_uid >>= fun pwent -> - Lwt_unix.getgrgid stat.Lwt_unix.LargeFile.st_gid >>= fun grent -> - let file_mode = stat.Lwt_unix.LargeFile.st_perm in - let user_id = stat.Lwt_unix.LargeFile.st_uid in - let group_id = stat.Lwt_unix.LargeFile.st_gid in - let file_size = stat.Lwt_unix.LargeFile.st_size in - let mod_time = Int64.of_float stat.Lwt_unix.LargeFile.st_mtime in + safe Lwt_unix.LargeFile.stat file >>= fun stat -> + let file_mode = stat.Lwt_unix.LargeFile.st_perm in + let user_id = stat.Lwt_unix.LargeFile.st_uid in + let group_id = stat.Lwt_unix.LargeFile.st_gid in + let file_size = stat.Lwt_unix.LargeFile.st_size in + let mod_time = Int64.of_float stat.Lwt_unix.LargeFile.st_mtime in let link_indicator = Tar.Header.Link.Normal in - let link_name = "" in - let uname = if level = V7 then "" else pwent.Lwt_unix.pw_name in - let gname = if level = V7 then "" else grent.Lwt_unix.gr_name in - let devmajor = if level = Ustar then stat.Lwt_unix.LargeFile.st_dev else 0 in - let devminor = if level = Ustar then stat.Lwt_unix.LargeFile.st_rdev else 0 in - Lwt.return (Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name - ~uname ~gname ~devmajor ~devminor file file_size) + let link_name = "" in + (if level = V7 then + Lwt.return (Ok "") + else + Lwt.catch + (fun () -> safe Lwt_unix.getpwuid stat.Lwt_unix.LargeFile.st_uid) + (function + | Not_found -> + Lwt.return (Error (`Msg ("No user entry found for UID"))) + | e -> Lwt.reraise e) >|= fun pwent -> + pwent.Lwt_unix.pw_name) >>= fun uname -> + (if level = V7 then + Lwt.return (Ok "") + else + Lwt.catch + (fun () -> safe Lwt_unix.getgrgid stat.Lwt_unix.LargeFile.st_gid) + (function + | Not_found -> + Lwt.return (Error (`Msg ("No group entry found for GID"))) + | e -> Lwt.reraise e) >|= fun grent -> + grent.Lwt_unix.gr_name) >>= fun gname -> + let devmajor = if level = Ustar then stat.Lwt_unix.LargeFile.st_dev else 0 in + let devminor = if level = Ustar then stat.Lwt_unix.LargeFile.st_rdev else 0 in + let hdr = Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name + ~uname ~gname ~devmajor ~devminor file file_size + in + Lwt.return (Ok hdr) + +let write_strings fd datas = + let open Lwt_result.Infix in + Lwt_list.fold_left_s (fun acc d -> + Lwt_result.lift acc >>= fun _written -> + Lwt_result.map_error unix_err_to_msg + (safe (Lwt_unix.write_string fd d 0) (String.length d))) + (Ok 0) datas >|= fun _written -> + () + +let write_header ?level header fd = + let open Lwt_result.Infix in + Lwt_result.lift (Tar.encode_header ?level header) >>= fun header_strings -> + write_strings fd header_strings + +let append_file ?level ?header filename fd = + let open Lwt_result.Infix in + (match header with + | None -> header_of_file ?level filename + | Some x -> Lwt.return (Ok x)) >>= fun header -> + write_header ?level header fd >>= fun () -> + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile filename [ O_RDONLY ]) 0) >>= fun src -> + (* TOCTOU [also, header may not be valid for file] *) + Lwt.finalize + (fun () -> copy ~src_fd:src ~dst_fd:fd + (Int64.to_int header.Tar.Header.file_size)) + (fun () -> safe_close src) + +let write_global_extended_header ?level header fd = + let open Lwt_result.Infix in + Lwt_result.lift (Tar.encode_global_extended_header ?level header) >>= fun header_strings -> + write_strings fd header_strings + +let write_end fd = + write_strings fd [ Tar.Header.zero_block ; Tar.Header.zero_block ] + +let create ?level ?global ?(filter = fun _ -> true) ~src dst = + let open Lwt_result.Infix in + Lwt_result.map_error unix_err_to_msg + (safe Lwt_unix.(openfile dst [ O_WRONLY ; O_CREAT ]) 0o644) >>= fun dst_fd -> + Lwt.finalize + (fun () -> + (match global with + | None -> Lwt.return (Ok ()) + | Some hdr -> write_global_extended_header ?level hdr dst_fd) >>= fun () -> + let rec copy_files directory = + safe Lwt_unix.opendir directory >>= fun dir -> + Lwt.finalize + (fun () -> + let rec next () = + try + safe Lwt_unix.readdir dir >>= fun name -> + let filename = Filename.concat directory name in + header_of_file ?level filename >>= fun header -> + if filter header then + match header.Tar.Header.link_indicator with + | Normal -> + append_file ?level ~header filename dst_fd >>= fun () -> + next () + | Directory -> + (* TODO first finish curdir (and close the dir fd), then go deeper *) + copy_files filename >>= fun () -> + next () + | _ -> Lwt.return (Ok ()) (* NYI *) + else Lwt.return (Ok ()) + with End_of_file -> Lwt.return (Ok ()) + in + next ()) + (fun () -> + Lwt.catch + (fun () -> Lwt_unix.closedir dir) + (fun _ -> Lwt.return_unit)) + in + copy_files src >>= fun () -> + write_end dst_fd) + (fun () -> safe_close dst_fd) diff --git a/unix/tar_lwt_unix.mli b/unix/tar_lwt_unix.mli index 9b97e4d..1282b38 100644 --- a/unix/tar_lwt_unix.mli +++ b/unix/tar_lwt_unix.mli @@ -16,20 +16,63 @@ (** Lwt_unix I/O for tar-formatted data *) -val really_read: Lwt_unix.file_descr -> bytes -> unit Lwt.t -(** [really_read fd buf] fills [buf] with data from [fd] or fails - with {!Stdlib.End_of_file}. *) +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] -val really_write: Lwt_unix.file_descr -> string -> unit Lwt.t -(** [really_write fd buf] writes the full contents of [buf] to - [fd] or fails with {!Stdlib.End_of_file}. *) +val pp_decode_error : Format.formatter -> decode_error -> unit -val skip : Lwt_unix.file_descr -> int -> unit Lwt.t -(** [skip fd n] reads [n] bytes from [fd] and discards them. If possible, you - should use [Lwt_unix.lseek fd n Lwt_unix.SEEK_CUR] instead. *) +(** [fold f filename acc] folds over the tar archive. The function [f] is called + for each [hdr : Tar.Header.t]. It should forward the position in the file + descriptor by [hdr.Tar.Header.file_size]. *) +val fold : + (Lwt_unix.file_descr -> ?global:Tar.Header.Extended.t -> Tar.Header.t -> 'a -> + ('a, decode_error) result Lwt.t) -> + string -> 'a -> ('a, decode_error) result Lwt.t -(** Return the header needed for a particular file on disk. *) -val header_of_file : ?level:Tar.Header.compatibility -> string -> Tar.Header.t Lwt.t +(** [extract ~filter ~src dst] extracts the tar archive [src] into the + directory [dst]. If [dst] does not exist, it is created. If [filter] is + provided (defaults to [fun _ -> true]), any file where [filter hdr] returns + [false], is skipped. *) +val extract : + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, decode_error) result Lwt.t -module HeaderReader : Tar.HEADERREADER with type in_channel = Lwt_unix.file_descr and type 'a io = 'a Lwt.t -module HeaderWriter : Tar.HEADERWRITER with type out_channel = Lwt_unix.file_descr and type 'a io = 'a Lwt.t +(** [create ~level ~filter ~src dst] creates a tar archive at [dst]. It uses + [src], a directory name, as input. If [filter] is provided + (defaults to [fun _ -> true]), any file where [filter hdr] returns [false] + is skipped. *) +val create : ?level:Tar.Header.compatibility -> + ?global:Tar.Header.Extended.t -> + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [header_of_file ~level filename] returns the tar header of [filename]. *) +val header_of_file : ?level:Tar.Header.compatibility -> string -> + (Tar.Header.t, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [append_file ~level ~header filename fd] appends the contents of [filename] + to the tar archive [fd]. If [header] is not provided, {header_of_file} is + used for constructing a header. *) +val append_file : ?level:Tar.Header.compatibility -> ?header:Tar.Header.t -> + string -> Lwt_unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [write_header ~level hdr fd] writes the header [hdr] to [fd]. *) +val write_header : ?level:Tar.Header.compatibility -> + Tar.Header.t -> Lwt_unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [write_global_extended_header ~level hdr fd] writes the extended header [hdr] to + [fd]. *) +val write_global_extended_header : ?level:Tar.Header.compatibility -> + Tar.Header.Extended.t -> Lwt_unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result Lwt.t + +(** [write_end fd] writes the tar end marker to [fd]. *) +val write_end : Lwt_unix.file_descr -> (unit, [ `Msg of string ]) result Lwt.t diff --git a/unix/tar_unix.ml b/unix/tar_unix.ml index a1c1548..b4c04db 100644 --- a/unix/tar_unix.ml +++ b/unix/tar_unix.ml @@ -15,60 +15,247 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -module Direct = struct - type 'a t = 'a - let return x = x - let ( >>= ) m f = f m -end - -module Driver = struct - type 'a io = 'a Direct.t - type in_channel = Unix.file_descr - type out_channel = Unix.file_descr - - let rec with_restart op fd buf off len = - try op fd buf off len with - Unix.Unix_error (Unix.EINTR,_,_) -> - with_restart op fd buf off len - - let really_read fd buf = - let len = Bytes.length buf in - let rec loop offset = - if offset < len then - let n = with_restart Unix.read fd buf offset (len - offset) in - if n = 0 then raise End_of_file; +let ( let* ) = Result.bind + +let rec safe f a = + try Ok (f a) with + | Unix.Unix_error (Unix.EINTR, _, _) -> safe f a + | Unix.Unix_error (e, f, s) -> Error (`Unix (e, f, s)) + +let safe_close fd = + try Unix.close fd with _ -> () + +let read_complete fd buf len = + let rec loop offset = + if offset < len then + let* n = safe (Unix.read fd buf offset) (len - offset) in + if n = 0 then + Error `Unexpected_end_of_file + else loop (offset + n) - in - loop 0 + else + Ok () + in + loop 0 + +let seek fd n = + safe (Unix.lseek fd n) Unix.SEEK_CUR + +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] + +let pp_decode_error ppf = function + | `Fatal err -> Tar.pp_error ppf err + | `Unix (err, fname, arg) -> + Format.fprintf ppf "Unix error %s (function %s, arg %s)" + (Unix.error_message err) fname arg + | `Unexpected_end_of_file -> + Format.fprintf ppf "Unexpected end of file" + | `Msg msg -> + Format.fprintf ppf "Error %s" msg - let skip fd n = - ignore (Unix.lseek fd n Unix.SEEK_CUR) +let fold f filename init = + let* fd = safe Unix.(openfile filename [ O_RDONLY ]) 0 in + let rec go t fd ?global ?data acc = + let* data = match data with + | None -> + let buf = Bytes.make Tar.Header.length '\000' in + let* () = read_complete fd buf Tar.Header.length in + Ok (Bytes.unsafe_to_string buf) + | Some data -> Ok data + in + match Tar.decode t data with + | Ok (t, Some `Header hdr, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let* acc' = f fd ?global hdr acc in + let* _off = seek fd (Tar.Header.compute_zero_padding_length hdr) in + go t fd ?global acc' + | Ok (t, Some `Skip n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let* _off = seek fd n in + go t fd ?global acc + | Ok (t, Some `Read n, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + let buf = Bytes.make n '\000' in + let* () = read_complete fd buf n in + let data = Bytes.unsafe_to_string buf in + go t fd ?global ~data acc + | Ok (t, None, g) -> + let global = Option.fold ~none:global ~some:(fun g -> Some g) g in + go t fd ?global acc + | Error `Eof -> Ok acc + | Error `Fatal _ as e -> e + in + Fun.protect + ~finally:(fun () -> safe_close fd) + (fun () -> go (Tar.decode_state ()) fd init) - let really_write fd buf = - let offset = ref 0 in - while !offset < String.length buf do - offset := !offset + with_restart Unix.write_substring fd buf !offset (String.length buf - !offset) - done -end +let unix_err_to_msg = function + | `Unix (e, f, s) -> + `Msg (Format.sprintf "error %s in function %s %s" + (Unix.error_message e) f s) -module HeaderReader = Tar.HeaderReader(Direct)(Driver) -module HeaderWriter = Tar.HeaderWriter(Direct)(Driver) +let copy ~src_fd ~dst_fd len = + let blen = 65536 in + let buffer = Bytes.make blen '\000' in + let rec read_write ~src_fd ~dst_fd len = + if len = 0 then + Ok () + else + let l = min blen len in + let* () = + Result.map_error + (function + | `Unix _ as e -> unix_err_to_msg e + | `Unexpected_end_of_file -> + `Msg "Unexpected end of file") + (read_complete src_fd buffer l) + in + let* _written = + Result.map_error unix_err_to_msg + (safe (Unix.write dst_fd buffer 0) l) + in + read_write ~src_fd ~dst_fd (len - l) + in + read_write ~src_fd ~dst_fd len -include Driver +let extract ?(filter = fun _ -> true) ~src dst = + let f fd ?global:_ hdr () = + if filter hdr then + match hdr.Tar.Header.link_indicator with + | Tar.Header.Link.Normal -> + let* dst = + Result.map_error unix_err_to_msg + (safe Unix.(openfile (Filename.concat dst hdr.Tar.Header.file_name) + [ O_WRONLY ; O_CREAT ]) hdr.Tar.Header.file_mode) + in + Fun.protect ~finally:(fun () -> safe_close dst) + (fun () -> copy ~src_fd:fd ~dst_fd:dst (Int64.to_int hdr.Tar.Header.file_size)) + (* TODO set owner / mode / mtime etc. *) + | _ -> + (* TODO handle directories, links, etc. *) + let* _off = + Result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) + in + Ok () + else + let* _off = + Result.map_error unix_err_to_msg + (seek fd (Int64.to_int hdr.Tar.Header.file_size)) + in + Ok () + in + fold f src () - (** Return the header needed for a particular file on disk *) -let header_of_file ?level (file: string) : Tar.Header.t = +(** Return the header needed for a particular file on disk *) +let header_of_file ?level file = let level = Tar.Header.compatibility level in - let stat = Unix.LargeFile.lstat file in + let* stat = safe Unix.LargeFile.lstat file in let file_mode = stat.Unix.LargeFile.st_perm in let user_id = stat.Unix.LargeFile.st_uid in let group_id = stat.Unix.LargeFile.st_gid in let mod_time = Int64.of_float stat.Unix.LargeFile.st_mtime in + (* TODO evaluate stat.st_kind *) let link_indicator = Tar.Header.Link.Normal in let link_name = "" in - let uname = if level = V7 then "" else (Unix.getpwuid stat.Unix.LargeFile.st_uid).Unix.pw_name in + let* uname = + if level = V7 then + Ok "" + else + try + let* passwd_entry = safe Unix.getpwuid stat.Unix.LargeFile.st_uid in + Ok passwd_entry.Unix.pw_name + with Not_found -> Error (`Msg ("No user entry found for UID")) + in let devmajor = if level = Ustar then stat.Unix.LargeFile.st_dev else 0 in - let gname = if level = V7 then "" else (Unix.getgrgid stat.Unix.LargeFile.st_gid).Unix.gr_name in + let* gname = + if level = V7 then + Ok "" + else + try + let* passwd_entry = safe Unix.getgrgid stat.Unix.LargeFile.st_gid in + Ok passwd_entry.Unix.gr_name + with Not_found -> Error (`Msg "No group entry found for GID") + in let devminor = if level = Ustar then stat.Unix.LargeFile.st_rdev else 0 in - Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name - ~uname ~gname ~devmajor ~devminor file stat.Unix.LargeFile.st_size + Ok (Tar.Header.make ~file_mode ~user_id ~group_id ~mod_time ~link_indicator ~link_name + ~uname ~gname ~devmajor ~devminor file stat.Unix.LargeFile.st_size) + +let write_strings fd datas = + let* _written = + List.fold_left (fun acc d -> + let* _written = acc in + Result.map_error unix_err_to_msg + (safe (Unix.write_substring fd d 0) (String.length d))) + (Ok 0) datas + in + Ok () + +let write_header ?level header fd = + let* header_strings = Tar.encode_header ?level header in + write_strings fd header_strings + +let append_file ?level ?header filename fd = + let* header = match header with + | None -> header_of_file ?level filename + | Some x -> Ok x + in + let* () = write_header ?level header fd in + let* src = + Result.map_error unix_err_to_msg + (safe Unix.(openfile filename [ O_RDONLY ]) 0) + in + (* TOCTOU [also, header may not be valid for file] *) + Fun.protect ~finally:(fun () -> safe_close src) + (fun () -> copy ~src_fd:src ~dst_fd:fd + (Int64.to_int header.Tar.Header.file_size)) + +let write_global_extended_header ?level header fd = + let* header_strings = Tar.encode_global_extended_header ?level header in + write_strings fd header_strings + +let write_end fd = + write_strings fd [ Tar.Header.zero_block ; Tar.Header.zero_block ] + +let create ?level ?global ?(filter = fun _ -> true) ~src dst = + let* dst_fd = + Result.map_error unix_err_to_msg + (safe Unix.(openfile dst [ O_WRONLY ; O_CREAT ]) 0o644) + in + Fun.protect ~finally:(fun () -> safe_close dst_fd) + (fun () -> + let* () = match global with + | None -> Ok () + | Some hdr -> write_global_extended_header ?level hdr dst_fd + in + let rec copy_files directory = + let* dir = safe Unix.opendir directory in + Fun.protect ~finally:(fun () -> try Unix.closedir dir with _ -> ()) + (fun () -> + let rec next () = + try + let* name = safe Unix.readdir dir in + let filename = Filename.concat directory name in + let* header = header_of_file ?level filename in + if filter header then + match header.Tar.Header.link_indicator with + | Normal -> + let* () = append_file ?level ~header filename dst_fd in + next () + | Directory -> + (* TODO first finish curdir (and close the dir fd), then go deeper *) + let* () = copy_files filename in + next () + | _ -> Ok () (* NYI *) + else Ok () + with End_of_file -> Ok () + in + next ()) + in + let* () = copy_files src in + write_end dst_fd) diff --git a/unix/tar_unix.mli b/unix/tar_unix.mli index b21ad57..3863ffd 100644 --- a/unix/tar_unix.mli +++ b/unix/tar_unix.mli @@ -16,20 +16,63 @@ (** Unix I/O for tar-formatted data. *) -val really_read: Unix.file_descr -> bytes -> unit -(** [really_read fd buf] fills [buf] with data from [fd] or raises - {!Stdlib.End_of_file}. *) +type decode_error = [ + | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] + | `Unix of Unix.error * string * string + | `Unexpected_end_of_file + | `Msg of string +] -val really_write: Unix.file_descr -> string -> unit -(** [really_write fd buf] writes the full contents of [buf] to [fd] - or {!Stdlib.End_of_file}. *) +val pp_decode_error : Format.formatter -> decode_error -> unit -val skip : Unix.file_descr -> int -> unit -(** [skip fd n] reads [n] bytes from [fd] and discards them. If possible, you - should use [Unix.lseek fd n Unix.SEEK_CUR] instead. *) +(** [fold f filename acc] folds over the tar archive. The function [f] is called + for each [hdr : Tar.Header.t]. It should forward the position in the file + descriptor by [hdr.Tar.Header.file_size]. *) +val fold : + (Unix.file_descr -> ?global:Tar.Header.Extended.t -> Tar.Header.t -> 'a -> + ('a, decode_error) result) -> + string -> 'a -> ('a, decode_error) result -(** Return the header needed for a particular file on disk. *) -val header_of_file : ?level:Tar.Header.compatibility -> string -> Tar.Header.t +(** [extract ~filter ~src dst] extracts the tar archive [src] into the + directory [dst]. If [dst] does not exist, it is created. If [filter] is + provided (defaults to [fun _ -> true]), any file where [filter hdr] returns + [false], is skipped. *) +val extract : + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, decode_error) result -module HeaderReader : Tar.HEADERREADER with type in_channel = Unix.file_descr and type 'a io = 'a -module HeaderWriter : Tar.HEADERWRITER with type out_channel = Unix.file_descr and type 'a io = 'a +(** [create ~level ~filter ~src dst] creates a tar archive at [dst]. It uses + [src], a directory name, as input. If [filter] is provided + (defaults to [fun _ -> true]), any file where [filter hdr] returns [false] + is skipped. *) +val create : ?level:Tar.Header.compatibility -> + ?global:Tar.Header.Extended.t -> + ?filter:(Tar.Header.t -> bool) -> + src:string -> string -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result + +(** [header_of_file ~level filename] returns the tar header of [filename]. *) +val header_of_file : ?level:Tar.Header.compatibility -> string -> + (Tar.Header.t, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result + +(** [append_file ~level ~header filename fd] appends the contents of [filename] + to the tar archive [fd]. If [header] is not provided, {header_of_file} is + used for constructing a header. *) +val append_file : ?level:Tar.Header.compatibility -> ?header:Tar.Header.t -> + string -> Unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result + +(** [write_header ~level hdr fd] writes the header [hdr] to [fd]. *) +val write_header : ?level:Tar.Header.compatibility -> + Tar.Header.t -> Unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result + +(** [write_global_extended_header ~level hdr fd] writes the extended header [hdr] to + [fd]. *) +val write_global_extended_header : ?level:Tar.Header.compatibility -> + Tar.Header.Extended.t -> Unix.file_descr -> + (unit, [ `Msg of string | `Unix of (Unix.error * string * string) ]) result + +(** [write_end fd] writes the tar end marker to [fd]. *) +val write_end : Unix.file_descr -> (unit, [> `Msg of string ]) result