From c3a3b7c79e23f4de61cfdaa311e2f0912d9ce9e4 Mon Sep 17 00:00:00 2001 From: Ujin Date: Fri, 5 Aug 2022 15:19:20 +0300 Subject: [PATCH] fix codec initializing --- src/WebSockets.jl | 69 ++++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/src/WebSockets.jl b/src/WebSockets.jl index c5b180ed1..7631c9eae 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -90,27 +90,18 @@ function mask!(bytes::Vector{UInt8}, mask) end return end - -function compress(data::T) where T <: AbstractVector{UInt8} - compressed = transcode(DeflateCompressor, data) - push!(compressed, 0x00) - return compressed -end - -function compress(data::String) - compressed = transcode(DeflateCompressor, data) - push!(compressed, 0x00) - return String(compressed) +function final_deflate_codecs(t::Tuple) + CodecZlib.TranscodingStreams.finalize(t[1]) + CodecZlib.TranscodingStreams.finalize(t[2]) end -function decompress(data::T) where T <: AbstractVector{UInt8} - decompressed = transcode(DeflateDecompressor, append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) - return decompressed -end +function init_deflate_codecs() + codecco = DeflateCompressor() + CodecZlib.TranscodingStreams.initialize(codecco) + codecde = DeflateDecompressor() + CodecZlib.TranscodingStreams.initialize(codecde) -function decompress(data::String) - decompressed = transcode(DeflateDecompressor, append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) - return String(decompressed) + return (codecco, codecde) end @@ -316,13 +307,13 @@ mutable struct WebSocket writebuffer::Vector{UInt8} readclosed::Bool writeclosed::Bool - isdeflate::Bool + deflate::Union{Nothing, Tuple{CodecZlib.CompressorCodec, CodecZlib.DecompressorCodec}} end const DEFAULT_MAX_FRAG = 1024 WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) = - WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate) + WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate ? init_deflate_codecs() : nothing) """ WebSockets.isclosed(ws) -> Bool @@ -330,6 +321,7 @@ WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, max Check whether a `WebSocket` has sent and received CLOSE frames. """ isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed +isdeflate(ws::WebSocket) = !isnothing(ws.deflate) # Handshake "Check whether a HTTP.Request or HTTP.Response is a websocket upgrade request/response" @@ -534,7 +526,7 @@ function Sockets.send(ws::WebSocket, x) # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.isdeflate : false)) + n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? isdeflate(ws) : false)) first = false nextstate === nothing && break item, st = nextstate @@ -543,8 +535,8 @@ function Sockets.send(ws::WebSocket, x) else # single binary or text frame for message @label write_single_frame - pl = ws.isdeflate ? compress(payload(ws, x)) : payload(ws, x) - return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.isdeflate)) + pl = isdeflate(ws) ? compress(ws, payload(ws, x)) : payload(ws, x) + return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=isdeflate(ws))) end end @@ -559,7 +551,7 @@ to when a PING message is received by a websocket connection. function ping(ws::WebSocket, data=UInt8[]) @require !ws.writeclosed @debugv 2 "$(ws.id): sending ping" - return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, data))) + return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, isdeflate(ws) ? compress(ws, data) : data))) end """ @@ -620,11 +612,34 @@ function Base.close(ws::WebSocket, body::CloseFrameBody=CloseFrameBody(1000, "") @assert ws.readclosed # if we're the server, it's our job to close the underlying socket !ws.client && isopen(ws.io) && close(ws.io) + final_deflate_codecs(ws.deflate) return end # Receiving messages +function compress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return compressed +end + +function compress(ws::WebSocket, data::String) + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return String(compressed) +end + +function decompress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + decompressed = transcode(ws.deflate[2], append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return decompressed +end + +function decompress(ws::WebSocket, data::String) + decompressed = transcode(ws.deflate[2], append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return String(decompressed) +end + # returns whether additional frames should be read # true if fragmented message or a ping/pong frame was handled @noinline control_len_check(len) = len > 125 && throw(WebSocketError(CloseFrameBody(1002, "Invalid length for control frame"))) @@ -644,7 +659,7 @@ function checkreadframe!(ws::WebSocket, frame::Frame) if !ws.writeclosed close(ws) end - throw(WebSocketError(frame.payload)) + throw(WebSocketError(isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload)) elseif opcode == PING control_len_check(frame.flags.len) pong(ws, frame.payload) @@ -686,7 +701,7 @@ function receive(ws::WebSocket) done = checkreadframe!(ws, frame) # common case of reading single non-control frame if done - payload = ws.isdeflate ? decompress(frame.payload) : frame.payload + payload = isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload payload isa String && utf8check(payload) return payload end @@ -704,7 +719,7 @@ function receive(ws::WebSocket) end done && break end - payload = ws.isdeflate ? decompress(payload) : payload + payload = isdeflate(ws) ? decompress(ws, payload) : payload payload isa String && utf8check(payload) @debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])" return payload