Skip to content

Commit db76373

Browse files
committed
allow passing a Accept-Encoding header to PkgServer to fetch e.g. zstd compressed files from the storage server
for backwards compatibility, return gzipped files when no `Accept-Encoding` is provided
1 parent 73c3532 commit db76373

File tree

3 files changed

+122
-17
lines changed

3 files changed

+122
-17
lines changed

src/PkgServer.jl

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,49 @@ include("meta.jl")
2424
include("admin.jl")
2525
include("dynamic.jl")
2626

27+
# Detect compression format by reading file magic bytes
28+
function detect_file_compression(filepath::AbstractString)::Union{String,Nothing}
29+
magic = open(filepath, "r") do io
30+
read(io, min(4, filesize(io)))
31+
end
32+
33+
# Zstd
34+
if length(magic) >= 4 && magic[1:4] == [0x28, 0xB5, 0x2F, 0xFD]
35+
return "zstd"
36+
end
37+
# Gzip
38+
if length(magic) >= 2 && magic[1:2] == [0x1F, 0x8B]
39+
return "gzip"
40+
end
41+
42+
# Not enough bytes to detect or unknown format
43+
return nothing
44+
end
45+
46+
# Check if client accepts a given compression format
47+
function client_accepts_format(http::HTTP.Stream, format::String)::Bool
48+
accept_encoding = HTTP.header(http, "Accept-Encoding", "")
49+
# Empty header means only gzip is accepted (backwards compatibility)
50+
if isempty(accept_encoding)
51+
return format == "gzip"
52+
end
53+
return occursin(format, accept_encoding)
54+
end
55+
56+
# Get content type based on file format
57+
function get_content_type(filepath::AbstractString)
58+
file_format = detect_file_compression(filepath)
59+
60+
if file_format == "zstd"
61+
return "application/x-zstd"
62+
elseif file_format == "gzip"
63+
return "application/x-gzip"
64+
else
65+
# Unknown or insufficient bytes - return generic octet-stream
66+
return "application/octet-stream"
67+
end
68+
end
69+
2770
mutable struct RegistryMeta
2871
# Upstream registry URL (e.g. "https://github.com/JuliaRegistries/General")
2972
upstream_url::String
@@ -294,20 +337,30 @@ function handle_request(http::HTTP.Stream)
294337

295338
# If the user asked for something that is an actual resource, send it directly
296339
if occursin(resource_re, resource)
297-
# If the resource already exists locally, yay! Serve it and quit.
340+
# If the resource already exists locally, check if client can accept the format
298341
resource_path = resource_filepath(resource)
299342
io = try_open(resource_path)
300343
if io !== nothing
301-
hit!(config.cache, resource[2:end])
302-
serve_file(http, io, "application/x-gzip")
303-
close(io)
304-
return
344+
cached_format = detect_file_compression(resource_path)
345+
346+
# Serve from cache if format is unknown or client accepts it
347+
if cached_format === nothing || client_accepts_format(http, cached_format)
348+
hit!(config.cache, resource[2:end])
349+
content_type = get_content_type(resource_path)
350+
serve_file(http, io, content_type)
351+
close(io)
352+
return
353+
else
354+
# Cached format not acceptable, close and re-fetch in desired format
355+
close(io)
356+
# Fall through to fetch_resource below
357+
end
305358
end
306359

307360
# If it doesn't exist locally, let's request a fetch on that resource.
308361
# This will return either `nothing` (e.g. resource does not exist) or
309362
# a `DownloadState` that represents a partial download.
310-
dl_state = fetch_resource(resource, request_id)
363+
dl_state = fetch_resource(resource, request_id, http)
311364
if dl_state !== nothing
312365
HTTP.setheader(http, "X-Cache-Miss" => "miss")
313366
stream_path = temp_resource_filepath(resource)
@@ -318,7 +371,8 @@ function handle_request(http::HTTP.Stream)
318371
# Try to serve `stream_path` file
319372
stream_io = try_open(stream_path)
320373
if stream_io !== nothing
321-
serve_file(http, stream_io, "application/x-gzip";
374+
content_type = get_content_type(stream_path)
375+
serve_file(http, stream_io, content_type;
322376
content_length=dl_state.content_length,
323377
dl_task=dl_state.dl_task)
324378
close(stream_io)
@@ -329,7 +383,8 @@ function handle_request(http::HTTP.Stream)
329383
# downloading since we last checked 20 lines ago. Check again.
330384
io = try_open(resource_path)
331385
if io !== nothing
332-
serve_file(http, io, "application/x-gzip")
386+
content_type = get_content_type(resource_path)
387+
serve_file(http, io, content_type)
333388
close(io)
334389
return
335390
end
@@ -377,7 +432,7 @@ function handle_request(http::HTTP.Stream)
377432
return
378433
end
379434

380-
# precompilation
435+
# precompilation
381436
include(joinpath(dirname(@__DIR__), "deps", "precompile.jl"))
382437
if get(ENV, "PKGSERVER_GENERATING_PRECOMPILE", nothing) === nothing
383438
_precompile_()

src/resource.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,20 +257,21 @@ Given a resource and a list of storage servers, return the storage server that r
257257
most quickly to a HEAD request for that storage server, as well as the HEAD response
258258
itself so that metadata such as the content-length of the resource can be inspected.
259259
"""
260-
function select_server(resource::AbstractString, servers::Vector{<:AbstractString}; timeout = 5, retries = 2)
261-
function head_req(server, resource)
260+
function select_server(resource::AbstractString, servers::Vector{<:AbstractString}, accept_encoding::AbstractString="gzip"; timeout = 5, retries = 2)
261+
function head_req(server, resource, accept_encoding)
262262
@try_printerror begin
263263
response = HTTP.head(
264264
string(server, resource);
265265
status_exception = false,
266266
timeout = timeout,
267267
retries = retries,
268+
headers = ["Accept-Encoding" => accept_encoding],
268269
)
269270
return server, response
270271
end
271272
end
272273
# Launch one task per server, performing a HEAD request
273-
tasks = [@spawn head_req(server, resource) for server in servers]
274+
tasks = [@spawn head_req(server, resource, accept_encoding) for server in servers]
274275

275276
# Wait for the first Task that gives us an HTTP 200 OK, returning that server.
276277
# If none have it, we return `nothing`. :(
@@ -314,12 +315,26 @@ are recorded and future downloads of that same resource will be skipped, until
314315
`forget_failures()` is called. The `DownloadState` object contains within it enough
315316
information to still serve a resource as it is being downloaded in the background task.
316317
"""
317-
function fetch_resource(resource::AbstractString, request_id::AbstractString; servers::Vector{String}=config.storage_servers)
318+
function fetch_resource(resource::AbstractString, request_id::AbstractString, http::Union{HTTP.Stream,Nothing}=nothing; servers::Vector{String}=config.storage_servers)
318319
if isempty(servers)
319320
@error("fetch called with no servers", resource)
320321
error("fetch called with no servers")
321322
end
322323

324+
# Determine what encoding to request from storage servers based on client preferences
325+
# For backwards compatibility, if no Accept-Encoding header, default to gzip
326+
if http !== nothing
327+
client_accept = HTTP.header(http, "Accept-Encoding", "")
328+
if isempty(client_accept)
329+
accept_encoding = "gzip" # no header means gzip
330+
else
331+
accept_encoding = client_accept
332+
end
333+
else
334+
# No client context, prefer zstd for storage efficiency
335+
accept_encoding = "zstd, gzip"
336+
end
337+
323338
# with_fetch_state() will wait for a lock
324339
with_fetch_state(resource) do state
325340
# check if this has failed to download recently
@@ -335,15 +350,15 @@ function fetch_resource(resource::AbstractString, request_id::AbstractString; se
335350
end
336351

337352
# If not, let's figure out which storage server we're going to download from:
338-
server, response = select_server(resource, servers)
353+
server, response = select_server(resource, servers, accept_encoding)
339354
if response === nothing
340355
@debug("no upstream server", resource, servers)
341356
return nothing
342357
end
343358

344359
# Launch download process in a separate task:
345360
dl_task = @async begin
346-
success = download(server, resource, content_length(response), request_id)
361+
success = download(server, resource, content_length(response), request_id, accept_encoding)
347362
lock(state.lock) do
348363
if success
349364
global fetch_hits += 1
@@ -422,8 +437,8 @@ function stream_file(io_in::IO, start_byte::Int, length::Int, dl_task::Task, io_
422437
return transmitted
423438
end
424439

425-
function download(server::AbstractString, resource::AbstractString, content_length::Int, request_id::AbstractString)
426-
@info("downloading resource", server, resource, request_id)
440+
function download(server::AbstractString, resource::AbstractString, content_length::Int, request_id::AbstractString, accept_encoding::AbstractString="gzip")
441+
@info("downloading resource", server, resource, request_id, accept_encoding)
427442
t_start = time()
428443
hash = basename(resource)
429444

@@ -465,6 +480,8 @@ function download(server::AbstractString, resource::AbstractString, content_leng
465480
req = HTTP.get(server * resource,
466481
status_exception = false,
467482
response_stream = file_io,
483+
headers = ["Accept-Encoding" => accept_encoding],
484+
decompress = false,
468485
)
469486
close(file_io)
470487
return req

test/tests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,32 @@ end
7171
@test endswith(eager_response.request.target, ".eager")
7272

7373
# Test asking for that registry directly, unpacking it and verifying the treehash
74+
# Clients without Accept-Encoding header should get gzip
7475
mktemp() do tarball_path, tarball_io
7576
response = HTTP.get("$(server_url)/registry/$(registry_uuid)/$(registry_treehash)"; response_stream=tarball_io)
7677
close(tarball_io)
7778
@test response.status == 200
79+
80+
format = PkgServer.detect_file_compression(tarball_path)
81+
@test format == "gzip"
82+
7883
@test registry_treehash == Tar.tree_hash(open(pipeline(`cat $(tarball_path)`, `gzip -d`), read=true))
7984
end
8085

86+
# Clients that only accept zstd should get zstd
87+
mktemp() do tarball_path, tarball_io
88+
response = HTTP.get("$(server_url)/registry/$(registry_uuid)/$(registry_treehash)",
89+
["Accept-Encoding" => "zstd"];
90+
response_stream=tarball_io)
91+
close(tarball_io)
92+
@test response.status == 200
93+
94+
format = PkgServer.detect_file_compression(tarball_path)
95+
@test format == "zstd"
96+
97+
@test registry_treehash == Tar.tree_hash(open(pipeline(`zstd -d -c $(tarball_path)`), read=true))
98+
end
99+
81100
# Verify that these files exist within the cache
82101
@test isfile(joinpath(cache_dir, "..", "static", "registries.eager"))
83102
@test isfile(joinpath(cache_dir, "..", "static", "registries.conservative"))
@@ -278,6 +297,20 @@ end
278297
# Also test that it's available at its nskip hash:
279298
@test HTTP.head("$(server_url)/artifact/$(art_yskip_hash)").status == 200
280299
end
300+
301+
# Test artifact download with zstd compression
302+
mktemp() do tarball_path, tarball_io
303+
response = HTTP.get("$(server_url)/artifact/$(art_yskip_hash)",
304+
["Accept-Encoding" => "zstd"];
305+
response_stream=tarball_io)
306+
close(tarball_io)
307+
@test response.status == 200
308+
309+
format = PkgServer.detect_file_compression(tarball_path)
310+
@test format == "zstd"
311+
312+
@test art_nskip_hash == Tar.tree_hash(open(pipeline(`zstd -d -c $(tarball_path)`), read=true))
313+
end
281314
end
282315

283316
@testset "Partial Content" begin

0 commit comments

Comments
 (0)