diff --git a/.golangci.yml b/.golangci.yml index a3235be..88cb4fb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,17 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -46,18 +61,17 @@ linters: - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - err113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -65,9 +79,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -75,28 +95,22 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -104,8 +118,7 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! diff --git a/context.go b/context.go index 1303def..2884d61 100644 --- a/context.go +++ b/context.go @@ -28,7 +28,7 @@ const ( srtcpIndexSize = 4 ) -// Encrypt/Decrypt state for a single SRTP SSRC +// Encrypt/Decrypt state for a single SRTP SSRC. type srtpSSRCState struct { ssrc uint32 rolloverHasProcessed bool @@ -36,7 +36,7 @@ type srtpSSRCState struct { replayDetector replaydetector.ReplayDetector } -// Encrypt/Decrypt state for a single SRTCP SSRC +// Encrypt/Decrypt state for a single SRTCP SSRC. type srtcpSSRCState struct { srtcpIndex uint32 ssrc uint32 @@ -60,8 +60,10 @@ type Context struct { profile ProtectionProfile - sendMKI []byte // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled. - mkis map[string]srtpCipher // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled. + // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled. + sendMKI []byte + // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled. + mkis map[string]srtpCipher encryptSRTP bool encryptSRTCP bool @@ -74,7 +76,11 @@ type Context struct { // Following example create SRTP Context with replay protection with window size of 256. // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) -func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { +func CreateContext( + masterKey, masterSalt []byte, + profile ProtectionProfile, + opts ...ContextOption, +) (c *Context, err error) { c = &Context{ srtpSSRCStates: map[uint32]*srtpSSRCState{}, srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, @@ -107,7 +113,8 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts return c, nil } -// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option +// AddCipherForMKI adds new MKI with associated masker key and salt. +// Context must be created with MasterKeyIndicator option // to enable MKI support. MKI must be unique and have the same length as the one used for creating Context. // Operation is not thread-safe, you need to provide synchronization with decrypting packets. func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { @@ -126,6 +133,7 @@ func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { return err } c.mkis[string(mki)] = cipher + return nil } @@ -149,7 +157,10 @@ func (c *Context) createCipher(mki, masterKey, masterSalt []byte, encryptSRTP, e switch c.profile { case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: return newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP) - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80: return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP) case ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: return newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki, false, false) @@ -168,6 +179,7 @@ func (c *Context) RemoveMKI(mki []byte) error { return errMKIAlreadyInUse } delete(c.mkis, string(mki)) + return nil } @@ -180,19 +192,20 @@ func (c *Context) SetSendMKI(mki []byte) error { } c.sendMKI = mki c.cipher = cipher + return nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) { seq := int32(sequenceNumber) - localRoc := uint32(s.index >> 16) - localSeq := int32(s.index & (seqNumMax - 1)) + localRoc := uint32(s.index >> 16) //nolint:gosec // G115 + localSeq := int32(s.index & (seqNumMax - 1)) //nolint:gosec // G115 guessRoc := localRoc var difference int32 - if s.rolloverHasProcessed { + if s.rolloverHasProcessed { //nolint:nestif // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian // judgment, it will cause guessRoc calculation error if s.index > seqNumMedian { @@ -226,6 +239,7 @@ func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference in if !s.rolloverHasProcessed { s.index |= uint64(sequenceNumber) s.rolloverHasProcessed = true + return } if difference > 0 { @@ -244,6 +258,7 @@ func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState { replayDetector: c.newSRTPReplayDetector(), } c.srtpSSRCStates[ssrc] = s + return s } @@ -258,6 +273,7 @@ func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState { replayDetector: c.newSRTCPReplayDetector(), } c.srtcpSSRCStates[ssrc] = s + return s } @@ -267,7 +283,8 @@ func (c *Context) ROC(ssrc uint32) (uint32, bool) { if !ok { return 0, false } - return uint32(s.index >> 16), true + + return uint32(s.index >> 16), true //nolint:gosec // G115 } // SetROC sets SRTP rollover counter value of specified SSRC. @@ -283,6 +300,7 @@ func (c *Context) Index(ssrc uint32) (uint32, bool) { if !ok { return 0, false } + return s.srtcpIndex, true } diff --git a/context_test.go b/context_test.go index 81b9b90..77cd3ae 100644 --- a/context_test.go +++ b/context_test.go @@ -48,33 +48,33 @@ func TestContextIndex(t *testing.T) { } func TestContextWithoutMKI(t *testing.T) { - c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) + ctx, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.SetSendMKI(nil) + err = ctx.SetSendMKI(nil) assert.Error(t, err) - err = c.SetSendMKI(make([]byte, 0)) + err = ctx.SetSendMKI(make([]byte, 0)) assert.Error(t, err) - err = c.RemoveMKI(nil) + err = ctx.RemoveMKI(nil) assert.Error(t, err) - err = c.RemoveMKI(make([]byte, 0)) + err = ctx.RemoveMKI(make([]byte, 0)) assert.Error(t, err) - err = c.RemoveMKI(make([]byte, 2)) + err = ctx.RemoveMKI(make([]byte, 2)) assert.Error(t, err) } @@ -82,28 +82,28 @@ func TestAddMKIToContextWithMKI(t *testing.T) { mki1 := []byte{1, 2, 3, 4} mki2 := []byte{2, 3, 4, 5} - c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + ctx, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) assert.Error(t, err) - err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) assert.Error(t, err) } @@ -111,22 +111,22 @@ func TestContextSetSendMKI(t *testing.T) { mki1 := []byte{1, 2, 3, 4} mki2 := []byte{2, 3, 4, 5} - c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + ctx, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) if err != nil { t.Fatal(err) } - err = c.SetSendMKI(mki1) + err = ctx.SetSendMKI(mki1) assert.NoError(t, err) - err = c.SetSendMKI(mki2) + err = ctx.SetSendMKI(mki2) assert.NoError(t, err) - err = c.SetSendMKI(make([]byte, 4)) + err = ctx.SetSendMKI(make([]byte, 4)) assert.Error(t, err) } @@ -135,34 +135,34 @@ func TestContextRemoveMKI(t *testing.T) { mki2 := []byte{2, 3, 4, 5} mki3 := []byte{3, 4, 5, 6} - c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + ctx, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) if err != nil { t.Fatal(err) } - err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) + err = ctx.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) if err != nil { t.Fatal(err) } - err = c.RemoveMKI(make([]byte, 4)) + err = ctx.RemoveMKI(make([]byte, 4)) assert.Error(t, err) - err = c.RemoveMKI(mki1) + err = ctx.RemoveMKI(mki1) assert.Error(t, err) - err = c.SetSendMKI(mki3) + err = ctx.SetSendMKI(mki3) assert.NoError(t, err) - err = c.RemoveMKI(mki1) + err = ctx.RemoveMKI(mki1) assert.NoError(t, err) - err = c.RemoveMKI(mki2) + err = ctx.RemoveMKI(mki2) assert.NoError(t, err) - err = c.RemoveMKI(mki3) + err = ctx.RemoveMKI(mki3) assert.Error(t, err) } diff --git a/crypto.go b/crypto.go index 3f19130..4c82f81 100644 --- a/crypto.go +++ b/crypto.go @@ -55,5 +55,6 @@ func xorBytesCTR(block cipher.Block, iv []byte, dst, src []byte) error { } i += n } + return nil } diff --git a/errors.go b/errors.go index c22653f..0726b19 100644 --- a/errors.go +++ b/errors.go @@ -9,9 +9,9 @@ import ( ) var ( - // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag. ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") - // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet + // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet. ErrMKINotFound = errors.New("MKI not found") errDuplicated = errors.New("duplicated packet") diff --git a/key_derivation.go b/key_derivation.go index f192faf..945b569 100644 --- a/key_derivation.go +++ b/key_derivation.go @@ -40,6 +40,7 @@ func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr i block.Encrypt(out[n:n+nBlockSize], prfIn) i++ } + return out[:outLen], nil } @@ -50,8 +51,12 @@ func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr i // - times the 16-bit RTP sequence number has been reset to zero after // - passing through 65,535 // i = 2^16 * ROC + SEQ -// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16) -func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) (counter [16]byte) { +// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16). +func generateCounter( + sequenceNumber uint16, + rolloverCounter uint32, + ssrc uint32, sessionSalt []byte, +) (counter [16]byte) { copy(counter[:], sessionSalt) counter[4] ^= byte(ssrc >> 24) diff --git a/key_derivation_test.go b/key_derivation_test.go index f0eeb95..5565e4e 100644 --- a/key_derivation_test.go +++ b/key_derivation_test.go @@ -14,9 +14,14 @@ func TestValidSessionKeys_AesCm128(t *testing.T) { masterKey := []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39} masterSalt := []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6} - expectedSessionKey := []byte{0xC6, 0x1E, 0x7A, 0x93, 0x74, 0x4F, 0x39, 0xEE, 0x10, 0x73, 0x4A, 0xFE, 0x3F, 0xF7, 0xA0, 0x87} + expectedSessionKey := []byte{ + 0xC6, 0x1E, 0x7A, 0x93, 0x74, 0x4F, 0x39, 0xEE, 0x10, 0x73, 0x4A, 0xFE, 0x3F, 0xF7, 0xA0, 0x87, + } expectedSessionSalt := []byte{0x30, 0xCB, 0xBC, 0x08, 0x86, 0x3D, 0x8C, 0x85, 0xD4, 0x9D, 0xB3, 0x4A, 0x9A, 0xE1} - expectedSessionAuthTag := []byte{0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, 0xAB, 0x49, 0xAF, 0x25, 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4} + expectedSessionAuthTag := []byte{ + 0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, + 0xAB, 0x49, 0xAF, 0x25, 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4, + } sessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -55,7 +60,10 @@ func TestValidSessionKeys_AesCm256(t *testing.T) { 0x1e, 0xc7, 0xfb, 0x39, 0x7f, 0x70, 0xa9, 0x60, 0x65, 0x3c, 0xaf, 0x06, 0x55, 0x4c, 0xd8, 0xc4, } expectedSessionSalt := []byte{0xfa, 0x31, 0x79, 0x16, 0x85, 0xca, 0x44, 0x4a, 0x9e, 0x07, 0xc6, 0xc6, 0x4e, 0x93} - expectedSessionAuthTag := []byte{0xfd, 0x9c, 0x32, 0xd3, 0x9e, 0xd5, 0xfb, 0xb5, 0xa9, 0xdc, 0x96, 0xb3, 0x08, 0x18, 0x45, 0x4d, 0x13, 0x13, 0xdc, 0x05} + expectedSessionAuthTag := []byte{ + 0xfd, 0x9c, 0x32, 0xd3, 0x9e, 0xd5, 0xfb, 0xb5, 0xa9, 0xdc, + 0x96, 0xb3, 0x08, 0x18, 0x45, 0x4d, 0x13, 0x13, 0xdc, 0x05, + } sessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -83,7 +91,7 @@ func TestValidSessionKeys_AesCm256(t *testing.T) { } // This test asserts that calling aesCmKeyDerivation with a non-zero indexOverKdr fails -// Currently this isn't supported, but the API makes sure we can add this in the future +// Currently this isn't supported, but the API makes sure we can add this in the future. func TestIndexOverKDR(t *testing.T) { _, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, []byte{}, []byte{}, 1, 0) assert.Error(t, err) @@ -101,6 +109,6 @@ func BenchmarkGenerateCounter(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) + generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) //nolint:gosec // G115 } } diff --git a/keying.go b/keying.go index 617f4d7..c9dc183 100644 --- a/keying.go +++ b/keying.go @@ -5,7 +5,7 @@ package srtp const labelExtractorDtlsSrtp = "EXTRACTOR-dtls_srtp" -// KeyingMaterialExporter allows package SRTP to extract keying material +// KeyingMaterialExporter allows package SRTP to extract keying material. type KeyingMaterialExporter interface { ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) } @@ -46,6 +46,7 @@ func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isC c.Keys.LocalMasterSalt = clientWriteKey[keyLen:] c.Keys.RemoteMasterKey = serverWriteKey[0:keyLen] c.Keys.RemoteMasterSalt = serverWriteKey[keyLen:] + return nil } @@ -53,5 +54,6 @@ func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isC c.Keys.LocalMasterSalt = serverWriteKey[keyLen:] c.Keys.RemoteMasterKey = clientWriteKey[0:keyLen] c.Keys.RemoteMasterSalt = clientWriteKey[keyLen:] + return nil } diff --git a/keying_test.go b/keying_test.go index 6071ebf..b6bd68d 100644 --- a/keying_test.go +++ b/keying_test.go @@ -34,11 +34,11 @@ func TestExtractSessionKeysFromDTLS(t *testing.T) { {&Config{Profile: ProtectionProfileAes128CmHmacSha1_80}}, } - m := &mockKeyingMaterialExporter{} + mockExporter := &mockKeyingMaterialExporter{} for i, tc := range tt { // Test client - err := tc.config.ExtractSessionKeysFromDTLS(m, true) + err := tc.config.ExtractSessionKeysFromDTLS(mockExporter, true) if err != nil { t.Errorf("failed to extract keys for %d-client: %v", i, err) } @@ -49,12 +49,15 @@ func TestExtractSessionKeysFromDTLS(t *testing.T) { clientMaterial = append(clientMaterial, keys.LocalMasterSalt...) clientMaterial = append(clientMaterial, keys.RemoteMasterSalt...) - if !bytes.Equal(clientMaterial, m.exported) { - t.Errorf("material reconstruction failed for %d-client:\n%#v\nexpected\n%#v", i, clientMaterial, m.exported) + if !bytes.Equal(clientMaterial, mockExporter.exported) { + t.Errorf( + "material reconstruction failed for %d-client:\n%#v\nexpected\n%#v", + i, clientMaterial, mockExporter.exported, + ) } // Test server - err = tc.config.ExtractSessionKeysFromDTLS(m, false) + err = tc.config.ExtractSessionKeysFromDTLS(mockExporter, false) if err != nil { t.Errorf("failed to extract keys for %d-server: %v", i, err) } @@ -65,8 +68,11 @@ func TestExtractSessionKeysFromDTLS(t *testing.T) { serverMaterial = append(serverMaterial, keys.RemoteMasterSalt...) serverMaterial = append(serverMaterial, keys.LocalMasterSalt...) - if !bytes.Equal(serverMaterial, m.exported) { - t.Errorf("material reconstruction failed for %d-server:\n%#v\nexpected\n%#v", i, serverMaterial, m.exported) + if !bytes.Equal(serverMaterial, mockExporter.exported) { + t.Errorf( + "material reconstruction failed for %d-server:\n%#v\nexpected\n%#v", + i, serverMaterial, mockExporter.exported, + ) } } } diff --git a/option.go b/option.go index dac0bcf..17fc381 100644 --- a/option.go +++ b/option.go @@ -16,6 +16,7 @@ func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber) } + return nil } } @@ -26,6 +27,7 @@ func SRTCPReplayProtection(windowSize uint) ContextOption { c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { return replaydetector.New(windowSize, maxSRTCPIndex) } + return nil } } @@ -36,6 +38,7 @@ func SRTPNoReplayProtection() ContextOption { // nolint:revive c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return &nopReplayDetector{} } + return nil } } @@ -46,6 +49,7 @@ func SRTCPNoReplayProtection() ContextOption { c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { return &nopReplayDetector{} } + return nil } } @@ -54,6 +58,7 @@ func SRTCPNoReplayProtection() ContextOption { func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = fn + return nil } } @@ -62,6 +67,7 @@ func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextO func SRTCPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { return func(c *Context) error { c.newSRTCPReplayDetector = fn + return nil } } @@ -81,6 +87,7 @@ func MasterKeyIndicator(mki []byte) ContextOption { c.sendMKI = make([]byte, len(mki)) copy(c.sendMKI, mki) } + return nil } } @@ -89,15 +96,18 @@ func MasterKeyIndicator(mki []byte) ContextOption { func SRTPEncryption() ContextOption { // nolint:revive return func(c *Context) error { c.encryptSRTP = true + return nil } } -// SRTPNoEncryption disables SRTP encryption. This option is useful when you want to use NullCipher for SRTP and keep authentication only. +// SRTPNoEncryption disables SRTP encryption. +// This option is useful when you want to use NullCipher for SRTP and keep authentication only. // It simplifies debugging and testing, but it is not recommended for production use. func SRTPNoEncryption() ContextOption { // nolint:revive return func(c *Context) error { c.encryptSRTP = false + return nil } } @@ -106,15 +116,18 @@ func SRTPNoEncryption() ContextOption { // nolint:revive func SRTCPEncryption() ContextOption { return func(c *Context) error { c.encryptSRTCP = true + return nil } } -// SRTCPNoEncryption disables SRTCP encryption. This option is useful when you want to use NullCipher for SRTCP and keep authentication only. +// SRTCPNoEncryption disables SRTCP encryption. +// This option is useful when you want to use NullCipher for SRTCP and keep authentication only. // It simplifies debugging and testing, but it is not recommended for production use. func SRTCPNoEncryption() ContextOption { return func(c *Context) error { c.encryptSRTCP = false + return nil } } diff --git a/protection_profile.go b/protection_profile.go index 9384bf8..181da22 100644 --- a/protection_profile.go +++ b/protection_profile.go @@ -5,19 +5,25 @@ package srtp import "fmt" -// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite +// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite. type ProtectionProfile uint16 // Supported protection profiles // See https://www.iana.org/assignments/srtp-protection/srtp-protection.xhtml // -// AES128_CM_HMAC_SHA1_80 and AES128_CM_HMAC_SHA1_32 are valid SRTP profiles, but they do not have an DTLS-SRTP Protection Profiles ID assigned -// in RFC 5764. They were in earlier draft of this RFC: https://datatracker.ietf.org/doc/html/draft-ietf-avt-dtls-srtp-03#section-4.1.2 +// AES128_CM_HMAC_SHA1_80 and AES128_CM_HMAC_SHA1_32 are valid SRTP profiles, +// but they do not have an DTLS-SRTP Protection Profiles ID assigned +// in RFC 5764. They were in earlier draft of this RFC: +// https://datatracker.ietf.org/doc/html/draft-ietf-avt-dtls-srtp-03#section-4.1.2 // Their IDs are now marked as reserved in the IANA registry. Despite this Chrome supports them: // https://chromium.googlesource.com/chromium/deps/libsrtp/+/84122798bb16927b1e676bd4f938a6e48e5bf2fe/srtp/include/srtp.h#694 // -// Null profiles disable encryption, they are used for debugging and testing. They are not recommended for production use. -// Use of them is equivalent to using ProtectionProfileAes128CmHmacSha1_NN profile with SRTPNoEncryption and SRTCPNoEncryption options. +// Null profiles disable encryption, they are used for debugging and testing. +// They are not recommended for production use. +// Use of them is equivalent to using ProtectionProfileAes128CmHmacSha1_NN +// profile with SRTPNoEncryption and SRTCPNoEncryption options. +// +//nolint:lll const ( ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001 ProtectionProfileAes128CmHmacSha1_32 ProtectionProfile = 0x0002 @@ -29,10 +35,16 @@ const ( ProtectionProfileAeadAes256Gcm ProtectionProfile = 0x0008 ) -// KeyLen returns length of encryption key in bytes. For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is is also the length of the session key. +// KeyLen returns length of encryption key in bytes. +// For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is +// also the length of the session key. func (p ProtectionProfile) KeyLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAeadAes128Gcm, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: return 16, nil case ProtectionProfileAeadAes256Gcm, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80: return 32, nil @@ -41,10 +53,17 @@ func (p ProtectionProfile) KeyLen() (int, error) { } } -// SaltLen returns length of salt key in bytes. For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is is also the length of the session salt. +// SaltLen returns length of salt key in bytes. +// For all profiles except NullHmacSha1_32 and NullHmacSha1_80 +// is also the length of the session salt. func (p ProtectionProfile) SaltLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: return 14, nil case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: return 12, nil @@ -53,7 +72,8 @@ func (p ProtectionProfile) SaltLen() (int, error) { } } -// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero. +// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. +// For AEAD ones it returns zero. func (p ProtectionProfile) AuthTagRTPLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_80: @@ -67,10 +87,17 @@ func (p ProtectionProfile) AuthTagRTPLen() (int, error) { } } -// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. For AEAD ones it returns zero. +// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. +// +// For AEAD ones it returns zero. func (p ProtectionProfile) AuthTagRTCPLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: return 10, nil case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: return 0, nil @@ -79,10 +106,16 @@ func (p ProtectionProfile) AuthTagRTCPLen() (int, error) { } } -// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. For AES ones it returns zero. +// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. +// For AES ones it returns zero. func (p ProtectionProfile) AEADAuthTagLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: return 0, nil case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: return 16, nil @@ -91,10 +124,16 @@ func (p ProtectionProfile) AEADAuthTagLen() (int, error) { } } -// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. For AEAD ones it returns zero. +// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. +// For AEAD ones it returns zero. func (p ProtectionProfile) AuthKeyLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: return 20, nil case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: return 0, nil diff --git a/session.go b/session.go index cc4f600..ea4e598 100644 --- a/session.go +++ b/session.go @@ -58,7 +58,7 @@ type Config struct { LocalOptions, RemoteOptions []ContextOption } -// SessionKeys bundles the keys required to setup an SRTP session +// SessionKeys bundles the keys required to setup an SRTP session. type SessionKeys struct { LocalMasterKey []byte LocalMasterSalt []byte @@ -74,20 +74,21 @@ func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto return nil, false } - r, ok := s.readStreams[ssrc] + rStream, ok := s.readStreams[ssrc] if ok { - return r, false + return rStream, false } // Create the readStream. - r = proto() + rStream = proto() - if err := r.init(child, ssrc); err != nil { + if err := rStream.init(child, ssrc); err != nil { return nil, false } - s.readStreams[ssrc] = r - return r, true + s.readStreams[ssrc] = rStream + + return rStream, true } func (s *session) removeReadStream(ssrc uint32) { @@ -109,10 +110,15 @@ func (s *session) close() error { } <-s.closed + return nil } -func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { +func (s *session) start( + localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, + profile ProtectionProfile, + child streamSession, +) error { var err error s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) if err != nil { @@ -146,6 +152,7 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote if !errors.Is(err, io.EOF) { s.log.Error(err.Error()) } + return } diff --git a/session_srtcp.go b/session_srtcp.go index 13f1a95..6104a35 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -16,7 +16,7 @@ const defaultSessionSRTCPReplayProtectionWindow = 64 // SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session // SRTCP itself does not have a design like this, but it is common in most applications // for local/remote to each have their own keying material. This provides those patterns -// instead of making everyone re-implement +// instead of making everyone re-implement. type SessionSRTCP struct { session writeStream *WriteStreamSRTCP @@ -47,7 +47,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n config.RemoteOptions..., ) - s := &SessionSRTCP{ + srtcpSession := &SessionSRTCP{ session: session{ nextConn: conn, localOptions: localOpts, @@ -61,37 +61,39 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n log: loggerFactory.NewLogger("srtp"), }, } - s.writeStream = &WriteStreamSRTCP{s} + srtcpSession.writeStream = &WriteStreamSRTCP{srtcpSession} - err := s.session.start( + err := srtcpSession.session.start( config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, - s, + srtcpSession, ) if err != nil { return nil, err } - return s, nil + + return srtcpSession, nil } -// OpenWriteStream returns the global write stream for the Session +// OpenWriteStream returns the global write stream for the Session. func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) { return s.writeStream, nil } // OpenReadStream opens a read stream for the given SSRC, it can be used -// if you want a certain SSRC, but don't want to wait for AcceptStream +// if you want a certain SSRC, but don't want to wait for AcceptStream. func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) if readStream, ok := r.(*ReadStreamSRTCP); ok { return readStream, nil } + return nil, errFailedTypeAssertion } -// AcceptStream returns a stream to handle RTCP for a single SSRC +// AcceptStream returns a stream to handle RTCP for a single SSRC. func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { stream, ok := <-s.newStream if !ok { @@ -106,7 +108,7 @@ func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { return readStream, stream.GetSSRC(), nil } -// Close ends the session +// Close ends the session. func (s *SessionSRTCP) Close() error { return s.session.close() } @@ -128,6 +130,7 @@ func (s *SessionSRTCP) write(buf []byte) (int, error) { if err != nil { return 0, err } + return s.session.nextConn.Write(encrypted) } diff --git a/session_srtcp_test.go b/session_srtcp_test.go index e61d660..110f3dc 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -28,7 +28,9 @@ func TestSessionSRTCPBadInit(t *testing.T) { } } -func buildSessionSRTCP(t *testing.T) (*SessionSRTCP, net.Conn, *Config) { +func buildSessionSRTCP(t *testing.T) (*SessionSRTCP, net.Conn, *Config) { //nolint:dupl + t.Helper() + aPipe, bPipe := net.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, @@ -51,6 +53,8 @@ func buildSessionSRTCP(t *testing.T) (*SessionSRTCP, net.Conn, *Config) { } func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //nolint:dupl + t.Helper() + aSession, bPipe, config := buildSessionSRTCP(t) bSession, err := NewSessionSRTCP(bPipe, config) if err != nil { @@ -107,7 +111,7 @@ func TestSessionSRTCP(t *testing.T) { } } -func TestSessionSRTCPWithIODeadline(t *testing.T) { +func TestSessionSRTCPWithIODeadline(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -222,7 +226,7 @@ func TestSessionSRTCPOpenReadStream(t *testing.T) { } } -func TestSessionSRTCPReplayProtection(t *testing.T) { +func TestSessionSRTCPReplayProtection(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -341,6 +345,8 @@ func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) { } func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { + t.Helper() + authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.AuthTagRTCPLen() if err != nil { return 0, err @@ -354,13 +360,16 @@ func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err erro } if err != nil { t.Error(err) + return 0, err } pli := &rtcp.PictureLossIndication{} if uerr := pli.Unmarshal(readBuffer[:n]); uerr != nil { t.Error(uerr) + return 0, uerr } + return pli.SenderSSRC, nil } @@ -375,6 +384,7 @@ func encryptSRTCP(context *Context, pkt rtcp.Packet) ([]byte, error) { if eerr != nil { return nil, eerr } + return encrypted, nil } @@ -389,5 +399,6 @@ func errIsTimeout(err error) bool { case strings.Contains(s, "deadline exceeded"): // error message when timeout after go1.15. return true } + return false } diff --git a/session_srtp.go b/session_srtp.go index e07cbe2..9d12bfa 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -17,7 +17,7 @@ const defaultSessionSRTPReplayProtectionWindow = 64 // SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session // SRTP itself does not have a design like this, but it is common in most applications // for local/remote to each have their own keying material. This provides those patterns -// instead of making everyone re-implement +// instead of making everyone re-implement. type SessionSRTP struct { session writeStream *WriteStreamSRTP @@ -48,7 +48,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol config.RemoteOptions..., ) - s := &SessionSRTP{ + srtpSession := &SessionSRTP{ session: session{ nextConn: conn, localOptions: localOpts, @@ -62,27 +62,28 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol log: loggerFactory.NewLogger("srtp"), }, } - s.writeStream = &WriteStreamSRTP{s} + srtpSession.writeStream = &WriteStreamSRTP{srtpSession} - err := s.session.start( + err := srtpSession.session.start( config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, - s, + srtpSession, ) if err != nil { return nil, err } - return s, nil + + return srtpSession, nil } -// OpenWriteStream returns the global write stream for the Session +// OpenWriteStream returns the global write stream for the Session. func (s *SessionSRTP) OpenWriteStream() (*WriteStreamSRTP, error) { return s.writeStream, nil } // OpenReadStream opens a read stream for the given SSRC, it can be used -// if you want a certain SSRC, but don't want to wait for AcceptStream +// if you want a certain SSRC, but don't want to wait for AcceptStream. func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTP) @@ -93,7 +94,7 @@ func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { return nil, errFailedTypeAssertion } -// AcceptStream returns a stream to handle RTCP for a single SSRC +// AcceptStream returns a stream to handle RTCP for a single SSRC. func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { stream, ok := <-s.newStream if !ok { @@ -108,7 +109,7 @@ func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { return readStream, stream.GetSSRC(), nil } -// Close ends the session +// Close ends the session. func (s *SessionSRTP) Close() error { return s.session.close() } @@ -165,13 +166,13 @@ func (s *SessionSRTP) setWriteDeadline(t time.Time) error { } func (s *SessionSRTP) decrypt(buf []byte) error { - h := &rtp.Header{} - headerLen, err := h.Unmarshal(buf) + header := &rtp.Header{} + headerLen, err := header.Unmarshal(buf) if err != nil { return err } - r, isNew := s.session.getOrCreateReadStream(h.SSRC, s, newReadStreamSRTP) + r, isNew := s.session.getOrCreateReadStream(header.SSRC, s, newReadStreamSRTP) if r == nil { return nil // Session has been closed } else if isNew { @@ -186,7 +187,7 @@ func (s *SessionSRTP) decrypt(buf []byte) error { return errFailedTypeAssertion } - decrypted, err := s.remoteContext.decryptRTP(buf, buf, h, headerLen) + decrypted, err := s.remoteContext.decryptRTP(buf, buf, header, headerLen) if err != nil { return err } diff --git a/session_srtp_test.go b/session_srtp_test.go index a3462f6..e90d6ff 100644 --- a/session_srtp_test.go +++ b/session_srtp_test.go @@ -25,7 +25,9 @@ func TestSessionSRTPBadInit(t *testing.T) { } } -func buildSessionSRTP(t *testing.T) (*SessionSRTP, net.Conn, *Config) { +func buildSessionSRTP(t *testing.T) (*SessionSRTP, net.Conn, *Config) { //nolint:dupl + t.Helper() + aPipe, bPipe := net.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, @@ -48,6 +50,8 @@ func buildSessionSRTP(t *testing.T) (*SessionSRTP, net.Conn, *Config) { } func buildSessionSRTPPair(t *testing.T) (*SessionSRTP, *SessionSRTP) { //nolint:dupl + t.Helper() + aSession, bPipe, config := buildSessionSRTP(t) bSession, err := NewSessionSRTP(bPipe, config) if err != nil { @@ -106,7 +110,7 @@ func TestSessionSRTP(t *testing.T) { } } -func TestSessionSRTPWithIODeadline(t *testing.T) { +func TestSessionSRTPWithIODeadline(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -272,7 +276,7 @@ func TestSessionSRTPMultiSSRC(t *testing.T) { } } -func TestSessionSRTPReplayProtection(t *testing.T) { +func TestSessionSRTPReplayProtection(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -394,7 +398,14 @@ func TestSessionSRTPAcceptStreamTimeout(t *testing.T) { } } -func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) { +func assertPayloadSRTP( + t *testing.T, + stream *ReadStreamSRTP, + headerSize int, + expectedPayload []byte, +) (seq uint16, err error) { + t.Helper() + readBuffer := make([]byte, headerSize+len(expectedPayload)) n, hdr, err := stream.ReadRTP(readBuffer) if errors.Is(err, io.EOF) { @@ -402,12 +413,15 @@ func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, exp } if err != nil { t.Error(err) + return 0, err } if !bytes.Equal(expectedPayload, readBuffer[headerSize:n]) { t.Errorf("Sent buffer does not match the one received exp(%v) actual(%v)", expectedPayload, readBuffer[headerSize:n]) + return 0, errPayloadDiffers } + return hdr.SequenceNumber, nil } @@ -422,5 +436,6 @@ func encryptSRTP(context *Context, pkt *rtp.Packet) ([]byte, error) { if eerr != nil { return nil, eerr } + return encrypted, nil } diff --git a/srtcp.go b/srtcp.go index 6d1a1c1..045eddc 100644 --- a/srtcp.go +++ b/srtcp.go @@ -57,10 +57,11 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { } markAsValid() + return out, nil } -// DecryptRTCP decrypts a buffer that contains a RTCP packet +// DecryptRTCP decrypts a buffer that contains a RTCP packet. func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byte, error) { if header == nil { header = &rtcp.Header{} @@ -79,9 +80,9 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { } ssrc := binary.BigEndian.Uint32(decrypted[4:]) - s := c.getSRTCPSSRCState(ssrc) + ssrcState := c.getSRTCPSSRCState(ssrc) - if s.srtcpIndex >= maxSRTCPIndex { + if ssrcState.srtcpIndex >= maxSRTCPIndex { // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key // (whichever occurs before), the key management MUST be called to provide new master key(s) // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. @@ -90,12 +91,12 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { } // We roll over early because MSB is used for marking as encrypted - s.srtcpIndex++ + ssrcState.srtcpIndex++ - return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc) + return c.cipher.encryptRTCP(dst, decrypted, ssrcState.srtcpIndex, ssrc) } -// EncryptRTCP Encrypts a RTCP packet +// EncryptRTCP Encrypts a RTCP packet. func (c *Context) EncryptRTCP(dst, decrypted []byte, header *rtcp.Header) ([]byte, error) { if header == nil { header = &rtcp.Header{} diff --git a/srtcp_test.go b/srtcp_test.go index ca1a809..3b2b318 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -193,9 +193,17 @@ func TestRTCPLifecycleInPlace(t *testing.T) { case decryptHeader.Type != pkt.pktType: t.Fatalf("DecryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, decryptHeader.Type) case !bytes.Equal(decryptInput[:len(decryptInput)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], actualDecrypted): - t.Fatalf("DecryptRTP failed to decrypt in place\nexpected: %v\n got: %v", decryptInput[:len(decryptInput)-(authTagLen+srtcpIndexSize)], actualDecrypted) + t.Fatalf( + "DecryptRTP failed to decrypt in place\nexpected: %v\n got: %v", + decryptInput[:len(decryptInput)-(authTagLen+srtcpIndexSize)], + actualDecrypted, + ) } - assert.Equal(decryptInput[:len(decryptInput)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], actualDecrypted, "DecryptRTP failed to decrypt in place") + assert.Equal( + decryptInput[:len(decryptInput)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], + actualDecrypted, + "DecryptRTP failed to decrypt in place", + ) assert.Equal(pkt.decrypted, actualDecrypted, "RTCP failed to decrypt") @@ -213,7 +221,11 @@ func TestRTCPLifecycleInPlace(t *testing.T) { case encryptHeader.Type != pkt.pktType: t.Fatalf("EncryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, encryptHeader.Type) } - assert.Equal(actualEncrypted[:len(actualEncrypted)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], encryptInput, "EncryptRTCP failed to encrypt in place") + assert.Equal( + actualEncrypted[:len(actualEncrypted)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], + encryptInput, + "EncryptRTCP failed to encrypt in place", + ) assert.Equal(pkt.encrypted, actualEncrypted, "RTCP failed to encrypt") } @@ -221,7 +233,7 @@ func TestRTCPLifecycleInPlace(t *testing.T) { } } -// Assert that passing a dst buffer that is too short doesn't result in a failure +// Assert that passing a dst buffer that is too short doesn't result in a failure. func TestRTCPLifecyclePartialAllocation(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase @@ -343,6 +355,7 @@ func TestRTCPReplayDetectorSeparation(t *testing.T) { func getRTCPIndex(encrypted []byte, authTagLen int) uint32 { tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) srtcpIndexBuffer := encrypted[tailOffset : tailOffset+srtcpIndexSize] + return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } @@ -538,7 +551,12 @@ func TestRTCPMaxPackets(t *testing.T) { t.Errorf("CreateContext failed: %v", err) } - decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10)) + decryptContext, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + SRTCPReplayProtection(10), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -582,6 +600,7 @@ func TestRTCPReplayDetectorFactory(t *testing.T) { testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayDetectorFactory(func() replaydetector.ReplayDetector { cntFactory++ + return &nopReplayDetector{} }), ) @@ -626,12 +645,22 @@ func TestRTCPInvalidMKI(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { - encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + encryptContext, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki1), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } - decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + decryptContext, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki2), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -654,7 +683,7 @@ func TestRTCPInvalidMKI(t *testing.T) { } } -func TestRTCPHandleMultipleMKI(t *testing.T) { +func TestRTCPHandleMultipleMKI(t *testing.T) { //nolint:cyclop mki1 := []byte{0x01, 0x02, 0x03, 0x04} mki2 := []byte{0x02, 0x03, 0x04, 0x05} @@ -665,7 +694,12 @@ func TestRTCPHandleMultipleMKI(t *testing.T) { copy(masterKey2, testCase.masterKey) masterKey2[0] = ^masterKey2[0] - encryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + encryptContext1, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki1), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -674,7 +708,12 @@ func TestRTCPHandleMultipleMKI(t *testing.T) { t.Errorf("CreateContext failed: %v", err) } - decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + decryptContext, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki1), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -710,7 +749,7 @@ func TestRTCPHandleMultipleMKI(t *testing.T) { } } -func TestRTCPSwitchMKI(t *testing.T) { +func TestRTCPSwitchMKI(t *testing.T) { //nolint:cyclop mki1 := []byte{0x01, 0x02, 0x03, 0x04} mki2 := []byte{0x02, 0x03, 0x04, 0x05} @@ -721,7 +760,12 @@ func TestRTCPSwitchMKI(t *testing.T) { copy(masterKey2, testCase.masterKey) masterKey2[0] = ^masterKey2[0] - encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + encryptContext, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki1), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -730,7 +774,12 @@ func TestRTCPSwitchMKI(t *testing.T) { t.Errorf("AddCipherForMKI failed: %v", err) } - decryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + decryptContext1, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.algo, + MasterKeyIndicator(mki1), + ) if err != nil { t.Errorf("CreateContext failed: %v", err) } diff --git a/srtp.go b/srtp.go index 56828bc..adb4397 100644 --- a/srtp.go +++ b/srtp.go @@ -26,10 +26,10 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL return nil, fmt.Errorf("%w: %d", errTooShortRTP, len(ciphertext)) } - s := c.getSRTPSSRCState(header.SSRC) + ssrcState := c.getSRTPSSRCState(header.SSRC) - roc, diff, _ := s.nextRolloverCount(header.SequenceNumber) - markAsValid, ok := s.replayDetector.Check( + roc, diff, _ := ssrcState.nextRolloverCount(header.SequenceNumber) + markAsValid, ok := ssrcState.replayDetector.Check( (uint64(roc) << 16) | uint64(header.SequenceNumber), ) if !ok { @@ -56,11 +56,12 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } markAsValid() - s.updateRolloverCount(header.SequenceNumber, diff) + ssrcState.updateRolloverCount(header.SequenceNumber, diff) + return dst, nil } -// DecryptRTP decrypts a RTP packet with an encrypted payload +// DecryptRTP decrypts a RTP packet with an encrypted payload. func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, error) { if header == nil { header = &rtp.Header{} @@ -75,7 +76,8 @@ func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, } // EncryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. -// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned. +// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, +// a new one will be allocated and returned. // If a rtp.Header is provided, it will be Unmarshaled using the plaintext. func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ([]byte, error) { if header == nil { diff --git a/srtp_cipher.go b/srtp_cipher.go index da745e7..93f70c4 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -6,7 +6,7 @@ package srtp import "github.com/pion/rtp" // cipher represents a implementation of one -// of the SRTP Specific ciphers +// of the SRTP Specific ciphers. type srtpCipher interface { // AuthTagRTPLen/AuthTagRTCPLen return auth key length of the cipher. // See the note below. diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 64f890f..82c0bd8 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -28,8 +28,12 @@ type srtpCipherAeadAesGcm struct { srtpEncrypted, srtcpEncrypted bool } -func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte, encryptSRTP, encryptSRTCP bool) (*srtpCipherAeadAesGcm, error) { - s := &srtpCipherAeadAesGcm{ +func newSrtpCipherAeadAesGcm( + profile ProtectionProfile, + masterKey, masterSalt, mki []byte, + encryptSRTP, encryptSRTCP bool, +) (*srtpCipherAeadAesGcm, error) { + srtpCipher := &srtpCipherAeadAesGcm{ ProtectionProfile: profile, srtpEncrypted: encryptSRTP, srtcpEncrypted: encryptSRTCP, @@ -45,7 +49,7 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, m return nil, err } - s.srtpCipher, err = cipher.NewGCM(srtpBlock) + srtpCipher.srtpCipher, err = cipher.NewGCM(srtpBlock) if err != nil { return nil, err } @@ -60,27 +64,36 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, m return nil, err } - s.srtcpCipher, err = cipher.NewGCM(srtcpBlock) + srtpCipher.srtcpCipher, err = cipher.NewGCM(srtcpBlock) if err != nil { return nil, err } - if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation( + labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { return nil, err - } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + } else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation( + labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { return nil, err } mkiLen := len(mki) if mkiLen > 0 { - s.mki = make([]byte, mkiLen) - copy(s.mki, mki) + srtpCipher.mki = make([]byte, mkiLen) + copy(srtpCipher.mki, mki) } - return s, nil + return srtpCipher, nil } -func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { +func (s *srtpCipherAeadAesGcm) encryptRTP( + dst []byte, + header *rtp.Header, + payload []byte, + roc uint32, +) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. authTagLen, err := s.AEADAuthTagLen() if err != nil { @@ -110,7 +123,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa return dst, nil } -func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { +func (s *srtpCipherAeadAesGcm) decryptRTP( + dst, ciphertext []byte, + header *rtp.Header, + headerLen int, + roc uint32, +) ([]byte, error) { // Grow the given buffer to fit the output. authTagLen, err := s.AEADAuthTagLen() if err != nil { @@ -143,6 +161,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He } copy(dst[:headerLen], ciphertext[:headerLen]) + return dst, nil } @@ -176,6 +195,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin } copy(dst[aadPos+4:], s.mki) + return dst, nil } @@ -215,6 +235,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss } copy(dst[:8], encrypted[:8]) + return dst, nil } @@ -233,6 +254,7 @@ func (s *srtpCipherAeadAesGcm) rtpInitializationVector(header *rtp.Header, roc u for i := range iv { iv[i] ^= s.srtpSessionSalt[i] } + return iv } @@ -252,6 +274,7 @@ func (s *srtpCipherAeadAesGcm) rtcpInitializationVector(srtcpIndex uint32, ssrc for i := range iv { iv[i] ^= s.srtcpSessionSalt[i] } + return iv } @@ -281,5 +304,6 @@ func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte { } tailOffset := len(in) - mkiLen + return in[tailOffset:] } diff --git a/srtp_cipher_aead_aes_gcm_rfc_test.go b/srtp_cipher_aead_aes_gcm_rfc_test.go index 9e50bde..b8ed14f 100644 --- a/srtp_cipher_aead_aes_gcm_rfc_test.go +++ b/srtp_cipher_aead_aes_gcm_rfc_test.go @@ -20,6 +20,7 @@ func fromHex(s string) []byte { if err != nil { panic(err) } + return b } @@ -36,7 +37,7 @@ type testRfcAeadCipher struct { authenticatedRTCPPacket []byte } -// createRfcAeadTestCiphers returns a list of test ciphers for the RFC test vectors +// createRfcAeadTestCiphers returns a list of test ciphers for the RFC test vectors. func createRfcAeadTestCiphers() []testRfcAeadCipher { tests := []testRfcAeadCipher{} @@ -125,102 +126,102 @@ func createRfcAeadTestCiphers() []testRfcAeadCipher { } func TestAeadCiphersWithRfcTestVectors(t *testing.T) { - for _, c := range createRfcAeadTestCiphers() { - t.Run(c.profile.String(), func(t *testing.T) { + for _, testCase := range createRfcAeadTestCiphers() { + t.Run(testCase.profile.String(), func(t *testing.T) { t.Run("Encrypt RTP", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, true, true) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, true, true) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTPPacket, actualEncrypted) + assert.Equal(t, testCase.encryptedRTPPacket, actualEncrypted) }) t.Run("Decrypt RTP", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, true, true) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, true, true) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacket, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.encryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) t.Run("Encrypt RTCP", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, true, true) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, true, true) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTCPPacket, actualEncrypted) + assert.Equal(t, testCase.encryptedRTCPPacket, actualEncrypted) }) t.Run("Decrypt RTCP", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, true, true) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, true, true) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacket, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.encryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) t.Run("Encrypt RTP with NULL cipher", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, false, false) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, false, false) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.authenticatedRTPPacket, actualEncrypted) + assert.Equal(t, testCase.authenticatedRTPPacket, actualEncrypted) }) t.Run("Decrypt RTP with NULL cipher", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, false, false) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, false, false) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualDecrypted, err := ctx.DecryptRTP(nil, c.authenticatedRTPPacket, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.authenticatedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) t.Run("Encrypt RTCP with NULL cipher", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, false, false) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, false, false) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.authenticatedRTCPPacket, actualEncrypted) + assert.Equal(t, testCase.authenticatedRTCPPacket, actualEncrypted) }) t.Run("Decrypt RTCP with NULL cipher", func(t *testing.T) { - cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(c.profile, c.keys, false, false) + cipher, err := newSrtpCipherAeadAesGcmWithDerivedKeys(testCase.profile, testCase.keys, false, false) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) ctx.SetIndex(0x4d617273, 0x000005d3) - actualDecrypted, err := ctx.DecryptRTCP(nil, c.authenticatedRTCPPacket, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.authenticatedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) }) } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index aa673fa..0a1f4f6 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -31,13 +31,18 @@ type srtpCipherAesCmHmacSha1 struct { mki []byte } -func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte, encryptSRTP, encryptSRTCP bool) (*srtpCipherAesCmHmacSha1, error) { +//nolint:cyclop +func newSrtpCipherAesCmHmacSha1( + profile ProtectionProfile, + masterKey, masterSalt, mki []byte, + encryptSRTP, encryptSRTCP bool, +) (*srtpCipherAesCmHmacSha1, error) { if profile == ProtectionProfileNullHmacSha1_80 || profile == ProtectionProfileNullHmacSha1_32 { encryptSRTP = false encryptSRTCP = false } - s := &srtpCipherAesCmHmacSha1{ + srtpCipher := &srtpCipherAesCmHmacSha1{ ProtectionProfile: profile, srtpEncrypted: encryptSRTP, srtcpEncrypted: encryptSRTCP, @@ -46,20 +51,24 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err - } else if s.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil { + } else if srtpCipher.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil { return nil, err } srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err - } else if s.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil { + } else if srtpCipher.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil { return nil, err } - if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation( + labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { return nil, err - } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + } else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation( + labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { return nil, err } @@ -78,19 +87,24 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt return nil, err } - s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) - s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + srtpCipher.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) + srtpCipher.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) mkiLen := len(mki) if mkiLen > 0 { - s.mki = make([]byte, mkiLen) - copy(s.mki, mki) + srtpCipher.mki = make([]byte, mkiLen) + copy(srtpCipher.mki, mki) } - return s, nil + return srtpCipher, nil } -func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { +func (s *srtpCipherAesCmHmacSha1) encryptRTP( + dst []byte, + header *rtp.Header, + payload []byte, + roc uint32, +) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. authTagLen, err := s.AuthTagRTPLen() if err != nil { @@ -133,7 +147,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay return dst, nil } -func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { +func (s *srtpCipherAesCmHmacSha1) decryptRTP( + dst, ciphertext []byte, + header *rtp.Header, + headerLen int, + roc uint32, +) ([]byte, error) { // Split the auth tag and the cipher text into two parts. authTagLen, err := s.AuthTagRTPLen() if err != nil { @@ -171,6 +190,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp } else { copy(dst[headerLen:], ciphertext[headerLen:]) } + return dst, nil } @@ -179,7 +199,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex // Encrypt everything after header if s.srtcpEncrypted { - counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) + counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115 if err := xorBytesCTR(s.srtcpBlock, counter[:], dst[8:], dst[8:]); err != nil { return nil, err } @@ -235,7 +255,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc isEncrypted := encrypted[tailOffset]>>7 != 0 if isEncrypted { - counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) + counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115 err = xorBytesCTR(s.srtcpBlock, counter[:], out[8:], out[8:]) } else { copy(out[8:], encrypted[8:]) @@ -279,6 +299,7 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([ if err != nil { return nil, err } + return s.srtpSessionAuth.Sum(nil)[0:authTagLen], nil } @@ -311,6 +332,7 @@ func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.AuthTagRTCPLen() tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] + return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } @@ -327,5 +349,6 @@ func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte { authTagLen, _ = s.AuthTagRTCPLen() } tailOffset := len(in) - (authTagLen + mkiLen) + return in[tailOffset : tailOffset+mkiLen] } diff --git a/srtp_cipher_aes_cm_hmac_sha1_rfc_test.go b/srtp_cipher_aes_cm_hmac_sha1_rfc_test.go index f784d6c..0cdba17 100644 --- a/srtp_cipher_aes_cm_hmac_sha1_rfc_test.go +++ b/srtp_cipher_aes_cm_hmac_sha1_rfc_test.go @@ -15,7 +15,7 @@ type testRfcAesCipher struct { keystream []byte } -// createRfcAesTestCiphers returns a list of test ciphers for the RFC test vectors +// createRfcAesTestCiphers returns a list of test ciphers for the RFC test vectors. func createRfcAesTestCiphers() []testRfcAesCipher { tests := []testRfcAesCipher{} @@ -54,8 +54,8 @@ func createRfcAesTestCiphers() []testRfcAesCipher { } func TestAesCiphersWithRfcTestVectors(t *testing.T) { - for _, c := range createRfcAesTestCiphers() { - t.Run(c.profile.String(), func(t *testing.T) { + for _, testCase := range createRfcAesTestCiphers() { + t.Run(testCase.profile.String(), func(t *testing.T) { // Use zero SSRC and sequence number as specified in RFC rtpHeader := []byte{ 0x80, 0x0f, 0x00, 0x00, 0xde, 0xca, 0xfb, 0xad, @@ -63,20 +63,21 @@ func TestAesCiphersWithRfcTestVectors(t *testing.T) { } t.Run("Keystream generation", func(t *testing.T) { - cipher, err := newSrtpCipherAesCmHmacSha1WithDerivedKeys(c.profile, c.keys, true, true) + cipher, err := newSrtpCipherAesCmHmacSha1WithDerivedKeys(testCase.profile, testCase.keys, true, true) assert.NoError(t, err) - ctx, err := createContextWithCipher(c.profile, cipher) + ctx, err := createContextWithCipher(testCase.profile, cipher) assert.NoError(t, err) - // Generated AES keystream will be XOR'ed with zeroes in RTP packet payload, so SRTP payload will be equal to keystream - decryptedRTPPacket := make([]byte, len(rtpHeader)+len(c.keystream)) + // Generated AES keystream will be XOR'ed with zeroes in RTP packet payload, + // so SRTP payload will be equal to keystream + decryptedRTPPacket := make([]byte, len(rtpHeader)+len(testCase.keystream)) copy(decryptedRTPPacket, rtpHeader) actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) assert.NoError(t, err) assert.Equal(t, rtpHeader, actualEncrypted[:len(rtpHeader)]) - assert.Equal(t, c.keystream, actualEncrypted[len(rtpHeader):len(rtpHeader)+len(c.keystream)]) + assert.Equal(t, testCase.keystream, actualEncrypted[len(rtpHeader):len(rtpHeader)+len(testCase.keystream)]) }) }) } diff --git a/srtp_cipher_test.go b/srtp_cipher_test.go index e5806dd..45a021b 100644 --- a/srtp_cipher_test.go +++ b/srtp_cipher_test.go @@ -28,8 +28,8 @@ type testCipher struct { authenticatedRTCPPacketWithMKI []byte } -// create array of testCiphers for each supported profile -func createTestCiphers() []testCipher { +// create array of testCiphers for each supported profile. +func createTestCiphers() []testCipher { //nolint:maintidx tests := []testCipher{ { //nolint:dupl profile: ProtectionProfileAes128CmHmacSha1_32, @@ -556,7 +556,7 @@ func createTestCiphers() []testCipher { 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, } - for k, v := range tests { + for key, v := range tests { keyLen, err := v.profile.KeyLen() if err != nil { panic(err) @@ -565,199 +565,243 @@ func createTestCiphers() []testCipher { if err != nil { panic(err) } - tests[k].masterKey = masterKey[:keyLen] - tests[k].masterSalt = masterSalt[:saltLen] - tests[k].mki = mki - tests[k].decryptedRTPPacket = decryptedRTPPacket - tests[k].decryptedRTCPPacket = decryptedRTCPPacket + tests[key].masterKey = masterKey[:keyLen] + tests[key].masterSalt = masterSalt[:saltLen] + tests[key].mki = mki + tests[key].decryptedRTPPacket = decryptedRTPPacket + tests[key].decryptedRTCPPacket = decryptedRTCPPacket } return tests } func TestSrtpCipher(t *testing.T) { - for _, c := range createTestCiphers() { - t.Run(c.profile.String(), func(t *testing.T) { - assert.Equal(t, c.decryptedRTPPacket, c.authenticatedRTPPacket[:len(c.decryptedRTPPacket)]) - assert.Equal(t, c.decryptedRTCPPacket, c.authenticatedRTCPPacket[:len(c.decryptedRTCPPacket)]) + for _, testCase := range createTestCiphers() { + t.Run(testCase.profile.String(), func(t *testing.T) { + assert.Equal(t, testCase.decryptedRTPPacket, testCase.authenticatedRTPPacket[:len(testCase.decryptedRTPPacket)]) + assert.Equal(t, testCase.decryptedRTCPPacket, testCase.authenticatedRTCPPacket[:len(testCase.decryptedRTCPPacket)]) t.Run("Encrypt RTP", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile) + ctx, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.profile) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTPPacket, actualEncrypted) + assert.Equal(t, testCase.encryptedRTPPacket, actualEncrypted) }) }) t.Run("Decrypt RTP", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile) + ctx, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.profile) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacket, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.encryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) }) t.Run("Encrypt RTCP", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile) + ctx, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.profile) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTCPPacket, actualEncrypted) + assert.Equal(t, testCase.encryptedRTCPPacket, actualEncrypted) }) }) t.Run("Decrypt RTCP", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile) + ctx, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.profile) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacket, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.encryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) }) t.Run("Encrypt RTP with MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTPPacketWithMKI, actualEncrypted) + assert.Equal(t, testCase.encryptedRTPPacketWithMKI, actualEncrypted) }) }) t.Run("Decrypt RTP with MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTP(nil, c.encryptedRTPPacketWithMKI, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.encryptedRTPPacketWithMKI, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) }) t.Run("Encrypt RTCP with MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.encryptedRTCPPacketWithMKI, actualEncrypted) + assert.Equal(t, testCase.encryptedRTCPPacketWithMKI, actualEncrypted) }) }) t.Run("Decrypt RTCP with MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTCP(nil, c.encryptedRTCPPacketWithMKI, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.encryptedRTCPPacketWithMKI, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) }) t.Run("Encrypt RTP with NULL cipher", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption()) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, SRTPNoEncryption(), SRTCPNoEncryption(), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualEncrypted[:len(c.decryptedRTPPacket)]) - assert.Equal(t, c.authenticatedRTPPacket, actualEncrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualEncrypted[:len(testCase.decryptedRTPPacket)]) + assert.Equal(t, testCase.authenticatedRTPPacket, actualEncrypted) }) }) t.Run("Decrypt RTP with NULL cipher", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption()) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, SRTPNoEncryption(), SRTCPNoEncryption(), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTP(nil, c.authenticatedRTPPacket, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.authenticatedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) }) t.Run("Encrypt RTCP with NULL cipher", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption()) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, SRTPNoEncryption(), SRTCPNoEncryption(), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualEncrypted[:len(c.decryptedRTCPPacket)]) - assert.Equal(t, c.authenticatedRTCPPacket, actualEncrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualEncrypted[:len(testCase.decryptedRTCPPacket)]) + assert.Equal(t, testCase.authenticatedRTCPPacket, actualEncrypted) }) }) t.Run("Decrypt RTCP with NULL cipher", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption()) + ctx, err := CreateContext( + testCase.masterKey, testCase.masterSalt, testCase.profile, SRTPNoEncryption(), SRTCPNoEncryption(), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTCP(nil, c.authenticatedRTCPPacket, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.authenticatedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) }) t.Run("Encrypt RTP with NULL cipher and MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption(), MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.profile, + SRTPNoEncryption(), + SRTCPNoEncryption(), + MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTP(nil, c.decryptedRTPPacket, nil) + actualEncrypted, err := ctx.EncryptRTP(nil, testCase.decryptedRTPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualEncrypted[:len(c.decryptedRTPPacket)]) - assert.Equal(t, c.authenticatedRTPPacketWithMKI, actualEncrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualEncrypted[:len(testCase.decryptedRTPPacket)]) + assert.Equal(t, testCase.authenticatedRTPPacketWithMKI, actualEncrypted) }) }) t.Run("Decrypt RTP with NULL cipher and MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption(), MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.profile, + SRTPNoEncryption(), + SRTCPNoEncryption(), + MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTP(nil, c.authenticatedRTPPacketWithMKI, nil) + actualDecrypted, err := ctx.DecryptRTP(nil, testCase.authenticatedRTPPacketWithMKI, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTPPacket, actualDecrypted) }) }) t.Run("Encrypt RTCP with NULL cipher and MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption(), MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.profile, + SRTPNoEncryption(), + SRTCPNoEncryption(), + MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualEncrypted, err := ctx.EncryptRTCP(nil, c.decryptedRTCPPacket, nil) + actualEncrypted, err := ctx.EncryptRTCP(nil, testCase.decryptedRTCPPacket, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualEncrypted[:len(c.decryptedRTCPPacket)]) - assert.Equal(t, c.authenticatedRTCPPacketWithMKI, actualEncrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualEncrypted[:len(testCase.decryptedRTCPPacket)]) + assert.Equal(t, testCase.authenticatedRTCPPacketWithMKI, actualEncrypted) }) }) t.Run("Decrypt RTCP with NULL cipher and MKI", func(t *testing.T) { - ctx, err := CreateContext(c.masterKey, c.masterSalt, c.profile, SRTPNoEncryption(), SRTCPNoEncryption(), MasterKeyIndicator(c.mki)) + ctx, err := CreateContext( + testCase.masterKey, + testCase.masterSalt, + testCase.profile, + SRTPNoEncryption(), + SRTCPNoEncryption(), + MasterKeyIndicator(testCase.mki), + ) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { - actualDecrypted, err := ctx.DecryptRTCP(nil, c.authenticatedRTCPPacketWithMKI, nil) + actualDecrypted, err := ctx.DecryptRTCP(nil, testCase.authenticatedRTCPPacketWithMKI, nil) assert.NoError(t, err) - assert.Equal(t, c.decryptedRTCPPacket, actualDecrypted) + assert.Equal(t, testCase.decryptedRTCPPacket, actualDecrypted) }) }) }) diff --git a/srtp_cipher_utils_test.go b/srtp_cipher_utils_test.go index 3b28ca0..12ba6e9 100644 --- a/srtp_cipher_utils_test.go +++ b/srtp_cipher_utils_test.go @@ -10,7 +10,9 @@ import ( "crypto/sha1" // nolint:gosec ) -// deriveSessionKeys should be used in tests only. RFCs test vectors specifes derived keys to use, this struct is used to inject them into the cipher in tests. +// deriveSessionKeys should be used in tests only. +// RFCs test vectors specifes derived keys to use, +// this struct is used to inject them into the cipher in tests. type derivedSessionKeys struct { srtpSessionKey []byte srtpSessionSalt []byte @@ -20,45 +22,57 @@ type derivedSessionKeys struct { srtcpSessionAuthTag []byte } -func newSrtpCipherAesCmHmacSha1WithDerivedKeys(profile ProtectionProfile, keys derivedSessionKeys, encryptSRTP, encryptSRTCP bool) (*srtpCipherAesCmHmacSha1, error) { +func newSrtpCipherAesCmHmacSha1WithDerivedKeys( + profile ProtectionProfile, + keys derivedSessionKeys, + encryptSRTP, encryptSRTCP bool, +) (*srtpCipherAesCmHmacSha1, error) { if profile == ProtectionProfileNullHmacSha1_80 || profile == ProtectionProfileNullHmacSha1_32 { encryptSRTP = false encryptSRTCP = false } - s := &srtpCipherAesCmHmacSha1{ + srtpCipher := &srtpCipherAesCmHmacSha1{ ProtectionProfile: profile, srtpEncrypted: encryptSRTP, srtcpEncrypted: encryptSRTCP, } var err error - if s.srtpBlock, err = aes.NewCipher(keys.srtpSessionKey); err != nil { + if srtpCipher.srtpBlock, err = aes.NewCipher(keys.srtpSessionKey); err != nil { return nil, err } - if s.srtcpBlock, err = aes.NewCipher(keys.srtcpSessionKey); err != nil { + if srtpCipher.srtcpBlock, err = aes.NewCipher(keys.srtcpSessionKey); err != nil { return nil, err } - s.srtpSessionSalt = keys.srtpSessionSalt - s.srtcpSessionSalt = keys.srtcpSessionSalt + srtpCipher.srtpSessionSalt = keys.srtpSessionSalt + srtpCipher.srtcpSessionSalt = keys.srtcpSessionSalt - s.srtcpSessionAuth = hmac.New(sha1.New, keys.srtcpSessionAuthTag) - s.srtpSessionAuth = hmac.New(sha1.New, keys.srtpSessionAuthTag) + srtpCipher.srtcpSessionAuth = hmac.New(sha1.New, keys.srtcpSessionAuthTag) + srtpCipher.srtpSessionAuth = hmac.New(sha1.New, keys.srtpSessionAuthTag) - return s, nil + return srtpCipher, nil } -func newSrtpCipherAeadAesGcmWithDerivedKeys(profile ProtectionProfile, keys derivedSessionKeys, encryptSRTP, encryptSRTCP bool) (*srtpCipherAeadAesGcm, error) { - s := &srtpCipherAeadAesGcm{ProtectionProfile: profile, srtpEncrypted: encryptSRTP, srtcpEncrypted: encryptSRTCP} +func newSrtpCipherAeadAesGcmWithDerivedKeys( + profile ProtectionProfile, + keys derivedSessionKeys, + encryptSRTP, encryptSRTCP bool, +) (*srtpCipherAeadAesGcm, error) { + srtpCipher := &srtpCipherAeadAesGcm{ + ProtectionProfile: profile, + srtpEncrypted: encryptSRTP, + srtcpEncrypted: encryptSRTCP, + } srtpBlock, err := aes.NewCipher(keys.srtpSessionKey) if err != nil { return nil, err } - s.srtpCipher, err = cipher.NewGCM(srtpBlock) + srtpCipher.srtpCipher, err = cipher.NewGCM(srtpBlock) if err != nil { return nil, err } @@ -68,33 +82,34 @@ func newSrtpCipherAeadAesGcmWithDerivedKeys(profile ProtectionProfile, keys deri return nil, err } - s.srtcpCipher, err = cipher.NewGCM(srtcpBlock) + srtpCipher.srtcpCipher, err = cipher.NewGCM(srtcpBlock) if err != nil { return nil, err } - s.srtpSessionSalt = keys.srtpSessionSalt - s.srtcpSessionSalt = keys.srtcpSessionSalt + srtpCipher.srtpSessionSalt = keys.srtpSessionSalt + srtpCipher.srtcpSessionSalt = keys.srtcpSessionSalt - return s, nil + return srtpCipher, nil } // createContextWithCipher creates a new SRTP Context with a pre-created cipher. This is used for testing purposes only. func createContextWithCipher(profile ProtectionProfile, cipher srtpCipher) (*Context, error) { - c := &Context{ + ctx := &Context{ srtpSSRCStates: map[uint32]*srtpSSRCState{}, srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, profile: profile, mkis: map[string]srtpCipher{}, cipher: cipher, } - err := SRTPNoReplayProtection()(c) + err := SRTPNoReplayProtection()(ctx) if err != nil { return nil, err } - err = SRTCPNoReplayProtection()(c) + err = SRTCPNoReplayProtection()(ctx) if err != nil { return nil, err } - return c, nil + + return ctx, nil } diff --git a/srtp_test.go b/srtp_test.go index 794e474..2bc701f 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -37,6 +37,8 @@ func (tc rtpTestCase) encrypted(profile ProtectionProfile) []byte { } func testKeyLen(t *testing.T, profile ProtectionProfile) { + t.Helper() + keyLen, err := profile.KeyLen() assert.NoError(t, err) @@ -69,97 +71,99 @@ func TestValidPacketCounter(t *testing.T) { assert.NoError(t, err) s := &srtpSSRCState{ssrc: 4160032510} - expectedCounter := []byte{0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00, 0x00} - counter := generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) + expectedCounter := []byte{ + 0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00, 0x00, + } + counter := generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) //nolint:gosec // G115 if !bytes.Equal(counter[:], expectedCounter) { t.Errorf("Session Key % 02x does not match expected % 02x", counter, expectedCounter) } } -func TestRolloverCount(t *testing.T) { - s := &srtpSSRCState{ssrc: defaultSsrc} +func TestRolloverCount(t *testing.T) { //nolint:cyclop + ssrcState := &srtpSSRCState{ssrc: defaultSsrc} // Set initial seqnum - roc, diff, ovf := s.nextRolloverCount(65530) + roc, diff, ovf := ssrcState.nextRolloverCount(65530) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(65530, diff) + ssrcState.updateRolloverCount(65530, diff) // Invalid packets never update ROC - s.nextRolloverCount(0) - s.nextRolloverCount(0x4000) - s.nextRolloverCount(0x8000) - s.nextRolloverCount(0xFFFF) - s.nextRolloverCount(0) + ssrcState.nextRolloverCount(0) + ssrcState.nextRolloverCount(0x4000) + ssrcState.nextRolloverCount(0x8000) + ssrcState.nextRolloverCount(0xFFFF) + ssrcState.nextRolloverCount(0) // We rolled over to 0 - roc, diff, ovf = s.nextRolloverCount(0) + roc, diff, ovf = ssrcState.nextRolloverCount(0) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(0, diff) + ssrcState.updateRolloverCount(0, diff) - roc, diff, ovf = s.nextRolloverCount(65530) + roc, diff, ovf = ssrcState.nextRolloverCount(65530) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(65530, diff) + ssrcState.updateRolloverCount(65530, diff) - roc, diff, ovf = s.nextRolloverCount(5) + roc, diff, ovf = ssrcState.nextRolloverCount(5) if roc != 1 { t.Errorf("rolloverCounter was not updated when it rolled over initial, to handle out of order") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(5, diff) + ssrcState.updateRolloverCount(5, diff) - _, diff, _ = s.nextRolloverCount(6) - s.updateRolloverCount(6, diff) - _, diff, _ = s.nextRolloverCount(7) - s.updateRolloverCount(7, diff) - roc, diff, _ = s.nextRolloverCount(8) + _, diff, _ = ssrcState.nextRolloverCount(6) + ssrcState.updateRolloverCount(6, diff) + _, diff, _ = ssrcState.nextRolloverCount(7) + ssrcState.updateRolloverCount(7, diff) + roc, diff, _ = ssrcState.nextRolloverCount(8) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - s.updateRolloverCount(8, diff) + ssrcState.updateRolloverCount(8, diff) // valid packets never update ROC - roc, diff, ovf = s.nextRolloverCount(0x4000) + roc, diff, ovf = ssrcState.nextRolloverCount(0x4000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(0x4000, diff) - roc, diff, ovf = s.nextRolloverCount(0x8000) + ssrcState.updateRolloverCount(0x4000, diff) + roc, diff, ovf = ssrcState.nextRolloverCount(0x8000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(0x8000, diff) - roc, diff, ovf = s.nextRolloverCount(0xFFFF) + ssrcState.updateRolloverCount(0x8000, diff) + roc, diff, ovf = ssrcState.nextRolloverCount(0xFFFF) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(0xFFFF, diff) - roc, _, ovf = s.nextRolloverCount(0) + ssrcState.updateRolloverCount(0xFFFF, diff) + roc, _, ovf = ssrcState.nextRolloverCount(0) if roc != 2 { t.Errorf("rolloverCounter must be incremented after wrapping, got %d", roc) } @@ -237,37 +241,71 @@ func rtpTestCases() []rtpTestCase { { sequenceNumber: 5000, encryptedCTR: []byte{0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, 0xfd, 0x8}, - encryptedGCM: []byte{0x05, 0x39, 0x62, 0xbb, 0x50, 0x2a, 0x08, 0x19, 0xc7, 0xcc, 0xc9, 0x24, 0xb8, 0xd9, 0x7a, 0xe5, 0xad, 0x99, 0x06, 0xc7, 0x3b, 0}, + encryptedGCM: []byte{ + 0x05, 0x39, 0x62, 0xbb, 0x50, 0x2a, 0x08, 0x19, 0xc7, 0xcc, 0xc9, + 0x24, 0xb8, 0xd9, 0x7a, 0xe5, 0xad, 0x99, 0x06, 0xc7, 0x3b, 0x00, + }, }, { sequenceNumber: 5001, - encryptedCTR: []byte{0xda, 0x47, 0xb, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, 0x4b, 0x6d, 0x23, 0xf3, 0xde}, - encryptedGCM: []byte{0xb0, 0xbc, 0xfc, 0xb0, 0x15, 0x2c, 0xa0, 0x15, 0xb5, 0xa8, 0xcd, 0x0d, 0x65, 0xfa, 0x98, 0xb3, 0x09, 0xb1, 0xf8, 0x4b, 0x1c, 0xfa}, + encryptedCTR: []byte{ + 0xda, 0x47, 0x0b, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, + 0x4b, 0x6d, 0x23, 0xf3, 0xde, + }, + encryptedGCM: []byte{ + 0xb0, 0xbc, 0xfc, 0xb0, 0x15, 0x2c, 0xa0, 0x15, 0xb5, 0xa8, 0xcd, + 0x0d, 0x65, 0xfa, 0x98, 0xb3, 0x09, 0xb1, 0xf8, 0x4b, 0x1c, 0xfa, + }, }, { sequenceNumber: 5002, - encryptedCTR: []byte{0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x2, 0x1c, 0xd1, 0x60, 0x76, 0xc1, 0xe}, - encryptedGCM: []byte{0x5e, 0x20, 0x6a, 0xbf, 0x58, 0x7e, 0x24, 0xc0, 0x15, 0x94, 0x7a, 0xe2, 0x49, 0x25, 0xd4, 0xd4, 0x08, 0xe2, 0xf1, 0x47, 0x7a, 0x33}, + encryptedCTR: []byte{ + 0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x02, 0x1c, + 0xd1, 0x60, 0x76, 0xc1, 0xe, + }, + encryptedGCM: []byte{ + 0x5e, 0x20, 0x6a, 0xbf, 0x58, 0x7e, 0x24, 0xc0, 0x15, 0x94, 0x7a, + 0xe2, 0x49, 0x25, 0xd4, 0xd4, 0x08, 0xe2, 0xf1, 0x47, 0x7a, 0x33, + }, }, { sequenceNumber: 5003, - encryptedCTR: []byte{0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, 0xe0, 0x76, 0x9f, 0x8, 0xdc}, - encryptedGCM: []byte{0xb0, 0x63, 0x14, 0xe7, 0xd2, 0x29, 0xca, 0x92, 0x8c, 0x97, 0x25, 0xd2, 0x50, 0x69, 0x6e, 0x1b, 0x04, 0xb9, 0x37, 0xa5, 0xa1, 0xc5}, + encryptedCTR: []byte{ + 0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, + 0xe0, 0x76, 0x9f, 0x8, 0xdc, + }, + encryptedGCM: []byte{ + 0xb0, 0x63, 0x14, 0xe7, 0xd2, 0x29, 0xca, 0x92, 0x8c, 0x97, 0x25, + 0xd2, 0x50, 0x69, 0x6e, 0x1b, 0x04, 0xb9, 0x37, 0xa5, 0xa1, 0xc5, + }, }, { sequenceNumber: 5004, - encryptedCTR: []byte{0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, 0x12, 0x68, 0xb, 0x57, 0x49}, - encryptedGCM: []byte{0xb2, 0x4f, 0x19, 0x53, 0x79, 0x8a, 0x9b, 0x9e, 0xe5, 0x22, 0x93, 0x14, 0x50, 0x8a, 0x8c, 0xd5, 0xfc, 0x61, 0xbf, 0x95, 0xd1, 0xfb}, + encryptedCTR: []byte{ + 0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, + 0x12, 0x68, 0xb, 0x57, 0x49, + }, + encryptedGCM: []byte{ + 0xb2, 0x4f, 0x19, 0x53, 0x79, 0x8a, 0x9b, 0x9e, 0xe5, 0x22, 0x93, + 0x14, 0x50, 0x8a, 0x8c, 0xd5, 0xfc, 0x61, 0xbf, 0x95, 0xd1, 0xfb, + }, }, { sequenceNumber: 65535, // upper boundary - encryptedCTR: []byte{0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, 0x15, 0xe, 0x44, 0xca, 0x36}, - encryptedGCM: []byte{0x40, 0x44, 0x6c, 0xd1, 0x33, 0x5f, 0xca, 0x9b, 0x2e, 0xa3, 0xe5, 0x03, 0xd7, 0x82, 0x36, 0xd8, 0xb7, 0xe8, 0x97, 0x3c, 0xe6, 0xb6}, + encryptedCTR: []byte{ + 0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, + 0x15, 0xe, 0x44, 0xca, 0x36, + }, + encryptedGCM: []byte{ + 0x40, 0x44, 0x6c, 0xd1, 0x33, 0x5f, 0xca, 0x9b, 0x2e, 0xa3, 0xe5, + 0x03, 0xd7, 0x82, 0x36, 0xd8, 0xb7, 0xe8, 0x97, 0x3c, 0xe6, 0xb6, + }, }, } } func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) { + t.Helper() assert := assert.New(t) authTagLen, err := profile.AuthTagRTPLen() @@ -284,13 +322,19 @@ func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) { t.Fatal(err) } - decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + decryptedPkt := &rtp.Packet{ + Payload: rtpTestCaseDecrypted(), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{ + Payload: testCase.encrypted(profile), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) @@ -318,7 +362,8 @@ func TestRTPLifecycleNewAlloc(t *testing.T) { t.Run("GCM", func(t *testing.T) { testRTPLifecyleNewAlloc(t, profileGCM) }) } -func testRTPLifecyleInPlace(t *testing.T, profile ProtectionProfile) { +func testRTPLifecyleInPlace(t *testing.T, profile ProtectionProfile) { //nolint:cyclop + t.Helper() assert := assert.New(t) for _, testCase := range rtpTestCases() { @@ -333,14 +378,20 @@ func testRTPLifecyleInPlace(t *testing.T, profile ProtectionProfile) { } decryptHeader := &rtp.Header{} - decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + decryptedPkt := &rtp.Packet{ + Payload: rtpTestCaseDecrypted(), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } encryptHeader := &rtp.Header{} - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{ + Payload: testCase.encrypted(profile), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) @@ -387,7 +438,8 @@ func TestRTPLifecycleInPlace(t *testing.T) { t.Run("GCM", func(t *testing.T) { testRTPLifecyleInPlace(t, profileGCM) }) } -func testRTPReplayProtection(t *testing.T, profile ProtectionProfile) { +func testRTPReplayProtection(t *testing.T, profile ProtectionProfile) { //nolint:cyclop + t.Helper() assert := assert.New(t) for _, testCase := range rtpTestCases() { @@ -404,14 +456,20 @@ func testRTPReplayProtection(t *testing.T, profile ProtectionProfile) { } decryptHeader := &rtp.Header{} - decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + decryptedPkt := &rtp.Packet{ + Payload: rtpTestCaseDecrypted(), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } encryptHeader := &rtp.Header{} - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{ + Payload: testCase.encrypted(profile), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) @@ -472,6 +530,7 @@ func TestRTPReplayDetectorFactory(t *testing.T) { decryptContext, err := buildTestContext( profile, SRTPReplayDetectorFactory(func() replaydetector.ReplayDetector { cntFactory++ + return &nopReplayDetector{} }), ) @@ -495,6 +554,8 @@ func TestRTPReplayDetectorFactory(t *testing.T) { } func benchmarkEncryptRTP(b *testing.B, profile ProtectionProfile, size int) { + b.Helper() + encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) @@ -533,6 +594,8 @@ func BenchmarkEncryptRTP(b *testing.B) { } func benchmarkEncryptRTPInPlace(b *testing.B, profile ProtectionProfile, size int) { + b.Helper() + encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) @@ -573,6 +636,8 @@ func BenchmarkEncryptRTPInPlace(b *testing.B) { } func benchmarkDecryptRTP(b *testing.B, profile ProtectionProfile) { + b.Helper() + sequenceNumber := uint16(5000) encrypted := rtpTestCases()[0].encrypted(profile) @@ -609,91 +674,91 @@ func BenchmarkDecryptRTP(b *testing.B) { b.Run("GCM", func(b *testing.B) { benchmarkDecryptRTP(b, profileGCM) }) } -func TestRolloverCount2(t *testing.T) { - s := &srtpSSRCState{ssrc: defaultSsrc} +func TestRolloverCount2(t *testing.T) { //nolint:cyclop + srtpState := &srtpSSRCState{ssrc: defaultSsrc} - roc, diff, ovf := s.nextRolloverCount(30123) + roc, diff, ovf := srtpState.nextRolloverCount(30123) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(30123, diff) + srtpState.updateRolloverCount(30123, diff) - roc, diff, ovf = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 + roc, diff, ovf = srtpState.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(62892, diff) - roc, diff, ovf = s.nextRolloverCount(204) + srtpState.updateRolloverCount(62892, diff) + roc, diff, ovf = srtpState.nextRolloverCount(204) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(62892, diff) - roc, diff, ovf = s.nextRolloverCount(64535) + srtpState.updateRolloverCount(62892, diff) + roc, diff, ovf = srtpState.nextRolloverCount(64535) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(64535, diff) - roc, diff, ovf = s.nextRolloverCount(205) + srtpState.updateRolloverCount(64535, diff) + roc, diff, ovf = srtpState.nextRolloverCount(205) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(205, diff) - roc, diff, ovf = s.nextRolloverCount(1) + srtpState.updateRolloverCount(205, diff) + roc, diff, ovf = srtpState.nextRolloverCount(1) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(1, diff) + srtpState.updateRolloverCount(1, diff) - roc, diff, ovf = s.nextRolloverCount(64532) + roc, diff, ovf = srtpState.nextRolloverCount(64532) if roc != 0 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(64532, diff) - roc, diff, ovf = s.nextRolloverCount(65534) + srtpState.updateRolloverCount(64532, diff) + roc, diff, ovf = srtpState.nextRolloverCount(65534) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(65534, diff) - roc, diff, ovf = s.nextRolloverCount(64532) + srtpState.updateRolloverCount(65534, diff) + roc, diff, ovf = srtpState.nextRolloverCount(64532) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(65532, diff) - roc, diff, ovf = s.nextRolloverCount(205) + srtpState.updateRolloverCount(65532, diff) + roc, diff, ovf = srtpState.nextRolloverCount(205) if roc != 1 { t.Errorf("index was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } - s.updateRolloverCount(65532, diff) + srtpState.updateRolloverCount(65532, diff) } func TestProtectionProfileAes128CmHmacSha1_32(t *testing.T) { @@ -745,7 +810,10 @@ func TestRTPDecryptShotenedPacket(t *testing.T) { t.Fatal(err) } - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{ + Payload: testCase.encrypted(profile), + Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}, + } encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) @@ -810,7 +878,7 @@ func TestRTPMaxPackets(t *testing.T) { } } -func TestRTPBurstLossWithSetROC(t *testing.T) { +func TestRTPBurstLossWithSetROC(t *testing.T) { //nolint:cyclop profiles := map[string]ProtectionProfile{ "CTR": profileCTR, "GCM": profileGCM, @@ -836,7 +904,7 @@ func TestRTPBurstLossWithSetROC(t *testing.T) { var pkts []*packetWithROC encryptContext.SetROC(1, 3) for i := 0x8C00; i < 0x20400; i += 0x100 { - p := &packetWithROC{ + packet := &packetWithROC{ pkt: rtp.Packet{ Payload: []byte{ byte(i >> 16), @@ -846,25 +914,25 @@ func TestRTPBurstLossWithSetROC(t *testing.T) { Header: rtp.Header{ Marker: true, SSRC: 1, - SequenceNumber: uint16(i), + SequenceNumber: uint16(i), //nolint:gosec // G115 }, }, } - b, errMarshal := p.pkt.Marshal() + b, errMarshal := packet.pkt.Marshal() if errMarshal != nil { t.Fatal(errMarshal) } - p.raw = b + packet.raw = b enc, errEnc := encryptContext.EncryptRTP(nil, b, nil) if errEnc != nil { t.Fatal(errEnc) } - p.roc, _ = encryptContext.ROC(1) + packet.roc, _ = encryptContext.ROC(1) if 0x9000 < i && i < 0x20100 { continue } - p.enc = enc - pkts = append(pkts, p) + packet.enc = enc + pkts = append(pkts, packet) } decryptContext, err := buildTestContext(profile) @@ -877,6 +945,7 @@ func TestRTPBurstLossWithSetROC(t *testing.T) { pkt, err := decryptContext.DecryptRTP(nil, p.enc, nil) if err != nil { t.Errorf("roc=%d, seq=%d: %v", p.roc, p.pkt.SequenceNumber, err) + continue } assert.Equal(p.raw, pkt) @@ -892,7 +961,10 @@ func TestDecryptInvalidSRTP(t *testing.T) { decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80) assert.NoError(err) - packet := []byte{0x41, 0x02, 0x07, 0xf9, 0xf9, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb5, 0x73, 0x19, 0xf6, 0x91, 0xbb, 0x3e, 0xa5, 0x21, 0x07} + packet := []byte{ + 0x41, 0x02, 0x07, 0xf9, 0xf9, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0xb5, 0x73, 0x19, 0xf6, 0x91, 0xbb, 0x3e, 0xa5, 0x21, 0x07, + } _, err = decryptContext.DecryptRTP(nil, packet, nil) assert.Error(err) } @@ -931,7 +1003,7 @@ func TestRTPInvalidMKI(t *testing.T) { } } -func TestRTPHandleMultipleMKI(t *testing.T) { +func TestRTPHandleMultipleMKI(t *testing.T) { //nolint:cyclop mki1 := []byte{0x01, 0x02, 0x03, 0x04} mki2 := []byte{0x02, 0x03, 0x04, 0x05} @@ -987,7 +1059,7 @@ func TestRTPHandleMultipleMKI(t *testing.T) { } } -func TestRTPSwitchMKI(t *testing.T) { +func TestRTPSwitchMKI(t *testing.T) { //nolint:cyclop mki1 := []byte{0x01, 0x02, 0x03, 0x04} mki2 := []byte{0x02, 0x03, 0x04, 0x05} diff --git a/stream_srtcp.go b/stream_srtcp.go index dc71c40..8fe407f 100644 --- a/stream_srtcp.go +++ b/stream_srtcp.go @@ -13,10 +13,10 @@ import ( "github.com/pion/transport/v3/packetio" ) -// Limit the buffer size to 100KB +// Limit the buffer size to 100KB. const srtcpBufferSize = 100 * 1000 -// ReadStreamSRTCP handles decryption for a single RTCP SSRC +// ReadStreamSRTCP handles decryption for a single RTCP SSRC. type ReadStreamSRTCP struct { mu sync.Mutex @@ -40,12 +40,12 @@ func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) { return n, err } -// Used by getOrCreateReadStream +// Used by getOrCreateReadStream. func newReadStreamSRTCP() readStream { return &ReadStreamSRTCP{} } -// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn +// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn. func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { n, err := r.Read(buf) if err != nil { @@ -61,7 +61,7 @@ func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { return n, header, nil } -// Read reads and decrypts full RTCP packet from the nextConn +// Read reads and decrypts full RTCP packet from the nextConn. func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } @@ -74,10 +74,11 @@ func (r *ReadStreamSRTCP) SetReadDeadline(t time.Time) error { }); ok { return b.SetReadDeadline(t) } + return nil } -// Close removes the ReadStream from the session and cleans up any associated state +// Close removes the ReadStream from the session and cleans up any associated state. func (r *ReadStreamSRTCP) Close() error { r.mu.Lock() defer r.mu.Unlock() @@ -96,6 +97,7 @@ func (r *ReadStreamSRTCP) Close() error { } r.session.removeReadStream(r.ssrc) + return nil } } @@ -128,17 +130,17 @@ func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error { return nil } -// GetSSRC returns the SSRC we are demuxing for +// GetSSRC returns the SSRC we are demuxing for. func (r *ReadStreamSRTCP) GetSSRC() uint32 { return r.ssrc } -// WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP +// WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP. type WriteStreamSRTCP struct { session *SessionSRTCP } -// WriteRTCP encrypts a RTCP header and its payload to the nextConn +// WriteRTCP encrypts a RTCP header and its payload to the nextConn. func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, error) { headerRaw, err := header.Marshal() if err != nil { @@ -148,7 +150,7 @@ func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, return w.session.write(append(headerRaw, payload...)) } -// Write encrypts and writes a full RTCP packets to the nextConn +// Write encrypts and writes a full RTCP packets to the nextConn. func (w *WriteStreamSRTCP) Write(b []byte) (int, error) { return w.session.write(b) } diff --git a/stream_srtp.go b/stream_srtp.go index cad0a38..1b34266 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -13,10 +13,10 @@ import ( "github.com/pion/transport/v3/packetio" ) -// Limit the buffer size to 1MB +// Limit the buffer size to 1MB. const srtpBufferSize = 1000 * 1000 -// ReadStreamSRTP handles decryption for a single RTP SSRC +// ReadStreamSRTP handles decryption for a single RTP SSRC. type ReadStreamSRTP struct { mu sync.Mutex @@ -29,7 +29,7 @@ type ReadStreamSRTP struct { buffer io.ReadWriteCloser } -// Used by getOrCreateReadStream +// Used by getOrCreateReadStream. func newReadStreamSRTP() readStream { return &ReadStreamSRTP{} } @@ -74,12 +74,12 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { return n, err } -// Read reads and decrypts full RTP packet from the nextConn +// Read reads and decrypts full RTP packet from the nextConn. func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } -// ReadRTP reads and decrypts full RTP packet and its header from the nextConn +// ReadRTP reads and decrypts full RTP packet and its header from the nextConn. func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { n, err := r.Read(buf) if err != nil { @@ -104,10 +104,11 @@ func (r *ReadStreamSRTP) SetReadDeadline(t time.Time) error { }); ok { return b.SetReadDeadline(t) } + return nil } -// Close removes the ReadStream from the session and cleans up any associated state +// Close removes the ReadStream from the session and cleans up any associated state. func (r *ReadStreamSRTP) Close() error { r.mu.Lock() defer r.mu.Unlock() @@ -126,26 +127,27 @@ func (r *ReadStreamSRTP) Close() error { } r.session.removeReadStream(r.ssrc) + return nil } } -// GetSSRC returns the SSRC we are demuxing for +// GetSSRC returns the SSRC we are demuxing for. func (r *ReadStreamSRTP) GetSSRC() uint32 { return r.ssrc } -// WriteStreamSRTP is stream for a single Session that is used to encrypt RTP +// WriteStreamSRTP is stream for a single Session that is used to encrypt RTP. type WriteStreamSRTP struct { session *SessionSRTP } -// WriteRTP encrypts a RTP packet and writes to the connection +// WriteRTP encrypts a RTP packet and writes to the connection. func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) { return w.session.writeRTP(header, payload) } -// Write encrypts and writes a full RTP packets to the nextConn +// Write encrypts and writes a full RTP packets to the nextConn. func (w *WriteStreamSRTP) Write(b []byte) (int, error) { return w.session.write(b) } diff --git a/stream_srtp_test.go b/stream_srtp_test.go index 8f8ab37..2137d90 100644 --- a/stream_srtp_test.go +++ b/stream_srtp_test.go @@ -17,15 +17,23 @@ import ( type noopConn struct{ closed chan struct{} } -func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} } -func (c *noopConn) Read([]byte) (n int, err error) { <-c.closed; return 0, io.EOF } +func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} } +func (c *noopConn) Read([]byte) (n int, err error) { + <-c.closed + + return 0, io.EOF +} func (c *noopConn) Write(b []byte) (n int, err error) { return len(b), nil } -func (c *noopConn) Close() error { close(c.closed); return nil } -func (c *noopConn) LocalAddr() net.Addr { return nil } -func (c *noopConn) RemoteAddr() net.Addr { return nil } -func (c *noopConn) SetDeadline(time.Time) error { return nil } -func (c *noopConn) SetReadDeadline(time.Time) error { return nil } -func (c *noopConn) SetWriteDeadline(time.Time) error { return nil } +func (c *noopConn) Close() error { + close(c.closed) + + return nil +} +func (c *noopConn) LocalAddr() net.Addr { return nil } +func (c *noopConn) RemoteAddr() net.Addr { return nil } +func (c *noopConn) SetDeadline(time.Time) error { return nil } +func (c *noopConn) SetReadDeadline(time.Time) error { return nil } +func (c *noopConn) SetWriteDeadline(time.Time) error { return nil } func TestBufferFactory(t *testing.T) { wg := sync.WaitGroup{} @@ -33,6 +41,7 @@ func TestBufferFactory(t *testing.T) { conn := newNoopConn() bf := func(_ packetio.BufferPacketType, _ uint32) io.ReadWriteCloser { wg.Done() + return packetio.NewBuffer() } rtpSession, err := NewSessionSRTP(conn, &Config{ @@ -65,6 +74,8 @@ func TestBufferFactory(t *testing.T) { } func benchmarkWrite(b *testing.B, profile ProtectionProfile, size int) { + b.Helper() + conn := newNoopConn() keyLen, err := profile.KeyLen() @@ -143,6 +154,8 @@ func BenchmarkWrite(b *testing.B) { } func benchmarkWriteRTP(b *testing.B, profile ProtectionProfile, size int) { + b.Helper() + conn := &noopConn{ closed: make(chan struct{}), } diff --git a/util.go b/util.go index 792175d..27e0509 100644 --- a/util.go +++ b/util.go @@ -13,10 +13,11 @@ func growBufferSize(buf []byte, size int) []byte { buf2 := make([]byte, size) copy(buf2, buf) + return buf2 } -// Check if buffers match, if not allocate a new buffer and return it +// Check if buffers match, if not allocate a new buffer and return it. func allocateIfMismatch(dst, src []byte) []byte { if dst == nil { dst = make([]byte, len(src)) @@ -24,7 +25,7 @@ func allocateIfMismatch(dst, src []byte) []byte { } else if !bytes.Equal(dst, src) { // bytes.Equal returns on ref equality, no optimization needed extraNeeded := len(src) - len(dst) if extraNeeded > 0 { - dst = append(dst, make([]byte, extraNeeded)...) + dst = append(dst, make([]byte, extraNeeded)...) //nolint:makezero // todo: fix } else if extraNeeded < 0 { dst = dst[:len(dst)+extraNeeded] }