diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 09326ed4..1f09c376 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -299,12 +299,12 @@ func sharedData(t *testing.T, ps envoyPorts, stdErr *bytes.Buffer) { func sharedQueue(t *testing.T, ps envoyPorts, stdErr *bytes.Buffer) { req, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d", ps.endpoint), nil) - require.NoError(t, err, stdErr.String()) + require.NoError(t, err) count := 10 for i := 0; i < count; i++ { r, err := http.DefaultClient.Do(req) - require.NoError(t, err, stdErr.String()) + require.NoError(t, err) r.Body.Close() } diff --git a/examples/configuration_from_root/main.go b/examples/configuration_from_root/main.go index 9dbe8737..493c672c 100644 --- a/examples/configuration_from_root/main.go +++ b/examples/configuration_from_root/main.go @@ -20,13 +20,11 @@ import ( func main() { proxywasm.SetNewRootContext(newRootContext) - proxywasm.SetNewHttpContext(newHttpContext) } type rootContext struct { // you must embed the default context so that you need not to reimplement all the methods by yourself proxywasm.DefaultRootContext - config []byte } @@ -44,29 +42,14 @@ func (ctx *rootContext) OnPluginStart(pluginConfigurationSize int) bool { return true } +func (ctx *rootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + ret := &httpContext{config: ctx.config} + proxywasm.LogInfof("read plugin config from root context: %s\n", string(ret.config)) + return ret +} + type httpContext struct { proxywasm.DefaultHttpContext config []byte } - -func newHttpContext(rootContextID, contextID uint32) proxywasm.HttpContext { - ctx := &httpContext{} - - rootCtx, err := proxywasm.GetRootContextByID(rootContextID) - if err != nil { - proxywasm.LogErrorf("unable to get root context: %v", err) - - return ctx - } - - exampleRootCtx, ok := rootCtx.(*rootContext) - if !ok { - proxywasm.LogError("could not cast root context") - } - - ctx.config = exampleRootCtx.config - - proxywasm.LogInfof("plugin config from root context: %s\n", string(ctx.config)) - return ctx -} diff --git a/examples/configuration_from_root/main_test.go b/examples/configuration_from_root/main_test.go index 4adca579..5bf55366 100644 --- a/examples/configuration_from_root/main_test.go +++ b/examples/configuration_from_root/main_test.go @@ -31,7 +31,6 @@ func TestContext_OnPluginStart(t *testing.T) { opt := proxytest.NewEmulatorOption(). WithNewRootContext(newRootContext). - WithNewHttpContext(newHttpContext). WithPluginConfiguration([]byte(pluginConfigData)) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the emulation lock so that other test cases can insert their own host emulation diff --git a/examples/http_auth_random/main.go b/examples/http_auth_random/main.go index 7f40a9a8..c4489897 100644 --- a/examples/http_auth_random/main.go +++ b/examples/http_auth_random/main.go @@ -24,7 +24,17 @@ import ( const clusterName = "httpbin" func main() { - proxywasm.SetNewHttpContext(newContext) + proxywasm.SetNewRootContext(newRootContext) +} + +type rootContext struct { + proxywasm.DefaultRootContext +} + +func newRootContext(uint32) proxywasm.RootContext { return &rootContext{} } + +func (*rootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &httpAuthRandom{contextID: contextID} } type httpAuthRandom struct { @@ -33,10 +43,6 @@ type httpAuthRandom struct { contextID uint32 } -func newContext(rootContextID, contextID uint32) proxywasm.HttpContext { - return &httpAuthRandom{contextID: contextID} -} - // override default func (ctx *httpAuthRandom) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { hs, err := proxywasm.GetHttpRequestHeaders() diff --git a/examples/http_auth_random/main_test.go b/examples/http_auth_random/main_test.go index 1d29cc05..c3d7a3a8 100644 --- a/examples/http_auth_random/main_test.go +++ b/examples/http_auth_random/main_test.go @@ -12,7 +12,7 @@ import ( func TestHttpAuthRandom_OnHttpRequestHeaders(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newContext) + WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() @@ -35,7 +35,7 @@ func TestHttpAuthRandom_OnHttpRequestHeaders(t *testing.T) { func TestHttpAuthRandom_OnHttpCallResponse(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newContext) + WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() diff --git a/examples/http_body/main.go b/examples/http_body/main.go index e0c480a5..ccd74d2e 100644 --- a/examples/http_body/main.go +++ b/examples/http_body/main.go @@ -20,7 +20,17 @@ import ( ) func main() { - proxywasm.SetNewHttpContext(newContext) + proxywasm.SetNewRootContext(newContext) +} + +type rootContext struct { + proxywasm.DefaultRootContext +} + +func newContext(uint32) proxywasm.RootContext { return &rootContext{} } + +func (*rootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &httpBody{contextID: contextID} } type httpBody struct { @@ -29,10 +39,6 @@ type httpBody struct { contextID uint32 } -func newContext(rootContextID, contextID uint32) proxywasm.HttpContext { - return &httpBody{contextID: contextID} -} - // override func (ctx *httpBody) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { proxywasm.LogInfof("body size: %d", bodySize) diff --git a/examples/http_body/main_test.go b/examples/http_body/main_test.go index e5824d52..028b711c 100644 --- a/examples/http_body/main_test.go +++ b/examples/http_body/main_test.go @@ -11,7 +11,7 @@ import ( func TestHttpBody_OnHttpRequestBody(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newContext) + WithNewRootContext(newContext) host := proxytest.NewHostEmulator(opt) defer host.Done() diff --git a/examples/http_headers/main.go b/examples/http_headers/main.go index 20c27e66..767ad788 100644 --- a/examples/http_headers/main.go +++ b/examples/http_headers/main.go @@ -20,7 +20,17 @@ import ( ) func main() { - proxywasm.SetNewHttpContext(newContext) + proxywasm.SetNewRootContext(newRootContext) +} + +type rootContext struct { + proxywasm.DefaultRootContext +} + +func newRootContext(uint32) proxywasm.RootContext { return &rootContext{} } + +func (*rootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &httpHeaders{contextID: contextID} } type httpHeaders struct { @@ -29,10 +39,6 @@ type httpHeaders struct { contextID uint32 } -func newContext(rootContextID, contextID uint32) proxywasm.HttpContext { - return &httpHeaders{contextID: contextID} -} - // override func (ctx *httpHeaders) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { hs, err := proxywasm.GetHttpRequestHeaders() diff --git a/examples/http_headers/main_test.go b/examples/http_headers/main_test.go index 9e1f70a8..dfc5aa7d 100644 --- a/examples/http_headers/main_test.go +++ b/examples/http_headers/main_test.go @@ -13,7 +13,7 @@ import ( func TestHttpHeaders_OnHttpRequestHeaders(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newContext) + WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() id := host.HttpFilterInitContext() @@ -33,7 +33,7 @@ func TestHttpHeaders_OnHttpRequestHeaders(t *testing.T) { func TestHttpHeaders_OnHttpResponseHeaders(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newContext) + WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() id := host.HttpFilterInitContext() diff --git a/examples/metrics/main.go b/examples/metrics/main.go index 40120672..4adc5cff 100644 --- a/examples/metrics/main.go +++ b/examples/metrics/main.go @@ -21,7 +21,6 @@ import ( func main() { proxywasm.SetNewRootContext(newRootContext) - proxywasm.SetNewHttpContext(newHttpContext) } var counter proxywasm.MetricCounter @@ -43,15 +42,15 @@ func (ctx *metricRootContext) OnVMStart(vmConfigurationSize int) bool { return true } +func (*metricRootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &metricHttpContext{} +} + type metricHttpContext struct { // you must embed the default context so that you need not to reimplement all the methods by yourself proxywasm.DefaultHttpContext } -func newHttpContext(uint32, uint32) proxywasm.HttpContext { - return &metricHttpContext{} -} - // override func (ctx *metricHttpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { prev := counter.Get() diff --git a/examples/metrics/main_test.go b/examples/metrics/main_test.go index 2a62861d..6574dbb0 100644 --- a/examples/metrics/main_test.go +++ b/examples/metrics/main_test.go @@ -12,7 +12,6 @@ import ( func TestMetric(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newHttpContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation diff --git a/examples/network/main.go b/examples/network/main.go index 57f49e7d..50683c84 100644 --- a/examples/network/main.go +++ b/examples/network/main.go @@ -26,7 +26,10 @@ var ( func main() { proxywasm.SetNewRootContext(newRootContext) - proxywasm.SetNewStreamContext(newNetworkContext) +} + +func newRootContext(contextID uint32) proxywasm.RootContext { + return &rootContext{} } type rootContext struct { @@ -34,24 +37,20 @@ type rootContext struct { proxywasm.DefaultRootContext } -func newRootContext(contextID uint32) proxywasm.RootContext { - return &rootContext{} -} - func (ctx *rootContext) OnVMStart(vmConfigurationSize int) bool { counter = proxywasm.DefineCounterMetric(connectionCounterName) return true } +func (ctx *rootContext) NewStreamContext(contextID uint32) proxywasm.StreamContext { + return &networkContext{} +} + type networkContext struct { // you must embed the default context so that you need not to reimplement all the methods by yourself proxywasm.DefaultStreamContext } -func newNetworkContext(rootContextID, contextID uint32) proxywasm.StreamContext { - return &networkContext{} -} - func (ctx *networkContext) OnNewConnection() types.Action { proxywasm.LogInfo("new connection!") return types.ActionContinue diff --git a/examples/network/main_test.go b/examples/network/main_test.go index 46587bef..54836553 100644 --- a/examples/network/main_test.go +++ b/examples/network/main_test.go @@ -26,7 +26,6 @@ import ( func TestNetwork_OnNewConnection(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewStreamContext(newNetworkContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation @@ -41,7 +40,6 @@ func TestNetwork_OnNewConnection(t *testing.T) { func TestNetwork_OnDownstreamClose(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewStreamContext(newNetworkContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation @@ -56,7 +54,6 @@ func TestNetwork_OnDownstreamClose(t *testing.T) { func TestNetwork_OnDownstreamData(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewStreamContext(newNetworkContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation @@ -73,7 +70,6 @@ func TestNetwork_OnDownstreamData(t *testing.T) { func TestNetwork_OnUpstreamData(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewStreamContext(newNetworkContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation @@ -90,7 +86,6 @@ func TestNetwork_OnUpstreamData(t *testing.T) { func TestNetwork_counter(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewStreamContext(newNetworkContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation diff --git a/examples/shared_data/main.go b/examples/shared_data/main.go index 0e58e826..e86bc5f1 100644 --- a/examples/shared_data/main.go +++ b/examples/shared_data/main.go @@ -21,7 +21,6 @@ import ( func main() { proxywasm.SetNewRootContext(newRootContext) - proxywasm.SetNewHttpContext(newHttpContext) } type ( @@ -40,10 +39,6 @@ func newRootContext(contextID uint32) proxywasm.RootContext { return &sharedDataRootContext{} } -func newHttpContext(rootContextID, contextID uint32) proxywasm.HttpContext { - return &sharedDataHttpContext{} -} - const sharedDataKey = "shared_data_key" // override @@ -54,6 +49,11 @@ func (ctx *sharedDataRootContext) OnVMStart(vmConfigurationSize int) bool { return true } +// override +func (*sharedDataRootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &sharedDataHttpContext{} +} + // override func (ctx *sharedDataHttpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { value, cas, err := proxywasm.GetSharedData(sharedDataKey) diff --git a/examples/shared_data/main_test.go b/examples/shared_data/main_test.go index 856b576a..5aaf600f 100644 --- a/examples/shared_data/main_test.go +++ b/examples/shared_data/main_test.go @@ -26,7 +26,6 @@ import ( func TestData(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newHttpContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation diff --git a/examples/shared_queue/main.go b/examples/shared_queue/main.go index 98fe4754..f05ed972 100644 --- a/examples/shared_queue/main.go +++ b/examples/shared_queue/main.go @@ -26,7 +26,6 @@ const ( func main() { proxywasm.SetNewRootContext(newRootContext) - proxywasm.SetNewHttpContext(newHttpContext) } type queueRootContext struct { @@ -70,15 +69,16 @@ func (ctx *queueRootContext) OnTick() { } } +// override +func (*queueRootContext) NewHttpContext(contextID uint32) proxywasm.HttpContext { + return &queueHttpContext{} +} + type queueHttpContext struct { // you must embed the default context so that you need not to reimplement all the methods by yourself proxywasm.DefaultHttpContext } -func newHttpContext(rootContextID, contextID uint32) proxywasm.HttpContext { - return &queueHttpContext{} -} - // override func (ctx *queueHttpContext) OnHttpRequestHeaders(int, bool) types.Action { for _, msg := range []string{"hello", "world", "hello", "proxy-wasm"} { diff --git a/examples/shared_queue/main_test.go b/examples/shared_queue/main_test.go index 5a09945a..7aa11fbb 100644 --- a/examples/shared_queue/main_test.go +++ b/examples/shared_queue/main_test.go @@ -27,7 +27,6 @@ import ( func TestQueue(t *testing.T) { opt := proxytest.NewEmulatorOption(). - WithNewHttpContext(newHttpContext). WithNewRootContext(newRootContext) host := proxytest.NewHostEmulator(opt) defer host.Done() // release the host emulation lock so that other test cases can insert their own host emulation diff --git a/proxytest/option.go b/proxytest/option.go index dcfa098f..6e8651bb 100644 --- a/proxytest/option.go +++ b/proxytest/option.go @@ -19,8 +19,6 @@ import "github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm" type EmulatorOption struct { pluginConfiguration, vmConfiguration []byte newRootContext func(uint32) proxywasm.RootContext - newStreamContext func(uint32, uint32) proxywasm.StreamContext - newHttpContext func(uint32, uint32) proxywasm.HttpContext } func NewEmulatorOption() *EmulatorOption { @@ -32,16 +30,6 @@ func (o *EmulatorOption) WithNewRootContext(f func(uint32) proxywasm.RootContext return o } -func (o *EmulatorOption) WithNewHttpContext(f func(uint32, uint32) proxywasm.HttpContext) *EmulatorOption { - o.newHttpContext = f - return o -} - -func (o *EmulatorOption) WithNewStreamContext(f func(uint32, uint32) proxywasm.StreamContext) *EmulatorOption { - o.newStreamContext = f - return o -} - func (o *EmulatorOption) WithPluginConfiguration(data []byte) *EmulatorOption { o.pluginConfiguration = data return o diff --git a/proxytest/proxytest.go b/proxytest/proxytest.go index f1f5e25d..d3137f3b 100644 --- a/proxytest/proxytest.go +++ b/proxytest/proxytest.go @@ -102,8 +102,6 @@ func NewHostEmulator(opt *EmulatorOption) HostEmulator { // set up state proxywasm.SetNewRootContext(opt.newRootContext) - proxywasm.SetNewStreamContext(opt.newStreamContext) - proxywasm.SetNewHttpContext(opt.newHttpContext) // create root context: TODO: support multiple root contexts proxywasm.ProxyOnContextCreate(RootContextID, 0) diff --git a/proxywasm/abi_lifecycle.go b/proxywasm/abi_lifecycle.go index 911a7d8b..c20dd292 100644 --- a/proxywasm/abi_lifecycle.go +++ b/proxywasm/abi_lifecycle.go @@ -18,10 +18,8 @@ package proxywasm func proxyOnContextCreate(contextID uint32, rootContextID uint32) { if rootContextID == 0 { currentState.createRootContext(contextID) - } else if currentState.newHttpContext != nil { - currentState.createHttpContext(contextID, rootContextID) - } else if currentState.newStreamContext != nil { - currentState.createStreamContext(contextID, rootContextID) + } else if currentState.createHttpContext(contextID, rootContextID) { + } else if currentState.createStreamContext(contextID, rootContextID) { } else { panic("invalid context id on proxy_on_context_create") } diff --git a/proxywasm/abi_lifecycle_test.go b/proxywasm/abi_lifecycle_test.go index 7bfb88be..55406eb2 100644 --- a/proxywasm/abi_lifecycle_test.go +++ b/proxywasm/abi_lifecycle_test.go @@ -22,39 +22,55 @@ import ( "github.com/stretchr/testify/require" ) -func Test_proxyOnContextCreate(t *testing.T) { +type testOnContextCreateRootContext struct { + DefaultRootContext + cnt int +} + +func (ctx *testOnContextCreateRootContext) NewStreamContext(contextID uint32) StreamContext { + if contextID == 100 { + ctx.cnt += 100 + return &DefaultStreamContext{} + } + return nil +} + +func (ctx *testOnContextCreateRootContext) NewHttpContext(contextID uint32) HttpContext { + if contextID == 1000 { + ctx.cnt += 1000 + return &DefaultHttpContext{} + } + return nil +} + +func Test_proxyOnContextCreateHttpContext(t *testing.T) { currentStateMux.Lock() defer currentStateMux.Unlock() - var cnt int + var rootPtr *testOnContextCreateRootContext currentState = &state{ - rootContexts: map[uint32]*rootContextState{}, - httpStreams: map[uint32]HttpContext{}, - streams: map[uint32]StreamContext{}, + rootContexts: map[uint32]*rootContextState{}, + httpStreams: map[uint32]HttpContext{}, + streams: map[uint32]StreamContext{}, + newRootContext: func(contextID uint32) RootContext { + return &testOnContextCreateRootContext{} + }, contextIDToRootID: map[uint32]uint32{}, } SetNewRootContext(func(contextID uint32) RootContext { - cnt++ - return nil + rootPtr = &testOnContextCreateRootContext{cnt: 1} + return rootPtr }) - proxyOnContextCreate(100, 0) - require.Equal(t, 1, cnt) - SetNewHttpContext(func(rootContextID, contextID uint32) HttpContext { - cnt += 100 - return nil - }) - proxyOnContextCreate(100, 100) - require.Equal(t, 101, cnt) - currentState.newHttpContext = nil + proxyOnContextCreate(1, 0) + require.Equal(t, 1, rootPtr.cnt) - SetNewStreamContext(func(rootContextID, contextID uint32) StreamContext { - cnt += 1000 - return nil - }) - proxyOnContextCreate(100, 100) - require.Equal(t, 1101, cnt) + proxyOnContextCreate(100, 1) + require.Equal(t, 101, rootPtr.cnt) + + proxyOnContextCreate(1000, 1) + require.Equal(t, 1101, rootPtr.cnt) } type lifecycleContext struct { diff --git a/proxywasm/context.go b/proxywasm/context.go index 37e08212..aa6dfdee 100644 --- a/proxywasm/context.go +++ b/proxywasm/context.go @@ -25,6 +25,10 @@ type RootContext interface { OnPluginStart(pluginConfigurationSize int) bool OnVMDone() bool OnLog() + + // Child context factories + NewStreamContext(contextID uint32) StreamContext + NewHttpContext(contextID uint32) HttpContext } type StreamContext interface { @@ -61,12 +65,14 @@ var ( ) // impl RootContext -func (*DefaultRootContext) OnQueueReady(uint32) {} -func (*DefaultRootContext) OnTick() {} -func (*DefaultRootContext) OnVMStart(int) bool { return true } -func (*DefaultRootContext) OnPluginStart(int) bool { return true } -func (*DefaultRootContext) OnVMDone() bool { return true } -func (*DefaultRootContext) OnLog() {} +func (*DefaultRootContext) OnQueueReady(uint32) {} +func (*DefaultRootContext) OnTick() {} +func (*DefaultRootContext) OnVMStart(int) bool { return true } +func (*DefaultRootContext) OnPluginStart(int) bool { return true } +func (*DefaultRootContext) OnVMDone() bool { return true } +func (*DefaultRootContext) OnLog() {} +func (*DefaultRootContext) NewStreamContext(uint32) StreamContext { return nil } +func (*DefaultRootContext) NewHttpContext(uint32) HttpContext { return nil } // impl StreamContext func (*DefaultStreamContext) OnDownstreamData(int, bool) types.Action { return types.ActionContinue } diff --git a/proxywasm/vmstate.go b/proxywasm/vmstate.go index 4d106e5d..d6747d23 100644 --- a/proxywasm/vmstate.go +++ b/proxywasm/vmstate.go @@ -34,12 +34,10 @@ type ( ) type state struct { - newRootContext func(contextID uint32) RootContext - rootContexts map[uint32]*rootContextState - newStreamContext func(rootContextID, contextID uint32) StreamContext - streams map[uint32]StreamContext - newHttpContext func(rootContextID, contextID uint32) HttpContext - httpStreams map[uint32]HttpContext + newRootContext func(contextID uint32) RootContext + rootContexts map[uint32]*rootContextState + streams map[uint32]StreamContext + httpStreams map[uint32]HttpContext contextIDToRootID map[uint32]uint32 activeContextID uint32 @@ -56,14 +54,6 @@ func SetNewRootContext(f func(contextID uint32) RootContext) { currentState.newRootContext = f } -func SetNewHttpContext(f func(rootContextID, contextID uint32) HttpContext) { - currentState.newHttpContext = f -} - -func SetNewStreamContext(f func(rootContextID, contextID uint32) StreamContext) { - currentState.newStreamContext = f -} - var ErrorRootContextNotFound = errors.New("root context not found") func GetRootContextByID(rootContextID uint32) (RootContext, error) { @@ -74,7 +64,6 @@ func GetRootContextByID(rootContextID uint32) (RootContext, error) { return rootContextState.context, nil } -//go:inline func (s *state) createRootContext(contextID uint32) { var ctx RootContext if s.newRootContext == nil { @@ -95,8 +84,9 @@ func (s *state) createRootContext(contextID uint32) { s.contextIDToRootID[contextID] = contextID } -func (s *state) createStreamContext(contextID uint32, rootContextID uint32) { - if _, ok := s.rootContexts[rootContextID]; !ok { +func (s *state) createStreamContext(contextID uint32, rootContextID uint32) bool { + root, ok := s.rootContexts[rootContextID] + if !ok { panic("invalid root context id") } @@ -104,13 +94,19 @@ func (s *state) createStreamContext(contextID uint32, rootContextID uint32) { panic("context id duplicated") } - ctx := s.newStreamContext(rootContextID, contextID) + ctx := root.context.NewStreamContext(contextID) + if ctx == nil { + // NewStreamContext is not defined by the user + return false + } s.contextIDToRootID[contextID] = rootContextID s.streams[contextID] = ctx + return true } -func (s *state) createHttpContext(contextID uint32, rootContextID uint32) { - if _, ok := s.rootContexts[rootContextID]; !ok { +func (s *state) createHttpContext(contextID uint32, rootContextID uint32) bool { + root, ok := s.rootContexts[rootContextID] + if !ok { panic("invalid root context id") } @@ -118,9 +114,14 @@ func (s *state) createHttpContext(contextID uint32, rootContextID uint32) { panic("context id duplicated") } - ctx := s.newHttpContext(rootContextID, contextID) + ctx := root.context.NewHttpContext(contextID) + if ctx == nil { + // NewHttpContext is not defined by the user + return false + } s.contextIDToRootID[contextID] = rootContextID s.httpStreams[contextID] = ctx + return true } func (s *state) registerHttpCallOut(calloutID uint32, callback HttpCalloutCallBack) { diff --git a/proxywasm/vmstate_test.go b/proxywasm/vmstate_test.go index 9fd90b6c..c6cf9ec8 100644 --- a/proxywasm/vmstate_test.go +++ b/proxywasm/vmstate_test.go @@ -38,34 +38,6 @@ func TestSetNewRootContext(t *testing.T) { assert.Equal(t, 1, cnt) } -func TestSetNewHttpContext(t *testing.T) { - currentStateMux.Lock() - defer currentStateMux.Unlock() - - var cnt int - f := func(uint32, uint32) HttpContext { - cnt++ - return nil - } - SetNewHttpContext(f) - currentState.newHttpContext(0, 0) - assert.Equal(t, 1, cnt) -} - -func TestSetNewStreamContext(t *testing.T) { - currentStateMux.Lock() - defer currentStateMux.Unlock() - - var cnt int - f := func(uint32, uint32) StreamContext { - cnt++ - return nil - } - SetNewStreamContext(f) - currentState.newStreamContext(0, 0) - assert.Equal(t, 1, cnt) -} - func TestState_createRootContext(t *testing.T) { t.Run("newRootContext exists", func(t *testing.T) { type rc struct{ DefaultRootContext } @@ -91,45 +63,68 @@ func TestState_createRootContext(t *testing.T) { }) } -func TestState_createStreamContext(t *testing.T) { - type sc struct{ DefaultStreamContext } +type ( + testStateRootContext struct{ DefaultRootContext } + testStateStreamContext struct { + contextID uint32 + DefaultStreamContext + } + testStateHttpContext struct { + contextID uint32 + DefaultHttpContext + } +) +func (ctx *testStateRootContext) NewStreamContext(contextID uint32) StreamContext { + return &testStateStreamContext{contextID: contextID} +} + +func (ctx *testStateRootContext) NewHttpContext(contextID uint32) HttpContext { + return &testStateHttpContext{contextID: contextID} +} + +func TestState_createStreamContext(t *testing.T) { var ( cid uint32 = 100 rid uint32 = 10 ) s := &state{ - rootContexts: map[uint32]*rootContextState{rid: nil}, - streams: map[uint32]StreamContext{}, - newStreamContext: func(rootContextID, contextID uint32) StreamContext { return &sc{} }, + rootContexts: map[uint32]*rootContextState{rid: nil}, + streams: map[uint32]StreamContext{}, + newRootContext: func(contextID uint32) RootContext { + return &testStateRootContext{} + }, contextIDToRootID: map[uint32]uint32{}, } + s.createRootContext(rid) s.createStreamContext(cid, rid) c, ok := s.streams[cid] require.True(t, ok) - _, ok = c.(*sc) + ctx, ok := c.(*testStateStreamContext) assert.True(t, ok) + assert.Equal(t, cid, ctx.contextID) } func TestState_createHttpContext(t *testing.T) { - type hc struct{ DefaultHttpContext } - var ( cid uint32 = 100 rid uint32 = 10 ) s := &state{ - rootContexts: map[uint32]*rootContextState{rid: nil}, - httpStreams: map[uint32]HttpContext{}, - newHttpContext: func(rootContextID, contextID uint32) HttpContext { return &hc{} }, + rootContexts: map[uint32]*rootContextState{rid: nil}, + httpStreams: map[uint32]HttpContext{}, + newRootContext: func(contextID uint32) RootContext { + return &testStateRootContext{} + }, contextIDToRootID: map[uint32]uint32{}, } + s.createRootContext(rid) s.createHttpContext(cid, rid) c, ok := s.httpStreams[cid] require.True(t, ok) - _, ok = c.(*hc) + ctx, ok := c.(*testStateHttpContext) assert.True(t, ok) - + assert.Equal(t, cid, ctx.contextID) }