diff --git a/go.mod b/go.mod index b2a87ed3..5bc9afb8 100644 --- a/go.mod +++ b/go.mod @@ -4,30 +4,30 @@ go 1.25.0 require ( github.com/OvyFlash/telegram-bot-api v0.0.0-20251112155921-e82db5fd534b - github.com/fatih/color v1.18.0 + github.com/fatih/color v1.19.0 github.com/forPelevin/gomoji v1.4.1 - github.com/fsnotify/fsnotify v1.9.0 + github.com/fsnotify/fsnotify v1.10.1 github.com/go-pkgz/expirable-cache/v3 v3.1.0 github.com/go-pkgz/fileutils v0.4.0 - github.com/go-pkgz/lgr v0.12.1 + github.com/go-pkgz/lgr v0.12.3 github.com/go-pkgz/repeater v1.2.0 - github.com/go-pkgz/rest v1.20.6 + github.com/go-pkgz/rest v1.21.0 github.com/go-pkgz/routegroup v1.6.0 github.com/go-pkgz/testutils v0.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/jessevdk/go-flags v1.6.1 github.com/jmoiron/sqlx v1.4.0 - github.com/lib/pq v1.10.9 - github.com/playwright-community/playwright-go v0.5200.1 + github.com/lib/pq v1.12.3 + github.com/playwright-community/playwright-go v0.5700.1 github.com/sandwich-go/gpt3-encoder v0.0.0-20230203030618-cd99729dd0dd github.com/sashabaranov/go-openai v1.41.2 github.com/sony/gobreaker/v2 v2.4.0 github.com/stretchr/testify v1.11.1 - github.com/yuin/gopher-lua v1.1.1 + github.com/yuin/gopher-lua v1.1.2 golang.org/x/image v0.38.0 golang.org/x/text v0.36.0 golang.org/x/time v0.11.0 - google.golang.org/genai v1.52.1 + google.golang.org/genai v1.57.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 modernc.org/sqlite v1.49.1 ) @@ -65,7 +65,7 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/deckarep/golang-set/v2 v2.7.0 // indirect + github.com/deckarep/golang-set/v2 v2.8.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/docker/go-connections v0.7.0 // indirect diff --git a/go.sum b/go.sum index 08b71c4b..6e8a605b 100644 --- a/go.sum +++ b/go.sum @@ -77,8 +77,8 @@ github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfv github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/deckarep/golang-set/v2 v2.7.0 h1:gIloKvD7yH2oip4VLhsv3JyLLFnC0Y2mlusgcvJYW5k= -github.com/deckarep/golang-set/v2 v2.7.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= +github.com/deckarep/golang-set/v2 v2.8.0 h1:swm0rlPCmdWn9mESxKOjWk8hXSqoxOp+ZlfuyaAdFlQ= +github.com/deckarep/golang-set/v2 v2.8.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= @@ -95,14 +95,14 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= -github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= +github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/forPelevin/gomoji v1.4.1 h1:7U+Bl8o6RV/dOQz7coQFWj/jX6Ram6/cWFOuFDEPEUo= github.com/forPelevin/gomoji v1.4.1/go.mod h1:mM6GtmCgpoQP2usDArc6GjbXrti5+FffolyQfGgPboQ= -github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= -github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho= +github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= github.com/go-jose/go-jose/v3 v3.0.5 h1:BLLJWbC4nMZOfuPVxoZIxeYsn6Nl2r1fITaJ78UQlVQ= github.com/go-jose/go-jose/v3 v3.0.5/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -117,12 +117,12 @@ github.com/go-pkgz/expirable-cache/v3 v3.1.0 h1:s05P851/O6QJ6Mc+7o2bh9aGtD3romB1 github.com/go-pkgz/expirable-cache/v3 v3.1.0/go.mod h1:6pVgNleydKPj0J2/mzrI02/RDo4ivKx5v2XlNmIjhjo= github.com/go-pkgz/fileutils v0.4.0 h1:v3CEj/nMiei2FLK4+JODBMNGHXWjjgh8b9LFMWDLgkg= github.com/go-pkgz/fileutils v0.4.0/go.mod h1:Fz25H/5U4P8aoxVPA22lJqK0WvRvS3YwT4MV4NzhtJc= -github.com/go-pkgz/lgr v0.12.1 h1:8GVfG2rSARq3Eaj5PP158rtBR2LHVGkwioIkQBGbvKg= -github.com/go-pkgz/lgr v0.12.1/go.mod h1:A4AxjOthFVFK6jRnVYMeusno5SeDAxcLVHd0kI/lN/Y= +github.com/go-pkgz/lgr v0.12.3 h1:QDug7kRkEsuQtruT9fNF5PVT2kZUqCDPc4GmsgS3fP8= +github.com/go-pkgz/lgr v0.12.3/go.mod h1:lpCDgVvCIxBHZp8+sGCj9MPctIzKZyZ3QdE19ddqd54= github.com/go-pkgz/repeater v1.2.0 h1:oJFvjyKdTDd5RCzpzxlzYIZFFj6Zfl17rE1aUfu6UjQ= github.com/go-pkgz/repeater v1.2.0/go.mod h1:vypP6xamA53MFmafnGUucqOmALKk36xgKu2hSG73LHM= -github.com/go-pkgz/rest v1.20.6 h1:O/IhQ3I2cS4bJYvL1TLcy63w2OcXZTTBG3R+wTIqPS4= -github.com/go-pkgz/rest v1.20.6/go.mod h1:NY+MX1is2kJckJt+nHDNovS/5j9jmF4yQuSno4qg7XU= +github.com/go-pkgz/rest v1.21.0 h1:Y/C4d/TpclJJDxqnH1RAcS6Hmox0RIReAlkwMcUWXK4= +github.com/go-pkgz/rest v1.21.0/go.mod h1:+AHzjHazq7Z3Tk/kRWOhbbAz/YZlUV40feC1Hf4NtbE= github.com/go-pkgz/routegroup v1.6.0 h1:44XHZgF6JIIldRlv+zjg6SygULASmjifnfIQjwCT0e4= github.com/go-pkgz/routegroup v1.6.0/go.mod h1:Pmu04fhgWhRtBMIJ8HXppnnzOPjnL/IEPBIdO2zmeqg= github.com/go-pkgz/testutils v0.6.0 h1:+hHdikZAZm7EISWrbJ0Od42eQIAUSwoFFsTBtq3ZRKI= @@ -191,8 +191,9 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e h1:Q6MvJtQK/iRcRtzAscm/zF23XxJlbECiGPyRicsX+Ak= github.com/lufia/plan9stats v0.0.0-20260330125221-c963978e514e/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= @@ -203,8 +204,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= -github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/go-archive v0.2.0 h1:zg5QDUM2mi0JIM9fdQZWC7U8+2ZfixfTYoHL7rWUcP8= @@ -233,8 +232,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= -github.com/playwright-community/playwright-go v0.5200.1 h1:Sm2oOuhqt0M5Y4kUi/Qh9w4cyyi3ZIWTBeGKImc2UVo= -github.com/playwright-community/playwright-go v0.5200.1/go.mod h1:UnnyQZaqUOO5ywAZu60+N4EiWReUqX1MQBBA3Oofvf8= +github.com/playwright-community/playwright-go v0.5700.1 h1:PNFb1byWqrTT720rEO0JL88C6Ju0EmUnR5deFLvtP/U= +github.com/playwright-community/playwright-go v0.5700.1/go.mod h1:MlSn1dZrx8rszbCxY6x3qK89ZesJUYVx21B2JnkoNF0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= @@ -284,8 +283,8 @@ github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gi github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= -github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA= +github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU= @@ -395,8 +394,8 @@ gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk= -google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.57.0 h1:qTyG2ynz5dQy2jF4CvZdLHHVslhR0heMue+zM1a4GNM= +google.golang.org/genai v1.57.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/vendor/github.com/deckarep/golang-set/v2/README.md b/vendor/github.com/deckarep/golang-set/v2/README.md index bb691b1c..e4713680 100644 --- a/vendor/github.com/deckarep/golang-set/v2/README.md +++ b/vendor/github.com/deckarep/golang-set/v2/README.md @@ -9,6 +9,11 @@ The missing `generic` set collection for the Go language. Until Go has sets bui ## Psst * Hi there, 👋! Do you use or have interest in the [Zig programming language](https://ziglang.org/) created by Andrew Kelley? If so, the golang-set project has a new sibling project: [ziglang-set](https://github.com/deckarep/ziglang-set)! Come check it out! +## Update 3/14/2025 +* Packaged version: `2.8.0` introduces support for true iterators for Go 1.23+. Please see [issue #141](https://github.com/deckarep/golang-set/issues/141) +for further details on the implications of how iterations work between older Go versions vs newer Go versions. Additionally, this +release has a minor unit-test spelling fix. + ## Update 12/3/2024 * Packaged version: `2.7.0` fixes a long-standing bug with *JSON Unmarshaling*. A large refactor in the interest of performance introduced this bug and there was no way around it but to revert the code back to how it was previously. The performance diff --git a/vendor/github.com/deckarep/golang-set/v2/set.go b/vendor/github.com/deckarep/golang-set/v2/set.go index 292089dc..e9409aa8 100644 --- a/vendor/github.com/deckarep/golang-set/v2/set.go +++ b/vendor/github.com/deckarep/golang-set/v2/set.go @@ -73,6 +73,10 @@ type Set[T comparable] interface { // given items are in the set. ContainsAny(val ...T) bool + // ContainsAnyElement returns whether at least one of the + // given element are in the set. + ContainsAnyElement(other Set[T]) bool + // Difference returns the difference between this set // and other. The returned set will contain // all elements of this set that are not also @@ -253,3 +257,13 @@ func NewThreadUnsafeSetFromMapKeys[T comparable, V any](val map[T]V) Set[T] { return s } + +// Elements returns an iterator that yields the elements of the set. Starting +// with Go 1.23, users can use a for loop to iterate over it. +func Elements[T comparable](s Set[T]) func(func(element T) bool) { + return func(yield func(element T) bool) { + s.Each(func(t T) bool { + return !yield(t) + }) + } +} diff --git a/vendor/github.com/deckarep/golang-set/v2/threadsafe.go b/vendor/github.com/deckarep/golang-set/v2/threadsafe.go index 93f20c86..664fc611 100644 --- a/vendor/github.com/deckarep/golang-set/v2/threadsafe.go +++ b/vendor/github.com/deckarep/golang-set/v2/threadsafe.go @@ -82,6 +82,19 @@ func (t *threadSafeSet[T]) ContainsAny(v ...T) bool { return ret } +func (t *threadSafeSet[T]) ContainsAnyElement(other Set[T]) bool { + o := other.(*threadSafeSet[T]) + + t.RLock() + o.RLock() + + ret := t.uss.ContainsAnyElement(o.uss) + + t.RUnlock() + o.RUnlock() + return ret +} + func (t *threadSafeSet[T]) IsEmpty() bool { return t.Cardinality() == 0 } diff --git a/vendor/github.com/deckarep/golang-set/v2/threadunsafe.go b/vendor/github.com/deckarep/golang-set/v2/threadunsafe.go index 7e3243b2..c95d32b4 100644 --- a/vendor/github.com/deckarep/golang-set/v2/threadunsafe.go +++ b/vendor/github.com/deckarep/golang-set/v2/threadunsafe.go @@ -109,6 +109,26 @@ func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool { return false } +func (s *threadUnsafeSet[T]) ContainsAnyElement(other Set[T]) bool { + o := other.(*threadUnsafeSet[T]) + + // loop over smaller set + if s.Cardinality() < other.Cardinality() { + for elem := range *s { + if o.contains(elem) { + return true + } + } + } else { + for elem := range *o { + if s.contains(elem) { + return true + } + } + } + return false +} + // private version of Contains for a single element v func (s *threadUnsafeSet[T]) contains(v T) (ok bool) { _, ok = (*s)[v] diff --git a/vendor/github.com/fatih/color/color.go b/vendor/github.com/fatih/color/color.go index ee39b408..d3906bfb 100644 --- a/vendor/github.com/fatih/color/color.go +++ b/vendor/github.com/fatih/color/color.go @@ -19,15 +19,15 @@ var ( // set (regardless of its value). This is a global option and affects all // colors. For more control over each color block use the methods // DisableColor() individually. - NoColor = noColorIsSet() || os.Getenv("TERM") == "dumb" || - (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) + NoColor = noColorIsSet() || os.Getenv("TERM") == "dumb" || !stdoutIsTerminal() // Output defines the standard output of the print functions. By default, - // os.Stdout is used. - Output = colorable.NewColorableStdout() + // stdOut() is used. + Output = stdOut() - // Error defines a color supporting writer for os.Stderr. - Error = colorable.NewColorableStderr() + // Error defines the standard error of the print functions. By default, + // stdErr() is used. + Error = stdErr() // colorsCache is used to reduce the count of created Color objects and // allows to reuse already created objects with required Attribute. @@ -40,6 +40,33 @@ func noColorIsSet() bool { return os.Getenv("NO_COLOR") != "" } +// stdoutIsTerminal returns true if os.Stdout is a terminal. +// Returns false if os.Stdout is nil (e.g., when running as a Windows service). +func stdoutIsTerminal() bool { + if os.Stdout == nil { + return false + } + return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) +} + +// stdOut returns a writer for color output. +// Returns io.Discard if os.Stdout is nil (e.g., when running as a Windows service). +func stdOut() io.Writer { + if os.Stdout == nil { + return io.Discard + } + return colorable.NewColorableStdout() +} + +// stdErr returns a writer for color error output. +// Returns io.Discard if os.Stderr is nil (e.g., when running as a Windows service). +func stdErr() io.Writer { + if os.Stderr == nil { + return io.Discard + } + return colorable.NewColorableStderr() +} + // Color defines a custom color object which is defined by SGR parameters. type Color struct { params []Attribute @@ -220,26 +247,30 @@ func (c *Color) unset() { // a low-level function, and users should use the higher-level functions, such // as color.Fprint, color.Print, etc. func (c *Color) SetWriter(w io.Writer) *Color { + _, _ = c.setWriter(w) + return c +} + +func (c *Color) setWriter(w io.Writer) (int, error) { if c.isNoColorSet() { - return c + return 0, nil } - fmt.Fprint(w, c.format()) - return c + return fmt.Fprint(w, c.format()) } // UnsetWriter resets all escape attributes and clears the output with the give // io.Writer. Usually should be called after SetWriter(). func (c *Color) UnsetWriter(w io.Writer) { - if c.isNoColorSet() { - return - } + _, _ = c.unsetWriter(w) +} - if NoColor { - return +func (c *Color) unsetWriter(w io.Writer) (int, error) { + if c.isNoColorSet() { + return 0, nil } - fmt.Fprintf(w, "%s[%dm", escape, Reset) + return fmt.Fprintf(w, "%s[%dm", escape, Reset) } // Add is used to chain SGR parameters. Use as many as parameters to combine @@ -255,10 +286,20 @@ func (c *Color) Add(value ...Attribute) *Color { // On Windows, users should wrap w with colorable.NewColorable() if w is of // type *os.File. func (c *Color) Fprint(w io.Writer, a ...interface{}) (n int, err error) { - c.SetWriter(w) - defer c.UnsetWriter(w) + n, err = c.setWriter(w) + if err != nil { + return n, err + } + + nn, err := fmt.Fprint(w, a...) + n += nn + if err != nil { + return + } - return fmt.Fprint(w, a...) + nn, err = c.unsetWriter(w) + n += nn + return n, err } // Print formats using the default formats for its operands and writes to @@ -278,10 +319,20 @@ func (c *Color) Print(a ...interface{}) (n int, err error) { // On Windows, users should wrap w with colorable.NewColorable() if w is of // type *os.File. func (c *Color) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { - c.SetWriter(w) - defer c.UnsetWriter(w) + n, err = c.setWriter(w) + if err != nil { + return n, err + } + + nn, err := fmt.Fprintf(w, format, a...) + n += nn + if err != nil { + return + } - return fmt.Fprintf(w, format, a...) + nn, err = c.unsetWriter(w) + n += nn + return n, err } // Printf formats according to a format specifier and writes to standard output. @@ -475,27 +526,24 @@ func (c *Color) Equals(c2 *Color) bool { if c == nil || c2 == nil { return false } + if len(c.params) != len(c2.params) { return false } + counts := make(map[Attribute]int, len(c.params)) for _, attr := range c.params { - if !c2.attrExists(attr) { - return false - } + counts[attr]++ } - return true -} - -func (c *Color) attrExists(a Attribute) bool { - for _, attr := range c.params { - if attr == a { - return true + for _, attr := range c2.params { + if counts[attr] == 0 { + return false } + counts[attr]-- } - return false + return true } func boolPtr(v bool) *bool { diff --git a/vendor/github.com/fatih/color/color_windows.go b/vendor/github.com/fatih/color/color_windows.go index be01c558..97e5a765 100644 --- a/vendor/github.com/fatih/color/color_windows.go +++ b/vendor/github.com/fatih/color/color_windows.go @@ -9,6 +9,9 @@ import ( func init() { // Opt-in for ansi color support for current process. // https://learn.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences#output-sequences + if os.Stdout == nil { + return + } var outMode uint32 out := windows.Handle(os.Stdout.Fd()) if err := windows.GetConsoleMode(out, &outMode); err != nil { diff --git a/vendor/github.com/fsnotify/fsnotify/.cirrus.yml b/vendor/github.com/fsnotify/fsnotify/.cirrus.yml deleted file mode 100644 index 7f257e99..00000000 --- a/vendor/github.com/fsnotify/fsnotify/.cirrus.yml +++ /dev/null @@ -1,14 +0,0 @@ -freebsd_task: - name: 'FreeBSD' - freebsd_instance: - image_family: freebsd-14-2 - install_script: - - pkg update -f - - pkg install -y go - test_script: - # run tests as user "cirrus" instead of root - - pw useradd cirrus -m - - chown -R cirrus:cirrus . - - FSNOTIFY_BUFFER=4096 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... - - sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race ./... - - FSNOTIFY_DEBUG=1 sudo --preserve-env=FSNOTIFY_BUFFER -u cirrus go test -parallel 1 -race -v ./... diff --git a/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md b/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md index 6468d2cf..3027f3c6 100644 --- a/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md +++ b/vendor/github.com/fsnotify/fsnotify/CHANGELOG.md @@ -1,5 +1,54 @@ # Changelog +1.10.1 2026-05-04 +----------------- + +### Changes and fixes + +- inotify: don't remove sibling watches sharing a path prefix ([#754]) + +- inotify, windows: don't rename sibling watches sharing a path prefix + ([#755]) + + +[#754]: https://github.com/fsnotify/fsnotify/pull/754 +[#755]: https://github.com/fsnotify/fsnotify/pull/755 + + +1.10.0 2026-04-30 +----------------- +This version of fsnotify needs Go 1.23. + +### Changes and fixes + +- inotify: improve initialization error message ([#731]) + +- inotify: send Rename event if recursive watch is renamed ([#696]) + +- inotify: avoid copying event buffers when reading names ([#741]) + +- kqueue: skip dangling symlinks (ENOENT) in watchDirectoryFiles, so a + bad entry no longer aborts Watcher.Add for the whole directory ([#748]) + +- kqueue: drop watches directly in Close() to fix a file descriptor leak + when recycling watchers ([#740]) + +- windows: fix nil pointer dereference in remWatch ([#736]) + +- windows: lock watch field updates against concurrent WatchList to fix + a race introduced in v1.9.0 ([#709], [#749]) + + +[#696]: https://github.com/fsnotify/fsnotify/pull/696 +[#709]: https://github.com/fsnotify/fsnotify/pull/709 +[#731]: https://github.com/fsnotify/fsnotify/pull/731 +[#736]: https://github.com/fsnotify/fsnotify/pull/736 +[#740]: https://github.com/fsnotify/fsnotify/pull/740 +[#741]: https://github.com/fsnotify/fsnotify/pull/741 +[#748]: https://github.com/fsnotify/fsnotify/pull/748 +[#749]: https://github.com/fsnotify/fsnotify/pull/749 + + 1.9.0 2024-04-04 ---------------- diff --git a/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md b/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md index 4cc40fa5..cd0ee612 100644 --- a/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md +++ b/vendor/github.com/fsnotify/fsnotify/CONTRIBUTING.md @@ -77,6 +77,8 @@ End-of-line escapes with `\` are not supported. debug [yes/no] # Enable/disable FSNOTIFY_DEBUG (tests are run in parallel by default, so -parallel=1 is probably a good idea). + state # Print internal state to stderr (exact output differs + # per backend). print [any strings] # Print text to stdout; for debugging. touch path diff --git a/vendor/github.com/fsnotify/fsnotify/README.md b/vendor/github.com/fsnotify/fsnotify/README.md index 1f4eb583..2e56ef4c 100644 --- a/vendor/github.com/fsnotify/fsnotify/README.md +++ b/vendor/github.com/fsnotify/fsnotify/README.md @@ -1,7 +1,7 @@ fsnotify is a Go library to provide cross-platform filesystem notifications on Windows, Linux, macOS, BSD, and illumos. -Go 1.17 or newer is required; the full documentation is at +Go 1.23 or newer is required; the full documentation is at https://pkg.go.dev/github.com/fsnotify/fsnotify --- @@ -12,7 +12,7 @@ Platform support: | :-------------------- | :--------- | :------------------------------------------------------------------------ | | inotify | Linux | Supported | | kqueue | BSD, macOS | Supported | -| ReadDirectoryChangesW | Windows | Supported | +| ReadDirectoryChangesW | Windows | Supported ([excluding `Chmod` operations][#487]) | | FEN | illumos | Supported | | fanotify | Linux 5.9+ | [Not yet](https://github.com/fsnotify/fsnotify/issues/114) | | FSEvents | macOS | [Needs support in x/sys/unix][fsevents] | @@ -22,6 +22,7 @@ Platform support: Linux and illumos should include Android and Solaris, but these are currently untested. +[#487]: https://github.com/fsnotify/fsnotify/issues/487 [fsevents]: https://github.com/fsnotify/fsnotify/issues/11#issuecomment-1279133120 [usn]: https://github.com/fsnotify/fsnotify/issues/53#issuecomment-1279829847 @@ -126,7 +127,7 @@ settings* until we have a native FSEvents implementation (see [#11]). ### Watching a file doesn't work well Watching individual files (rather than directories) is generally not recommended as many programs (especially editors) update files atomically: it will write to -a temporary file which is then moved to to destination, overwriting the original +a temporary file which is then moved to a destination, overwriting the original (or some variant thereof). The watcher on the original file is now lost, as that no longer exists. @@ -151,26 +152,57 @@ This is the event that inotify sends, so not much can be changed about this. The `fs.inotify.max_user_watches` sysctl variable specifies the upper limit for the number of watches per user, and `fs.inotify.max_user_instances` specifies the maximum number of inotify instances per user. Every Watcher you create is an -"instance", and every path you add is a "watch". +"instance", and every path you add is a "watch". Reaching the limit will result +in a "no space left on device" or "too many open files" error. These are also exposed in `/proc` as `/proc/sys/fs/inotify/max_user_watches` and -`/proc/sys/fs/inotify/max_user_instances` +`/proc/sys/fs/inotify/max_user_instances`. The default values differ per distro +and available memory. To increase them you can use `sysctl` or write the value to proc file: - # The default values on Linux 5.18 - sysctl fs.inotify.max_user_watches=124983 - sysctl fs.inotify.max_user_instances=128 + sysctl fs.inotify.max_user_watches=200000 + sysctl fs.inotify.max_user_instances=256 To make the changes persist on reboot edit `/etc/sysctl.conf` or `/usr/lib/sysctl.d/50-default.conf` (details differ per Linux distro; check your distro's documentation): - fs.inotify.max_user_watches=124983 - fs.inotify.max_user_instances=128 + fs.inotify.max_user_watches=200000 + fs.inotify.max_user_instances=256 + +### Windows +Recursive watching is not currently enabled through fsnotify's public API +(see the FAQ "Are subdirectories watched?" above). The notes below +describe Windows backend behavior observed when recursive watching is +enabled internally (for example, in fsnotify's own tests). They are kept +here as a reference for maintainers and contributors who encounter the +behavior, since the recursive code path still exists in the backend. + +When recursive watching is enabled and you watch a directory, you may +receive a `Write` event for an intermediate directory whenever a child +entry inside it is created, renamed, or removed. For example, with a +recursive watch on `/a` and a new file `/a/b/c`, you will receive +`Create /a/b/c` and may also receive `Write /a/b`. + +This happens because, on NTFS-backed volumes, modifying the entries of a +directory updates that directory's last-write time, and the Windows +backend requests `FILE_NOTIFY_CHANGE_LAST_WRITE` to support `Write` events +on files. The same `Write` filter therefore picks up the directory's +metadata update. + +kqueue has the same "directory `Write` = directory contents changed" +semantics, so portable code that treats `Write` on a directory as +"something inside it changed" works on Windows and BSD/macOS, but not on +Linux (inotify uses `Write` only for file-content changes). If you only +care about file content, filter out `Write` events whose path refers to a +directory. + +Whether the directory `Write` is actually delivered alongside the child +events is not guaranteed: it depends on `ReadDirectoryChangesW` buffering, +NTFS metadata update timing, and event coalescing, none of which fsnotify +controls. -Reaching the limit will result in a "no space left on device" or "too many open -files" error. ### kqueue (macOS, all BSD systems) kqueue requires opening a file descriptor for every file that's being watched; diff --git a/vendor/github.com/fsnotify/fsnotify/backend_fen.go b/vendor/github.com/fsnotify/fsnotify/backend_fen.go index 57fc6928..e43c6d08 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_fen.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_fen.go @@ -158,7 +158,9 @@ func (w *fen) readEvents() { pevents := make([]unix.PortEvent, 8) for { - count, err := w.port.Get(pevents, 1, nil) + count, err := internal.IgnoringEINTR(func() (int, error) { + return w.port.Get(pevents, 1, nil) + }) if err != nil && err != unix.ETIME { // Interrupted system call (count should be 0) ignore and continue if errors.Is(err, unix.EINTR) && count == 0 { diff --git a/vendor/github.com/fsnotify/fsnotify/backend_inotify.go b/vendor/github.com/fsnotify/fsnotify/backend_inotify.go index a36cb89d..4c3f6f7c 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_inotify.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_inotify.go @@ -55,10 +55,10 @@ type ( path map[string]uint32 // pathname → wd } watch struct { - wd uint32 // Watch descriptor (as returned by the inotify_add_watch() syscall) - flags uint32 // inotify flags of this watch (see inotify(7) for the list of valid flags) - path string // Watch path. - recurse bool // Recursion with ./...? + wd uint32 // Watch descriptor (as returned by the inotify_add_watch() syscall) + flags uint32 // inotify flags of this watch (see inotify(7) for the list of valid flags) + path string // Watch path. + watchFlags watchFlag } koekje struct { cookie uint32 @@ -66,6 +66,9 @@ type ( } ) +func (w watch) byUser() bool { return w.watchFlags&flagByUser != 0 } +func (w watch) recurse() bool { return w.watchFlags&flagRecurse != 0 } + func newWatches() *watches { return &watches{ wd: make(map[uint32]*watch), @@ -79,6 +82,13 @@ func (w *watches) len() int { return len(w.wd) } func (w *watches) add(ww *watch) { w.wd[ww.wd] = ww; w.path[ww.path] = ww.wd } func (w *watches) remove(watch *watch) { delete(w.path, watch.path); delete(w.wd, watch.wd) } +func isSameOrDescendantPath(path, root string) bool { + if path == root { + return true + } + return strings.HasPrefix(path, root+string(os.PathSeparator)) +} + func (w *watches) removePath(path string) ([]uint32, error) { path, recurse := recursivePath(path) wd, ok := w.path[path] @@ -87,20 +97,20 @@ func (w *watches) removePath(path string) ([]uint32, error) { } watch := w.wd[wd] - if recurse && !watch.recurse { + if recurse && !watch.recurse() { return nil, fmt.Errorf("can't use /... with non-recursive watch %q", path) } delete(w.path, path) delete(w.wd, wd) - if !watch.recurse { + if !watch.recurse() { return []uint32{wd}, nil } wds := make([]uint32, 0, 8) wds = append(wds, wd) for p, rwd := range w.path { - if strings.HasPrefix(p, path) { + if isSameOrDescendantPath(p, path) { delete(w.path, p) delete(w.wd, rwd) wds = append(wds, rwd) @@ -139,7 +149,7 @@ func newBackend(ev chan Event, errs chan error) (backend, error) { // I/O operations won't terminate on close. fd, errno := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK) if fd == -1 { - return nil, errno + return nil, fmt.Errorf("couldn't initialize inotify: %w", errno) } w := &inotify{ @@ -188,11 +198,8 @@ func (w *inotify) AddWith(path string, opts ...addOpt) error { return fmt.Errorf("%w: %s", xErrUnsupported, with.op) } - add := func(path string, with withOpts, recurse bool) error { + add := func(path string, with withOpts, wf watchFlag) error { var flags uint32 - if with.noFollow { - flags |= unix.IN_DONT_FOLLOW - } if with.op.Has(Create) { flags |= unix.IN_CREATE } @@ -220,7 +227,7 @@ func (w *inotify) AddWith(path string, opts ...addOpt) error { if with.op.Has(xUnportableCloseRead) { flags |= unix.IN_CLOSE_NOWRITE } - return w.register(path, flags, recurse) + return w.register(path, flags, wf) } w.mu.Lock() @@ -248,14 +255,18 @@ func (w *inotify) AddWith(path string, opts ...addOpt) error { w.sendEvent(Event{Name: root, Op: Create}) } - return add(root, with, true) + wf := flagRecurse + if root == path { + wf |= flagByUser + } + return add(root, with, wf) }) } - return add(path, with, false) + return add(path, with, 0) } -func (w *inotify) register(path string, flags uint32, recurse bool) error { +func (w *inotify) register(path string, flags uint32, wf watchFlag) error { return w.watches.updatePath(path, func(existing *watch) (*watch, error) { if existing != nil { flags |= existing.flags | unix.IN_MASK_ADD @@ -272,10 +283,10 @@ func (w *inotify) register(path string, flags uint32, recurse bool) error { if existing == nil { return &watch{ - wd: uint32(wd), - path: path, - flags: flags, - recurse: recurse, + wd: uint32(wd), + path: path, + flags: flags, + watchFlags: wf, }, nil } @@ -425,11 +436,7 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs nameLen = uint32(inEvent.Len) ) if nameLen > 0 { - /// Point "bytes" at the first byte of the filename - bb := *buf - bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&bb[offset+unix.SizeofInotifyEvent]))[:nameLen:nameLen] - /// The filename is padded with NULL bytes. TrimRight() gets rid of those. - name += "/" + strings.TrimRight(string(bytes[0:nameLen]), "\x00") + name += "/" + inotifyEventName(buf, offset, nameLen) } if debug { @@ -450,7 +457,9 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs // We can't really update the state when a watched path is moved; only // IN_MOVE_SELF is sent and not IN_MOVED_{FROM,TO}. So remove the watch. if inEvent.Mask&unix.IN_MOVE_SELF == unix.IN_MOVE_SELF { - if watch.recurse { // Do nothing + // Watch is set up as part of recurse: do nothing as the move gets + // registered from the parent directory. + if watch.recurse() && !watch.byUser() { return Event{}, true } @@ -460,6 +469,10 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs return Event{}, false } } + + if watch.recurse() { + return Event{Name: watch.path, Op: Rename}, true + } } /// Skip if we're watching both this path and the parent; the parent will @@ -473,11 +486,11 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs ev := w.newEvent(name, inEvent.Mask, inEvent.Cookie) // Need to update watch path for recurse. - if watch.recurse { + if watch.recurse() { isDir := inEvent.Mask&unix.IN_ISDIR == unix.IN_ISDIR /// New directory created: set up watch on it. if isDir && ev.Has(Create) { - err := w.register(ev.Name, watch.flags, true) + err := w.register(ev.Name, watch.flags, flagRecurse) if !w.sendError(err) { return Event{}, false } @@ -495,7 +508,7 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs if k == watch.wd || ww.path == ev.Name { continue } - if strings.HasPrefix(ww.path, ev.renamedFrom) { + if isSameOrDescendantPath(ww.path, ev.renamedFrom) { ww.path = strings.Replace(ww.path, ev.renamedFrom, ev.Name, 1) w.watches.wd[k] = ww } @@ -507,12 +520,13 @@ func (w *inotify) handleEvent(inEvent *unix.InotifyEvent, buf *[65536]byte, offs return ev, true } -func (w *inotify) isRecursive(path string) bool { - ww := w.watches.byPath(path) - if ww == nil { // path could be a file, so also check the Dir. - ww = w.watches.byPath(filepath.Dir(path)) +func inotifyEventName(buf *[65536]byte, offset, nameLen uint32) string { + start := int(offset + unix.SizeofInotifyEvent) + bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[start]))[:nameLen:nameLen] + for nameLen > 0 && bytes[nameLen-1] == 0 { + nameLen-- } - return ww != nil && ww.recurse + return string(bytes[:nameLen]) } func (w *inotify) newEvent(name string, mask, cookie uint32) Event { @@ -578,6 +592,6 @@ func (w *inotify) state() { w.mu.Lock() defer w.mu.Unlock() for wd, ww := range w.watches.wd { - fmt.Fprintf(os.Stderr, "%4d: recurse=%t %q\n", wd, ww.recurse, ww.path) + fmt.Fprintf(os.Stderr, "%4d: %q watchFlags=0x%x\n", wd, ww.path, ww.watchFlags) } } diff --git a/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go b/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go index 340aeec0..d2c8cfb6 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_kqueue.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "runtime" + "sort" "sync" "time" @@ -245,9 +246,26 @@ func (w *kqueue) Close() error { return nil } + // Snapshot and drop all watches directly. w.Remove -> w.remove + // short-circuits on isClosed() (which is already true after + // w.shared.close() above), so calling Remove here in the happy path + // leaked every watched directory + file descriptor. On macOS a + // single directory watch opens an fd for every file in the dir, so + // long-running processes that recreate watchers (hot-reload dev + // servers, etc.) ran out of fds with EMFILE (#732). pathsToRemove := w.watches.listPaths(false) for _, name := range pathsToRemove { - w.Remove(name) + info, ok := w.watches.byPath(name) + if !ok { + // w.path has an entry for name but w.wd doesn't -- + // drop the stale lookup entry so the map state is + // consistent after Close. + w.watches.remove(0, name) + continue + } + _ = w.register([]int{info.wd}, unix.EV_DELETE, 0) + unix.Close(info.wd) + w.watches.remove(info.wd, name) } unix.Close(w.closepipe[1]) // Send "quit" message to readEvents @@ -376,19 +394,12 @@ func (w *kqueue) addWatch(name string, flags uint32, listDir bool) (string, erro } } - // Retry on EINTR; open() can return EINTR in practice on macOS. - // See #354, and Go issues 11180 and 39237. - for { - info.wd, err = unix.Open(name, openMode, 0) - if err == nil { - break - } - if errors.Is(err, unix.EINTR) { - continue - } + info.wd, err = internal.IgnoringEINTR(func() (int, error) { + return unix.Open(name, openMode, 0) + }) + if err != nil { return "", err } - info.isDir = fi.IsDir() } @@ -436,9 +447,10 @@ func (w *kqueue) readEvents() { eventBuffer := make([]unix.Kevent_t, 10) for { - kevents, err := w.read(eventBuffer) - // EINTR is okay, the syscall was interrupted before timeout expired. - if err != nil && err != unix.EINTR { + kevents, err := internal.IgnoringEINTR(func() ([]unix.Kevent_t, error) { + return w.read(eventBuffer) + }) + if err != nil { if !w.sendError(fmt.Errorf("fsnotify.readEvents: %w", err)) { return } @@ -583,12 +595,14 @@ func (w *kqueue) watchDirectoryFiles(dirPath string) error { cleanPath, err := w.internalWatch(path, fi) if err != nil { - // No permission to read the file; that's not a problem: just skip. - // But do add it to w.fileExists to prevent it from being picked up - // as a "new" file later (it still shows up in the directory + // No permission, or the entry resolved to a missing target + // (e.g. a dangling symlink): not a problem, just skip. But + // do mark it as seen to prevent it from being picked up as + // a "new" file later (it still shows up in the directory // listing). switch { - case errors.Is(err, unix.EACCES) || errors.Is(err, unix.EPERM): + case errors.Is(err, unix.EACCES) || errors.Is(err, unix.EPERM) || + errors.Is(err, os.ErrNotExist): cleanPath = filepath.Clean(path) default: return fmt.Errorf("%q: %w", path, err) @@ -703,3 +717,19 @@ func (w *kqueue) xSupports(op Op) bool { } return true } + +func (w *kqueue) state() { + w.watches.mu.Lock() + defer w.watches.mu.Unlock() + + all := make([]int, 0, len(w.watches.wd)) + for wd := range w.watches.wd { + all = append(all, wd) + } + sort.Ints(all) + + for _, wd := range all { + ww := w.watches.wd[wd] + fmt.Fprintf(os.Stderr, "%4d %q linkname=%q\n", wd, ww.name, ww.linkName) + } +} diff --git a/vendor/github.com/fsnotify/fsnotify/backend_windows.go b/vendor/github.com/fsnotify/fsnotify/backend_windows.go index 3433642d..fb9210f2 100644 --- a/vendor/github.com/fsnotify/fsnotify/backend_windows.go +++ b/vendor/github.com/fsnotify/fsnotify/backend_windows.go @@ -11,7 +11,6 @@ import ( "fmt" "os" "path/filepath" - "reflect" "runtime" "strings" "sync" @@ -37,6 +36,13 @@ type readDirChangesW struct { var defaultBufferSize = 50 +func isSameOrDescendantPath(path, root string) bool { + if path == root { + return true + } + return strings.HasPrefix(path, root+string(os.PathSeparator)) +} + func newBackend(ev chan Event, errs chan error) (backend, error) { port, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { @@ -359,22 +365,26 @@ func (w *readDirChangesW) addWatch(pathname string, flags uint64, bufsize int) e } else { windows.CloseHandle(ino.handle) } + w.mu.Lock() if pathname == dir { watchEntry.mask |= flags } else { watchEntry.names[filepath.Base(pathname)] |= flags } + w.mu.Unlock() err = w.startRead(watchEntry) if err != nil { return err } + w.mu.Lock() if pathname == dir { watchEntry.mask &= ^provisional } else { watchEntry.names[filepath.Base(pathname)] &= ^provisional } + w.mu.Unlock() return nil } @@ -394,8 +404,13 @@ func (w *readDirChangesW) remWatch(pathname string) error { w.mu.Lock() watch := w.watches.get(ino) w.mu.Unlock() + if watch == nil { + windows.CloseHandle(ino.handle) + return fmt.Errorf("%w: %s", ErrNonExistentWatch, pathname) + } if recurse && !watch.recurse { + windows.CloseHandle(ino.handle) return fmt.Errorf("can't use \\... with non-recursive watch %q", pathname) } @@ -403,16 +418,19 @@ func (w *readDirChangesW) remWatch(pathname string) error { if err != nil { w.sendError(os.NewSyscallError("CloseHandle", err)) } - if watch == nil { - return fmt.Errorf("%w: %s", ErrNonExistentWatch, pathname) - } if pathname == dir { - w.sendEvent(watch.path, "", watch.mask&sysFSIGNORED) + w.mu.Lock() + mask := watch.mask watch.mask = 0 + w.mu.Unlock() + w.sendEvent(watch.path, "", mask&sysFSIGNORED) } else { name := filepath.Base(pathname) - w.sendEvent(filepath.Join(watch.path, name), "", watch.names[name]&sysFSIGNORED) + w.mu.Lock() + mask := watch.names[name] delete(watch.names, name) + w.mu.Unlock() + w.sendEvent(filepath.Join(watch.path, name), "", mask&sysFSIGNORED) } return w.startRead(watch) @@ -420,17 +438,23 @@ func (w *readDirChangesW) remWatch(pathname string) error { // Must run within the I/O thread. func (w *readDirChangesW) deleteWatch(watch *watch) { - for name, mask := range watch.names { - if mask&provisional == 0 { - w.sendEvent(filepath.Join(watch.path, name), "", mask&sysFSIGNORED) + // Snapshot+clear under the lock so concurrent WatchList() readers see a + // consistent state. sendEvent must run outside the lock since it can + // block on the user-facing Events channel. + w.mu.Lock() + names := watch.names + watch.names = make(map[string]uint64) + mask := watch.mask + watch.mask = 0 + w.mu.Unlock() + + for name, m := range names { + if m&provisional == 0 { + w.sendEvent(filepath.Join(watch.path, name), "", m&sysFSIGNORED) } - delete(watch.names, name) } - if watch.mask != 0 { - if watch.mask&provisional == 0 { - w.sendEvent(watch.path, "", watch.mask&sysFSIGNORED) - } - watch.mask = 0 + if mask != 0 && mask&provisional == 0 { + w.sendEvent(watch.path, "", mask&sysFSIGNORED) } } @@ -457,9 +481,8 @@ func (w *readDirChangesW) startRead(watch *watch) error { } // We need to pass the array, rather than the slice. - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&watch.buf)) rdErr := windows.ReadDirectoryChanges(watch.ino.handle, - (*byte)(unsafe.Pointer(hdr.Data)), uint32(hdr.Len), + unsafe.SliceData(watch.buf), uint32(len(watch.buf)), watch.recurse, mask, nil, &watch.ov, 0) if rdErr != nil { err := os.NewSyscallError("ReadDirectoryChanges", rdErr) @@ -565,12 +588,7 @@ func (w *readDirChangesW) readEvents() { // Create a buf that is the size of the path name size := int(raw.FileNameLength / 2) - var buf []uint16 - // TODO: Use unsafe.Slice in Go 1.17; https://stackoverflow.com/questions/51187973 - sh := (*reflect.SliceHeader)(unsafe.Pointer(&buf)) - sh.Data = uintptr(unsafe.Pointer(&raw.FileName)) - sh.Len = size - sh.Cap = size + buf := unsafe.Slice(&raw.FileName, size) name := windows.UTF16ToString(buf) fullname := filepath.Join(watch.path, name) @@ -587,31 +605,35 @@ func (w *readDirChangesW) readEvents() { case windows.FILE_ACTION_RENAMED_OLD_NAME: watch.rename = name case windows.FILE_ACTION_RENAMED_NEW_NAME: - // Update saved path of all sub-watches. + // Update saved path of all sub-watches and rename the + // names entry under the lock so WatchList() can't observe + // a torn state. old := filepath.Join(watch.path, watch.rename) w.mu.Lock() for _, watchMap := range w.watches { for _, ww := range watchMap { - if strings.HasPrefix(ww.path, old) { + if isSameOrDescendantPath(ww.path, old) { ww.path = filepath.Join(fullname, strings.TrimPrefix(ww.path, old)) } } } - w.mu.Unlock() - if watch.names[watch.rename] != 0 { watch.names[name] |= watch.names[watch.rename] delete(watch.names, watch.rename) mask = sysFSMOVESELF } + w.mu.Unlock() } if raw.Action != windows.FILE_ACTION_RENAMED_NEW_NAME { w.sendEvent(fullname, "", watch.names[name]&mask) } if raw.Action == windows.FILE_ACTION_REMOVED { - w.sendEvent(fullname, "", watch.names[name]&sysFSIGNORED) + w.mu.Lock() + ignored := watch.names[name] & sysFSIGNORED delete(watch.names, name) + w.mu.Unlock() + w.sendEvent(fullname, "", ignored) } if watch.rename != "" && raw.Action == windows.FILE_ACTION_RENAMED_NEW_NAME { diff --git a/vendor/github.com/fsnotify/fsnotify/fsnotify.go b/vendor/github.com/fsnotify/fsnotify/fsnotify.go index f64be4bf..38cb4dd4 100644 --- a/vendor/github.com/fsnotify/fsnotify/fsnotify.go +++ b/vendor/github.com/fsnotify/fsnotify/fsnotify.go @@ -51,26 +51,25 @@ import ( // The fs.inotify.max_user_watches sysctl variable specifies the upper limit // for the number of watches per user, and fs.inotify.max_user_instances // specifies the maximum number of inotify instances per user. Every Watcher you -// create is an "instance", and every path you add is a "watch". +// create is an "instance", and every path you add is a "watch". Reaching the +// limit will result in a "no space left on device" or "too many open files" +// error. // // These are also exposed in /proc as /proc/sys/fs/inotify/max_user_watches and -// /proc/sys/fs/inotify/max_user_instances +// /proc/sys/fs/inotify/max_user_instances. The default values differ per distro +// and available memory. // // To increase them you can use sysctl or write the value to the /proc file: // -// # Default values on Linux 5.18 -// sysctl fs.inotify.max_user_watches=124983 -// sysctl fs.inotify.max_user_instances=128 +// sysctl fs.inotify.max_user_watches=200000 +// sysctl fs.inotify.max_user_instances=256 // // To make the changes persist on reboot edit /etc/sysctl.conf or // /usr/lib/sysctl.d/50-default.conf (details differ per Linux distro; check // your distro's documentation): // -// fs.inotify.max_user_watches=124983 -// fs.inotify.max_user_instances=128 -// -// Reaching the limit will result in a "no space left on device" or "too many open -// files" error. +// fs.inotify.max_user_watches=200000 +// fs.inotify.max_user_instances=256 // // # kqueue notes (macOS, BSD) // @@ -93,6 +92,28 @@ import ( // Sometimes it will send events for all files, sometimes it will send no // events, and often only for some files. // +// Recursive watching is not currently enabled through fsnotify's public +// API; the recursive code path is gated and only exercised by fsnotify's +// own tests. The note below describes backend behavior observed when +// recursive watching is enabled internally, and is kept here as a +// reference for maintainers and contributors who encounter it. +// +// When recursive watching is enabled and you watch a directory, you may +// receive a Write event for an intermediate directory whenever a child +// entry inside it is created, renamed, or removed. For example, with a +// recursive watch on /a and a new file /a/b/c, you will receive +// Create /a/b/c and may also receive Write /a/b. +// +// This happens because, on NTFS-backed volumes, modifying the entries of a +// directory updates that directory's last-write time, and the Windows +// backend requests FILE_NOTIFY_CHANGE_LAST_WRITE to support Write events +// on files. The same Write filter therefore picks up the directory's +// metadata update. +// +// Whether the directory Write is actually delivered alongside the child +// events is not guaranteed; it depends on ReadDirectoryChangesW buffering, +// NTFS metadata update timing, and event coalescing. +// // The default ReadDirectoryChangesW() buffer size is 64K, which is the largest // value that is guaranteed to work with SMB filesystems. If you have many // events in quick succession this may not be enough, and you will have to use @@ -129,8 +150,12 @@ type Watcher struct { // want to wait until you've stopped receiving them // (see the dedup example in cmd/fsnotify). // - // Some systems may send Write event for directories - // when the directory content changes. + // Some systems also send Write events for directories + // when the directory contents change. This is the + // case for kqueue, and on Windows for the directory + // that contains a created, renamed, or removed child + // entry. It does not happen on inotify. See the + // per-platform notes on [Watcher]. // // fsnotify.Chmod Attributes were changed. On Linux this is also sent // when a file is removed (or more accurately, when a @@ -179,7 +204,9 @@ const ( Create Op = 1 << iota // The pathname was written to; this does *not* mean the write has finished, - // and a write can be followed by more writes. + // and a write can be followed by more writes. On Windows and kqueue, a + // Write on a directory can also indicate that its contents changed; see + // the per-platform notes on [Watcher]. Write // The path was removed; any watches on it will be removed. Some "remove" @@ -220,7 +247,7 @@ const ( // File opened for reading was closed. // - // Only works on Linux and FreeBSD. + // Only works on Linux. xUnportableCloseRead ) @@ -410,7 +437,6 @@ type ( withOpts struct { bufsize int op Op - noFollow bool sendCreate bool } ) @@ -469,12 +495,6 @@ func withOps(op Op) addOpt { return func(opt *withOpts) { opt.op = op } } -// WithNoFollow disables following symlinks, so the symlinks themselves are -// watched. -func withNoFollow() addOpt { - return func(opt *withOpts) { opt.noFollow = true } -} - // "Internal" option for recursive watches on inotify. func withCreate() addOpt { return func(opt *withOpts) { opt.sendCreate = true } @@ -494,3 +514,13 @@ func recursivePath(path string) (string, bool) { } return path, false } + +type watchFlag uint8 + +const ( + // Added by user with Add(), rather than an internal watch. + flagByUser = watchFlag(0x01) + // Part of recursive watch; as the top-level path added by the user or an + // "internal" watch. + flagRecurse = watchFlag(0x02) +) diff --git a/vendor/github.com/fsnotify/fsnotify/internal/darwin.go b/vendor/github.com/fsnotify/fsnotify/internal/darwin.go index 0b01bc18..6721aa60 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/darwin.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/darwin.go @@ -15,25 +15,6 @@ var ( var maxfiles uint64 -func SetRlimit() { - // Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ - var l syscall.Rlimit - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) - if err == nil && l.Cur != l.Max { - l.Cur = l.Max - syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) - } - maxfiles = l.Cur - - if n, err := syscall.SysctlUint32("kern.maxfiles"); err == nil && uint64(n) < maxfiles { - maxfiles = uint64(n) - } - - if n, err := syscall.SysctlUint32("kern.maxfilesperproc"); err == nil && uint64(n) < maxfiles { - maxfiles = uint64(n) - } -} - func Maxfiles() uint64 { return maxfiles } func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, dev) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go index 928319fb..76001807 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_darwin.go @@ -6,52 +6,10 @@ var names = []struct { n string m uint32 }{ - {"NOTE_ABSOLUTE", unix.NOTE_ABSOLUTE}, {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, - {"NOTE_BACKGROUND", unix.NOTE_BACKGROUND}, - {"NOTE_CHILD", unix.NOTE_CHILD}, - {"NOTE_CRITICAL", unix.NOTE_CRITICAL}, {"NOTE_DELETE", unix.NOTE_DELETE}, - {"NOTE_EXEC", unix.NOTE_EXEC}, - {"NOTE_EXIT", unix.NOTE_EXIT}, - {"NOTE_EXITSTATUS", unix.NOTE_EXITSTATUS}, - {"NOTE_EXIT_CSERROR", unix.NOTE_EXIT_CSERROR}, - {"NOTE_EXIT_DECRYPTFAIL", unix.NOTE_EXIT_DECRYPTFAIL}, - {"NOTE_EXIT_DETAIL", unix.NOTE_EXIT_DETAIL}, - {"NOTE_EXIT_DETAIL_MASK", unix.NOTE_EXIT_DETAIL_MASK}, - {"NOTE_EXIT_MEMORY", unix.NOTE_EXIT_MEMORY}, - {"NOTE_EXIT_REPARENTED", unix.NOTE_EXIT_REPARENTED}, {"NOTE_EXTEND", unix.NOTE_EXTEND}, - {"NOTE_FFAND", unix.NOTE_FFAND}, - {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, - {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, - {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, - {"NOTE_FFNOP", unix.NOTE_FFNOP}, - {"NOTE_FFOR", unix.NOTE_FFOR}, - {"NOTE_FORK", unix.NOTE_FORK}, - {"NOTE_FUNLOCK", unix.NOTE_FUNLOCK}, - {"NOTE_LEEWAY", unix.NOTE_LEEWAY}, {"NOTE_LINK", unix.NOTE_LINK}, - {"NOTE_LOWAT", unix.NOTE_LOWAT}, - {"NOTE_MACHTIME", unix.NOTE_MACHTIME}, - {"NOTE_MACH_CONTINUOUS_TIME", unix.NOTE_MACH_CONTINUOUS_TIME}, - {"NOTE_NONE", unix.NOTE_NONE}, - {"NOTE_NSECONDS", unix.NOTE_NSECONDS}, - {"NOTE_OOB", unix.NOTE_OOB}, - //{"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, -0x100000 (?!) - {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, - {"NOTE_REAP", unix.NOTE_REAP}, {"NOTE_RENAME", unix.NOTE_RENAME}, - {"NOTE_REVOKE", unix.NOTE_REVOKE}, - {"NOTE_SECONDS", unix.NOTE_SECONDS}, - {"NOTE_SIGNAL", unix.NOTE_SIGNAL}, - {"NOTE_TRACK", unix.NOTE_TRACK}, - {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, - {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, - {"NOTE_USECONDS", unix.NOTE_USECONDS}, - {"NOTE_VM_ERROR", unix.NOTE_VM_ERROR}, - {"NOTE_VM_PRESSURE", unix.NOTE_VM_PRESSURE}, - {"NOTE_VM_PRESSURE_SUDDEN_TERMINATE", unix.NOTE_VM_PRESSURE_SUDDEN_TERMINATE}, - {"NOTE_VM_PRESSURE_TERMINATE", unix.NOTE_VM_PRESSURE_TERMINATE}, {"NOTE_WRITE", unix.NOTE_WRITE}, } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go index 3186b0c3..76001807 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_dragonfly.go @@ -7,27 +7,9 @@ var names = []struct { m uint32 }{ {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, - {"NOTE_CHILD", unix.NOTE_CHILD}, {"NOTE_DELETE", unix.NOTE_DELETE}, - {"NOTE_EXEC", unix.NOTE_EXEC}, - {"NOTE_EXIT", unix.NOTE_EXIT}, {"NOTE_EXTEND", unix.NOTE_EXTEND}, - {"NOTE_FFAND", unix.NOTE_FFAND}, - {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, - {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, - {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, - {"NOTE_FFNOP", unix.NOTE_FFNOP}, - {"NOTE_FFOR", unix.NOTE_FFOR}, - {"NOTE_FORK", unix.NOTE_FORK}, {"NOTE_LINK", unix.NOTE_LINK}, - {"NOTE_LOWAT", unix.NOTE_LOWAT}, - {"NOTE_OOB", unix.NOTE_OOB}, - {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, - {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, {"NOTE_RENAME", unix.NOTE_RENAME}, - {"NOTE_REVOKE", unix.NOTE_REVOKE}, - {"NOTE_TRACK", unix.NOTE_TRACK}, - {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, - {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, {"NOTE_WRITE", unix.NOTE_WRITE}, } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go index f69fdb93..b9e45f55 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_freebsd.go @@ -6,37 +6,15 @@ var names = []struct { n string m uint32 }{ - {"NOTE_ABSTIME", unix.NOTE_ABSTIME}, - {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, - {"NOTE_CHILD", unix.NOTE_CHILD}, - {"NOTE_CLOSE", unix.NOTE_CLOSE}, - {"NOTE_CLOSE_WRITE", unix.NOTE_CLOSE_WRITE}, {"NOTE_DELETE", unix.NOTE_DELETE}, - {"NOTE_EXEC", unix.NOTE_EXEC}, - {"NOTE_EXIT", unix.NOTE_EXIT}, + {"NOTE_WRITE", unix.NOTE_WRITE}, {"NOTE_EXTEND", unix.NOTE_EXTEND}, - {"NOTE_FFAND", unix.NOTE_FFAND}, - {"NOTE_FFCOPY", unix.NOTE_FFCOPY}, - {"NOTE_FFCTRLMASK", unix.NOTE_FFCTRLMASK}, - {"NOTE_FFLAGSMASK", unix.NOTE_FFLAGSMASK}, - {"NOTE_FFNOP", unix.NOTE_FFNOP}, - {"NOTE_FFOR", unix.NOTE_FFOR}, - {"NOTE_FILE_POLL", unix.NOTE_FILE_POLL}, - {"NOTE_FORK", unix.NOTE_FORK}, + {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, {"NOTE_LINK", unix.NOTE_LINK}, - {"NOTE_LOWAT", unix.NOTE_LOWAT}, - {"NOTE_MSECONDS", unix.NOTE_MSECONDS}, - {"NOTE_NSECONDS", unix.NOTE_NSECONDS}, - {"NOTE_OPEN", unix.NOTE_OPEN}, - {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, - {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, - {"NOTE_READ", unix.NOTE_READ}, {"NOTE_RENAME", unix.NOTE_RENAME}, {"NOTE_REVOKE", unix.NOTE_REVOKE}, - {"NOTE_SECONDS", unix.NOTE_SECONDS}, - {"NOTE_TRACK", unix.NOTE_TRACK}, - {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, - {"NOTE_TRIGGER", unix.NOTE_TRIGGER}, - {"NOTE_USECONDS", unix.NOTE_USECONDS}, - {"NOTE_WRITE", unix.NOTE_WRITE}, + {"NOTE_OPEN", unix.NOTE_OPEN}, + {"NOTE_CLOSE", unix.NOTE_CLOSE}, + {"NOTE_CLOSE_WRITE", unix.NOTE_CLOSE_WRITE}, + {"NOTE_READ", unix.NOTE_READ}, } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go index 607e683b..5d811643 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_kqueue.go @@ -27,6 +27,6 @@ func Debug(name string, kevent *unix.Kevent_t) { if unknown > 0 { l = append(l, fmt.Sprintf("0x%x", unknown)) } - fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %10d:%-60s → %q\n", + fmt.Fprintf(os.Stderr, "FSNOTIFY_DEBUG: %s %10d:%-20s → %q\n", time.Now().Format("15:04:05.000000000"), mask, strings.Join(l, " | "), name) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go index e5b3b6f6..76001807 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_netbsd.go @@ -7,19 +7,9 @@ var names = []struct { m uint32 }{ {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, - {"NOTE_CHILD", unix.NOTE_CHILD}, {"NOTE_DELETE", unix.NOTE_DELETE}, - {"NOTE_EXEC", unix.NOTE_EXEC}, - {"NOTE_EXIT", unix.NOTE_EXIT}, {"NOTE_EXTEND", unix.NOTE_EXTEND}, - {"NOTE_FORK", unix.NOTE_FORK}, {"NOTE_LINK", unix.NOTE_LINK}, - {"NOTE_LOWAT", unix.NOTE_LOWAT}, - {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, - {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, {"NOTE_RENAME", unix.NOTE_RENAME}, - {"NOTE_REVOKE", unix.NOTE_REVOKE}, - {"NOTE_TRACK", unix.NOTE_TRACK}, - {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, {"NOTE_WRITE", unix.NOTE_WRITE}, } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go b/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go index 1dd455bc..871766d6 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/debug_openbsd.go @@ -7,22 +7,10 @@ var names = []struct { m uint32 }{ {"NOTE_ATTRIB", unix.NOTE_ATTRIB}, - // {"NOTE_CHANGE", unix.NOTE_CHANGE}, // Not on 386? - {"NOTE_CHILD", unix.NOTE_CHILD}, {"NOTE_DELETE", unix.NOTE_DELETE}, - {"NOTE_EOF", unix.NOTE_EOF}, - {"NOTE_EXEC", unix.NOTE_EXEC}, - {"NOTE_EXIT", unix.NOTE_EXIT}, {"NOTE_EXTEND", unix.NOTE_EXTEND}, - {"NOTE_FORK", unix.NOTE_FORK}, {"NOTE_LINK", unix.NOTE_LINK}, - {"NOTE_LOWAT", unix.NOTE_LOWAT}, - {"NOTE_PCTRLMASK", unix.NOTE_PCTRLMASK}, - {"NOTE_PDATAMASK", unix.NOTE_PDATAMASK}, {"NOTE_RENAME", unix.NOTE_RENAME}, - {"NOTE_REVOKE", unix.NOTE_REVOKE}, - {"NOTE_TRACK", unix.NOTE_TRACK}, - {"NOTE_TRACKERR", unix.NOTE_TRACKERR}, {"NOTE_TRUNCATE", unix.NOTE_TRUNCATE}, {"NOTE_WRITE", unix.NOTE_WRITE}, } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go b/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go index 5ac8b507..758a2490 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/freebsd.go @@ -15,17 +15,6 @@ var ( var maxfiles uint64 -func SetRlimit() { - // Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ - var l syscall.Rlimit - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) - if err == nil && l.Cur != l.Max { - l.Cur = l.Max - syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) - } - maxfiles = uint64(l.Cur) -} - func Maxfiles() uint64 { return maxfiles } func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, uint64(dev)) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/unix.go b/vendor/github.com/fsnotify/fsnotify/internal/unix.go index b251fb80..9c66f5d3 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/unix.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/unix.go @@ -15,17 +15,6 @@ var ( var maxfiles uint64 -func SetRlimit() { - // Go 1.19 will do this automatically: https://go-review.googlesource.com/c/go/+/393354/ - var l syscall.Rlimit - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &l) - if err == nil && l.Cur != l.Max { - l.Cur = l.Max - syscall.Setrlimit(syscall.RLIMIT_NOFILE, &l) - } - maxfiles = uint64(l.Cur) -} - func Maxfiles() uint64 { return maxfiles } func Mkfifo(path string, mode uint32) error { return unix.Mkfifo(path, mode) } func Mknod(path string, mode uint32, dev int) error { return unix.Mknod(path, mode, dev) } diff --git a/vendor/github.com/fsnotify/fsnotify/internal/unix2.go b/vendor/github.com/fsnotify/fsnotify/internal/unix2.go index 37dfeddc..b2d89592 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/unix2.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/unix2.go @@ -2,6 +2,24 @@ package internal +import "syscall" + func HasPrivilegesForSymlink() bool { return true } + +// IgnoringEINTR makes a function call and repeats it if it returns an +// EINTR error. This appears to be required even though we install all +// signal handlers with SA_RESTART: see #22838, #38033, #38836, #40846. +// Also #20400 and #36644 are issues in which a signal handler is +// installed without setting SA_RESTART. None of these are the common case, +// but there are enough of them that it seems that we can't avoid +// an EINTR loop. +func IgnoringEINTR[T any](fn func() (T, error)) (T, error) { + for { + v, err := fn() + if err != syscall.EINTR { + return v, err + } + } +} diff --git a/vendor/github.com/fsnotify/fsnotify/internal/windows.go b/vendor/github.com/fsnotify/fsnotify/internal/windows.go index 896bc2e5..e24d5692 100644 --- a/vendor/github.com/fsnotify/fsnotify/internal/windows.go +++ b/vendor/github.com/fsnotify/fsnotify/internal/windows.go @@ -14,7 +14,6 @@ var ( ErrUnixEACCES = errors.New("dummy") ) -func SetRlimit() {} func Maxfiles() uint64 { return 1<<64 - 1 } func Mkfifo(path string, mode uint32) error { return errors.New("no FIFOs on Windows") } func Mknod(path string, mode uint32, dev int) error { return errors.New("no device nodes on Windows") } diff --git a/vendor/github.com/go-pkgz/lgr/CLAUDE.md b/vendor/github.com/go-pkgz/lgr/CLAUDE.md new file mode 100644 index 00000000..76818245 --- /dev/null +++ b/vendor/github.com/go-pkgz/lgr/CLAUDE.md @@ -0,0 +1,28 @@ +# Go-PKGZ/LGR Development Guidelines + +## Build & Test Commands +- Build: `go build -race` +- Test all: `go test -timeout=60s -race -covermode=atomic -coverprofile=profile.cov` +- Test single file: `go test -run TestName` +- Benchmark: `go test -bench=. -run=Bench` +- Lint: `golangci-lint run` + +## Code Style Guidelines +- Go 1.21 compatibility required +- Maximum line length: 140 characters +- No package names with underscores +- Use early returns (enforced by prealloc linter) +- Test files use testify for assertions: `require` for fatal assertions, `assert` for non-fatal ones +- Indent with tabs, not spaces + +## Error Handling +- FATAL logs to stderr and calls os.Exit(1) +- ERROR logs to both stdout and stderr +- PANIC logs stack trace and runtime info to stderr +- Stack traces for ERROR level can be enabled with StackTraceOnError option + +## Project Conventions +- Public API follows interface-based design (`lgr.L` interface) +- Avoid global loggers, prefer dependency injection +- Functional options pattern for logger configuration +- Secret logging sanitization with `lgr.Secret` option \ No newline at end of file diff --git a/vendor/github.com/go-pkgz/lgr/logger.go b/vendor/github.com/go-pkgz/lgr/logger.go index 2d8471ca..14ae39ca 100644 --- a/vendor/github.com/go-pkgz/lgr/logger.go +++ b/vendor/github.com/go-pkgz/lgr/logger.go @@ -166,9 +166,10 @@ func (l *Logger) logf(format string, args ...interface{}) { // if slog handler is set, use it if l.slogHandler != nil { - // use NewRecord for consistency with adapter setup - // skip=0 because we don't need caller information from this context - record := slog.NewRecord(l.now(), stringToLevel(lv), msg, 0) + // get the caller's PC so slog handlers can resolve source info when AddSource is enabled + var pcs [1]uintptr + runtime.Callers(3+l.callerDepth, pcs[:]) // skip runtime.Callers, logf, Logf (+ any extra depth) + record := slog.NewRecord(l.now(), stringToLevel(lv), msg, pcs[0]) _ = l.slogHandler.Handle(context.Background(), record) // handle FATAL and PANIC levels as they have special behavior diff --git a/vendor/github.com/go-pkgz/lgr/slog.go b/vendor/github.com/go-pkgz/lgr/slog.go index b68af07a..dd9ac3ab 100644 --- a/vendor/github.com/go-pkgz/lgr/slog.go +++ b/vendor/github.com/go-pkgz/lgr/slog.go @@ -23,13 +23,13 @@ func FromSlogHandler(h slog.Handler) L { // SetupWithSlog sets up the global logger with a slog logger func SetupWithSlog(logger *slog.Logger) { options := []Option{SlogHandler(logger.Handler())} - + // check if the slog handler is enabled for debug level // if so, enable debug mode in lgr to prevent filtering if logger.Handler().Enabled(context.Background(), slog.LevelDebug) { options = append(options, Debug) } - + Setup(options...) } @@ -59,12 +59,6 @@ func (h *lgrSlogHandler) Handle(_ context.Context, record slog.Record) error { // build message with attributes msg := record.Message - // add time if record has it, otherwise current time is used by lgr - var timeStr string - if !record.Time.IsZero() { - timeStr = record.Time.Format("2006/01/02 15:04:05.000 ") - } - // format attributes as key=value pairs var attrs strings.Builder if len(h.attrs) > 0 || record.NumAttrs() > 0 { @@ -82,8 +76,8 @@ func (h *lgrSlogHandler) Handle(_ context.Context, record slog.Record) error { return true }) - // combine everything into final message - logMsg := fmt.Sprintf("%s%s %s%s", timeStr, level, msg, attrs.String()) + // combine level prefix and message; lgr.Logf adds its own timestamp and level formatting + logMsg := fmt.Sprintf("%s %s%s", level, msg, attrs.String()) h.lgr.Logf(logMsg) return nil } @@ -115,39 +109,15 @@ type slogLgrAdapter struct { // Logf implements lgr.L interface func (a *slogLgrAdapter) Logf(format string, args ...interface{}) { - // parse log level from the beginning of the message msg := fmt.Sprintf(format, args...) level, msg := extractLevel(msg) - // create a record with caller information - // skip level is critical: - // - 0 = this line - // - 1 = this function (Logf) - // - 2 = caller of Logf (user code) - // - // note: We use PC=0 to ensure slog.Record.PC() returns 0, - // which causes slog to skip obtaining the caller info itself - record := slog.NewRecord(time.Now(), stringToLevel(level), msg, 2) - - // we need to manually add the source information ourselves, since - // slog.Handler might have AddSource=true but won't get the caller - // right due to how we're adapting lgr → slog - pc, file, line, ok := runtime.Caller(2) // skip to caller of Logf - if ok { - // only add source info if we can find it - funcName := runtime.FuncForPC(pc).Name() - record.AddAttrs( - slog.Group("source", - slog.String("function", funcName), - slog.String("file", file), - slog.Int("line", line), - ), - ) - } + // get the caller's PC so slog handlers can resolve source info when AddSource is enabled + var pcs [1]uintptr + runtime.Callers(2, pcs[:]) // skip runtime.Callers and Logf + record := slog.NewRecord(time.Now(), stringToLevel(level), msg, pcs[0]) - // handle the record if err := a.handler.Handle(context.Background(), record); err != nil { - // if handling fails, fallback to stderr fmt.Fprintf(os.Stderr, "slog handler error: %v\n", err) } } diff --git a/vendor/github.com/go-pkgz/rest/.golangci.yml b/vendor/github.com/go-pkgz/rest/.golangci.yml index bd6fb4ac..7e71eafe 100644 --- a/vendor/github.com/go-pkgz/rest/.golangci.yml +++ b/vendor/github.com/go-pkgz/rest/.golangci.yml @@ -19,6 +19,8 @@ linters: - unconvert - unparam - unused + - modernize + - testifylint settings: goconst: min-len: 2 diff --git a/vendor/github.com/go-pkgz/rest/README.md b/vendor/github.com/go-pkgz/rest/README.md index 1260cf80..749e0dac 100644 --- a/vendor/github.com/go-pkgz/rest/README.md +++ b/vendor/github.com/go-pkgz/rest/README.md @@ -143,6 +143,22 @@ The `Rewrite` middleware is designed to rewrite the URL path based on a given ru For example, `Rewrite("^/sites/(.*)/settings/$", "/sites/settings/$1")` will change request's URL from `/sites/id1/settings/` to `/sites/settings/id1` +### CleanPath middleware + +Cleans double slashes from URL path. For example, requests to `/users//1` or `//users////1` will be cleaned to `/users/1` before routing. Trailing slashes are preserved: `/api//v1/` becomes `/api/v1/`. Note: dot segments (`.` and `..`) are intentionally not cleaned to preserve routing semantics. + +```go +router.Use(rest.CleanPath) +``` + +### StripSlashes middleware + +Removes trailing slashes from URL path. For example, `/users/` becomes `/users`. The root path `/` is preserved. + +```go +router.Use(rest.StripSlashes) +``` + ### NoCache middleware Sets a number of HTTP headers to prevent a router (handler's) response from being cached by an upstream proxy and/or client. @@ -166,6 +182,131 @@ RealIP is a middleware that sets a http.Request's RemoteAddr to the results of p Only public IPs are accepted from headers; private/loopback/link-local IPs are skipped. This makes the middleware compatible with CDN setups like Cloudflare where the leftmost IP in `X-Forwarded-For` is the actual client. +### CORS middleware + +Handles Cross-Origin Resource Sharing, allowing controlled access from different origins. + +```go +// allow all origins (default) +router.Use(rest.CORS()) + +// specific origins with credentials +router.Use(rest.CORS( + rest.CorsAllowedOrigins("https://app.example.com", "https://admin.example.com"), + rest.CorsAllowCredentials(true), + rest.CorsMaxAge(86400), +)) + +// full configuration +router.Use(rest.CORS( + rest.CorsAllowedOrigins("https://app.example.com"), + rest.CorsAllowedMethods("GET", "POST", "PUT", "DELETE"), + rest.CorsAllowedHeaders("Authorization", "Content-Type", "X-Custom-Header"), + rest.CorsExposedHeaders("X-Request-Id", "X-Total-Count"), + rest.CorsAllowCredentials(true), + rest.CorsMaxAge(3600), +)) +``` + +Features: +- Automatic preflight (OPTIONS) handling +- Origin validation with case-insensitive matching +- Credentials support (reflects origin instead of `*`) +- Configurable cache duration for preflight results + +Available options: +- `CorsAllowedOrigins(origins...)` - allowed origins (default: `*`) +- `CorsAllowedMethods(methods...)` - allowed HTTP methods (default: GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD) +- `CorsAllowedHeaders(headers...)` - allowed request headers (default: Accept, Content-Type, Authorization, X-Requested-With) +- `CorsExposedHeaders(headers...)` - headers exposed to client +- `CorsAllowCredentials(bool)` - enable credentials (cookies, auth headers) +- `CorsMaxAge(seconds)` - preflight cache duration + +### Secure middleware + +Adds security headers to responses. By default sets: `X-Frame-Options`, `X-Content-Type-Options`, `Referrer-Policy`, `X-XSS-Protection`, and `Strict-Transport-Security` (for HTTPS only). + +```go +// with sensible defaults +router.Use(rest.Secure()) + +// with full security headers for web apps (adds CSP and Permissions-Policy) +router.Use(rest.Secure(rest.SecAllHeaders())) + +// with custom options +router.Use(rest.Secure( + rest.SecFrameOptions("SAMEORIGIN"), + rest.SecReferrerPolicy("no-referrer"), + rest.SecHSTS(86400, true, true), + rest.SecContentSecurityPolicy("default-src 'self'"), + rest.SecPermissionsPolicy("geolocation=(), camera=()"), +)) +``` + +Default headers: +- `X-Frame-Options: DENY` - prevents clickjacking +- `X-Content-Type-Options: nosniff` - prevents MIME-type sniffing +- `Referrer-Policy: strict-origin-when-cross-origin` - controls referrer information +- `X-XSS-Protection: 1; mode=block` - enables XSS filtering (legacy browsers) +- `Strict-Transport-Security: max-age=31536000; includeSubDomains` - enforces HTTPS (only sent over HTTPS) + +Available options: +- `SecFrameOptions(value)` - set X-Frame-Options (DENY, SAMEORIGIN) +- `SecContentTypeNosniff(enable)` - enable/disable nosniff +- `SecReferrerPolicy(policy)` - set Referrer-Policy +- `SecContentSecurityPolicy(policy)` - set Content-Security-Policy +- `SecPermissionsPolicy(policy)` - set Permissions-Policy +- `SecHSTS(maxAge, includeSubdomains, preload)` - configure HSTS +- `SecXSSProtection(value)` - set X-XSS-Protection +- `SecAllHeaders()` - convenience option that sets CSP and Permissions-Policy with restrictive defaults + +### CSRF middleware + +Provides Cross-Site Request Forgery protection using modern browser Fetch metadata headers (`Sec-Fetch-Site`, `Origin`). For Go 1.25+, this wraps the stdlib's `http.CrossOriginProtection`. For earlier versions, a compatible custom implementation is used. + +```go +// basic protection +protection := rest.NewCrossOriginProtection() +router.Use(protection.Handler) + +// with trusted origins for cross-origin requests +protection := rest.NewCrossOriginProtection() +if err := protection.AddTrustedOrigin("https://mobile.example.com"); err != nil { + log.Fatal(err) +} +if err := protection.AddTrustedOrigin("https://admin.example.com"); err != nil { + log.Fatal(err) +} +router.Use(protection.Handler) + +// with bypass patterns for webhooks or OAuth +protection := rest.NewCrossOriginProtection() +protection.AddBypassPattern("/api/webhook") +protection.AddBypassPattern("/oauth/") +router.Use(protection.Handler) + +// with custom deny handler +protection := rest.NewCrossOriginProtection() +protection.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "CSRF validation failed", http.StatusForbidden) +})) +router.Use(protection.Handler) +``` + +How it works: +- Safe methods (GET, HEAD, OPTIONS) are always allowed +- Checks `Sec-Fetch-Site` header for "same-origin" or "none" +- Falls back to comparing `Origin` header with request `Host` +- Requests without these headers are assumed same-origin (non-browser clients) + +Available methods: +- `NewCrossOriginProtection()` - creates new CSRF protection middleware +- `AddTrustedOrigin(origin)` - adds origin allowed for cross-origin requests (format: "scheme://host[:port]") +- `AddBypassPattern(pattern)` - adds URL pattern that bypasses protection (for webhooks, OAuth, etc.) +- `SetDenyHandler(handler)` - sets custom handler for rejected requests (default: 403 Forbidden) +- `Check(request)` - manually validates a request, returns error if blocked +- `Handler(handler)` - wraps an http.Handler with CSRF protection + ### Maybe middleware Maybe middleware allows changing the flow of the middleware stack execution depending on the return diff --git a/vendor/github.com/go-pkgz/rest/benchmarks.go b/vendor/github.com/go-pkgz/rest/benchmarks.go index 24e1dea9..e422d0ea 100644 --- a/vendor/github.com/go-pkgz/rest/benchmarks.go +++ b/vendor/github.com/go-pkgz/rest/benchmarks.go @@ -156,10 +156,7 @@ func (b *Benchmarks) Stats(interval time.Duration) BenchmarkStats { } // ensure we calculate rate based on actual interval - actualInterval := fnInterval.Sub(stInterval) - if actualInterval < time.Second { - actualInterval = time.Second - } + actualInterval := max(fnInterval.Sub(stInterval), time.Second) return BenchmarkStats{ Requests: requests, diff --git a/vendor/github.com/go-pkgz/rest/cors.go b/vendor/github.com/go-pkgz/rest/cors.go new file mode 100644 index 00000000..df7d8f9c --- /dev/null +++ b/vendor/github.com/go-pkgz/rest/cors.go @@ -0,0 +1,180 @@ +package rest + +import ( + "net/http" + "strconv" + "strings" +) + +// CORSConfig defines CORS middleware configuration. +// Use CorsOpt functions to customize. +type CORSConfig struct { + // AllowedOrigins is a list of origins that may access the resource. + // use "*" to allow all origins (not recommended with credentials). + // default: ["*"] + AllowedOrigins []string + // AllowedMethods is a list of methods the client is allowed to use. + // default: GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD + AllowedMethods []string + // AllowedHeaders is a list of headers the client is allowed to send. + // default: Accept, Content-Type, Authorization, X-Requested-With + AllowedHeaders []string + // ExposedHeaders is a list of headers that are safe to expose to the client. + // default: empty + ExposedHeaders []string + // AllowCredentials indicates whether the request can include credentials. + // when true, AllowedOrigins cannot be "*" (browser security restriction). + // default: false + AllowCredentials bool + // MaxAge indicates how long (in seconds) the results of a preflight can be cached. + // default: 0 (no caching) + MaxAge int +} + +// CorsOpt is a functional option for CORSConfig +type CorsOpt func(*CORSConfig) + +// defaultCORSConfig returns config with sensible defaults +func defaultCORSConfig() CORSConfig { + return CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"}, + AllowedHeaders: []string{"Accept", "Content-Type", "Authorization", "X-Requested-With"}, + ExposedHeaders: []string{}, + AllowCredentials: false, + MaxAge: 0, + } +} + +// CorsAllowedOrigins sets the list of allowed origins. +// Use "*" to allow all origins (not recommended with credentials). +func CorsAllowedOrigins(origins ...string) CorsOpt { + return func(c *CORSConfig) { + c.AllowedOrigins = origins + } +} + +// CorsAllowedMethods sets the list of allowed HTTP methods. +func CorsAllowedMethods(methods ...string) CorsOpt { + return func(c *CORSConfig) { + c.AllowedMethods = methods + } +} + +// CorsAllowedHeaders sets the list of allowed request headers. +func CorsAllowedHeaders(headers ...string) CorsOpt { + return func(c *CORSConfig) { + c.AllowedHeaders = headers + } +} + +// CorsExposedHeaders sets the list of headers exposed to the client. +func CorsExposedHeaders(headers ...string) CorsOpt { + return func(c *CORSConfig) { + c.ExposedHeaders = headers + } +} + +// CorsAllowCredentials enables or disables credentials. +// When true, AllowedOrigins cannot be "*". +func CorsAllowCredentials(allow bool) CorsOpt { + return func(c *CORSConfig) { + c.AllowCredentials = allow + } +} + +// CorsMaxAge sets how long (in seconds) preflight results can be cached. +func CorsMaxAge(seconds int) CorsOpt { + return func(c *CORSConfig) { + c.MaxAge = seconds + } +} + +// CORS is middleware that handles Cross-Origin Resource Sharing. +// It handles preflight OPTIONS requests and sets appropriate headers. +// By default allows all origins with common methods and headers. +func CORS(opts ...CorsOpt) func(http.Handler) http.Handler { + cfg := defaultCORSConfig() + for _, opt := range opts { + opt(&cfg) + } + + // pre-compute joined strings for performance + methodsStr := strings.Join(cfg.AllowedMethods, ", ") + headersStr := strings.Join(cfg.AllowedHeaders, ", ") + exposedStr := strings.Join(cfg.ExposedHeaders, ", ") + + // check if wildcard is used + allowAll := len(cfg.AllowedOrigins) == 1 && cfg.AllowedOrigins[0] == "*" + + // build origin lookup for O(1) check (only when not allowing all) + var originSet map[string]bool + if !allowAll { + originSet = make(map[string]bool, len(cfg.AllowedOrigins)) + for _, o := range cfg.AllowedOrigins { + originSet[strings.ToLower(o)] = true + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // no origin header means same-origin or non-browser request + if origin == "" { + next.ServeHTTP(w, r) + return + } + + // check if origin is allowed + var allowed bool + if allowAll { + allowed = true + } else { + allowed = originSet[strings.ToLower(origin)] + } + if !allowed { + // origin not allowed, proceed without CORS headers + next.ServeHTTP(w, r) + return + } + + // set Vary header for caching + w.Header().Add("Vary", "Origin") + + // set allowed origin + if allowAll && !cfg.AllowCredentials { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + // reflect the specific origin (required for credentials) + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + // set credentials header if enabled + if cfg.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // handle preflight request + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + // preflight request + w.Header().Set("Access-Control-Allow-Methods", methodsStr) + w.Header().Set("Access-Control-Allow-Headers", headersStr) + + if cfg.MaxAge > 0 { + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.MaxAge)) + } + + w.WriteHeader(http.StatusNoContent) + return + } + + // actual request - set exposed headers + if exposedStr != "" { + w.Header().Set("Access-Control-Expose-Headers", exposedStr) + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/vendor/github.com/go-pkgz/rest/csrf.go b/vendor/github.com/go-pkgz/rest/csrf.go new file mode 100644 index 00000000..0f1f930b --- /dev/null +++ b/vendor/github.com/go-pkgz/rest/csrf.go @@ -0,0 +1,207 @@ +//go:build !go1.25 + +package rest + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "sync" +) + +// CrossOriginProtection provides CSRF protection using modern browser Fetch metadata. +// It validates requests using Sec-Fetch-Site and Origin headers, rejecting cross-origin +// state-changing requests. Safe methods (GET, HEAD, OPTIONS) are always allowed. +// +// For Go 1.25+, this wraps the stdlib http.CrossOriginProtection. +// For earlier versions, it provides an equivalent custom implementation. +type CrossOriginProtection struct { + mu sync.RWMutex + trustedOrigins map[string]bool + bypassPatterns []string + denyHandler http.Handler +} + +// NewCrossOriginProtection creates a new CSRF protection middleware. +func NewCrossOriginProtection() *CrossOriginProtection { + return &CrossOriginProtection{ + trustedOrigins: make(map[string]bool), + } +} + +// AddTrustedOrigin adds an origin that should be allowed to make cross-origin requests. +// The origin must be in the format "scheme://host" or "scheme://host:port". +// Returns an error if the origin format is invalid. +func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error { + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("invalid origin: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("origin must have scheme and host: %s", origin) + } + if u.Path != "" && u.Path != "/" { + return fmt.Errorf("origin must not have path: %s", origin) + } + if u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("origin must not have query or fragment: %s", origin) + } + + normalized := strings.ToLower(u.Scheme) + "://" + strings.ToLower(u.Host) + + c.mu.Lock() + c.trustedOrigins[normalized] = true + c.mu.Unlock() + return nil +} + +// AddBypassPattern adds a URL pattern that should bypass CSRF protection. +// Patterns follow the same syntax as http.ServeMux (e.g., "/api/webhook", "/oauth/"). +// Use sparingly and only for endpoints that have alternative authentication. +func (c *CrossOriginProtection) AddBypassPattern(pattern string) { + c.mu.Lock() + c.bypassPatterns = append(c.bypassPatterns, pattern) + c.mu.Unlock() +} + +// SetDenyHandler sets a custom handler for rejected requests. +// If not set, rejected requests receive a 403 Forbidden response. +func (c *CrossOriginProtection) SetDenyHandler(h http.Handler) { + c.mu.Lock() + c.denyHandler = h + c.mu.Unlock() +} + +// Check validates a request against CSRF protection rules. +// Returns nil if the request is allowed, or an error describing why it was rejected. +func (c *CrossOriginProtection) Check(r *http.Request) error { + // safe methods are always allowed + if isSafeMethod(r.Method) { + return nil + } + + // check bypass patterns + if c.matchesBypassPattern(r.URL.Path) { + return nil + } + + // check Sec-Fetch-Site header (modern browsers) + secFetchSite := r.Header.Get("Sec-Fetch-Site") + if secFetchSite != "" { + switch secFetchSite { + case "same-origin", "none": + return nil + case "cross-site", "same-site": + // check if origin is trusted + origin := r.Header.Get("Origin") + if origin != "" && c.isOriginTrusted(origin) { + return nil + } + return fmt.Errorf("cross-origin request blocked: Sec-Fetch-Site=%s", secFetchSite) + } + } + + // fallback: check Origin header against Host + origin := r.Header.Get("Origin") + if origin != "" { + // check if origin is trusted + if c.isOriginTrusted(origin) { + return nil + } + + // compare origin host with request host + originURL, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("invalid Origin header: %w", err) + } + + requestHost := r.Host + if requestHost == "" { + requestHost = r.URL.Host + } + + // normalize hosts for comparison + originHost := strings.ToLower(originURL.Host) + requestHost = strings.ToLower(requestHost) + + if originHost != requestHost { + return fmt.Errorf("cross-origin request blocked: origin %s does not match host %s", originHost, requestHost) + } + return nil + } + + // no Sec-Fetch-Site or Origin headers - assume same-origin or non-browser request + return nil +} + +// Handler wraps an http.Handler with CSRF protection. +// Rejected requests receive a 403 Forbidden response (or custom deny handler). +func (c *CrossOriginProtection) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := c.Check(r); err != nil { + c.mu.RLock() + deny := c.denyHandler + c.mu.RUnlock() + + if deny != nil { + deny.ServeHTTP(w, r) + return + } + http.Error(w, "Forbidden - CSRF check failed", http.StatusForbidden) + return + } + h.ServeHTTP(w, r) + }) +} + +// isSafeMethod returns true for HTTP methods that don't modify state. +func isSafeMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return true + } + return false +} + +// isOriginTrusted checks if the origin is in the trusted list. +func (c *CrossOriginProtection) isOriginTrusted(origin string) bool { + u, err := url.Parse(origin) + if err != nil { + return false + } + normalized := strings.ToLower(u.Scheme) + "://" + strings.ToLower(u.Host) + + c.mu.RLock() + trusted := c.trustedOrigins[normalized] + c.mu.RUnlock() + return trusted +} + +// matchesBypassPattern checks if the path matches any bypass pattern. +func (c *CrossOriginProtection) matchesBypassPattern(path string) bool { + c.mu.RLock() + patterns := make([]string, len(c.bypassPatterns)) + copy(patterns, c.bypassPatterns) + c.mu.RUnlock() + + for _, pattern := range patterns { + if matchPattern(pattern, path) { + return true + } + } + return false +} + +// matchPattern implements simple pattern matching similar to http.ServeMux. +func matchPattern(pattern, path string) bool { + // exact match + if pattern == path { + return true + } + // prefix match for patterns ending with / + if strings.HasSuffix(pattern, "/") && strings.HasPrefix(path, pattern) { + return true + } + return false +} diff --git a/vendor/github.com/go-pkgz/rest/csrf_go125.go b/vendor/github.com/go-pkgz/rest/csrf_go125.go new file mode 100644 index 00000000..d229f0d8 --- /dev/null +++ b/vendor/github.com/go-pkgz/rest/csrf_go125.go @@ -0,0 +1,111 @@ +//go:build go1.25 + +package rest + +import ( + "fmt" + "net/http" + "net/url" + "strings" +) + +// CrossOriginProtection provides CSRF protection using modern browser Fetch metadata. +// It validates requests using Sec-Fetch-Site and Origin headers, rejecting cross-origin +// state-changing requests. Safe methods (GET, HEAD, OPTIONS) are always allowed. +// +// For Go 1.25+, this wraps the stdlib http.CrossOriginProtection. +// For earlier versions, it provides an equivalent custom implementation. +type CrossOriginProtection struct { + stdlib *http.CrossOriginProtection +} + +// NewCrossOriginProtection creates a new CSRF protection middleware. +func NewCrossOriginProtection() *CrossOriginProtection { + return &CrossOriginProtection{ + stdlib: http.NewCrossOriginProtection(), + } +} + +// AddTrustedOrigin adds an origin that should be allowed to make cross-origin requests. +// The origin must be in the format "scheme://host" or "scheme://host:port". +// Returns an error if the origin format is invalid. +func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error { + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("invalid origin: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("origin must have scheme and host: %s", origin) + } + if u.Path != "" && u.Path != "/" { + return fmt.Errorf("origin must not have path: %s", origin) + } + if u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("origin must not have query or fragment: %s", origin) + } + + // normalize to lowercase for consistent case-insensitive matching + normalized := strings.ToLower(u.Scheme) + "://" + strings.ToLower(u.Host) + return c.stdlib.AddTrustedOrigin(normalized) +} + +// AddBypassPattern adds a URL pattern that should bypass CSRF protection. +// Patterns follow the same syntax as http.ServeMux (e.g., "/api/webhook", "/oauth/"). +// Use sparingly and only for endpoints that have alternative authentication. +func (c *CrossOriginProtection) AddBypassPattern(pattern string) { + c.stdlib.AddInsecureBypassPattern(pattern) +} + +// SetDenyHandler sets a custom handler for rejected requests. +// If not set, rejected requests receive a 403 Forbidden response. +func (c *CrossOriginProtection) SetDenyHandler(h http.Handler) { + c.stdlib.SetDenyHandler(h) +} + +// Check validates a request against CSRF protection rules. +// Returns nil if the request is allowed, or an error describing why it was rejected. +func (c *CrossOriginProtection) Check(r *http.Request) error { + // the stdlib Check method panics or returns void, so we use our own check logic + // by creating a test handler and seeing if it gets called + + // safe methods are always allowed + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return nil + } + + // use a test to determine if request would be allowed + allowed := false + testHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + allowed = true + }) + + // create a response recorder to capture the result + rec := &discardResponseWriter{} + c.stdlib.Handler(testHandler).ServeHTTP(rec, r) + + if !allowed { + return fmt.Errorf("cross-origin request blocked by CSRF protection") + } + return nil +} + +// Handler wraps an http.Handler with CSRF protection. +// Rejected requests receive a 403 Forbidden response (or custom deny handler). +func (c *CrossOriginProtection) Handler(h http.Handler) http.Handler { + return c.stdlib.Handler(h) +} + +// discardResponseWriter is a minimal ResponseWriter for testing. +type discardResponseWriter struct { + header http.Header +} + +func (d *discardResponseWriter) Header() http.Header { + if d.header == nil { + d.header = make(http.Header) + } + return d.header +} +func (d *discardResponseWriter) Write(b []byte) (int, error) { return len(b), nil } +func (d *discardResponseWriter) WriteHeader(_ int) {} diff --git a/vendor/github.com/go-pkgz/rest/gzip.go b/vendor/github.com/go-pkgz/rest/gzip.go index a7328b05..67cac265 100644 --- a/vendor/github.com/go-pkgz/rest/gzip.go +++ b/vendor/github.com/go-pkgz/rest/gzip.go @@ -20,7 +20,7 @@ var gzDefaultContentTypes = []string{ } var gzPool = sync.Pool{ - New: func() interface{} { return gzip.NewWriter(io.Discard) }, + New: func() any { return gzip.NewWriter(io.Discard) }, } type gzipResponseWriter struct { diff --git a/vendor/github.com/go-pkgz/rest/logger/logger.go b/vendor/github.com/go-pkgz/rest/logger/logger.go index d93342c5..df409320 100644 --- a/vendor/github.com/go-pkgz/rest/logger/logger.go +++ b/vendor/github.com/go-pkgz/rest/logger/logger.go @@ -32,7 +32,7 @@ type Middleware struct { // Backend is logging backend type Backend interface { - Logf(format string, args ...interface{}) + Logf(format string, args ...any) } type logParts struct { @@ -51,7 +51,7 @@ type logParts struct { type stdBackend struct{} -func (s stdBackend) Logf(format string, args ...interface{}) { +func (s stdBackend) Logf(format string, args ...any) { log.Printf(format, args...) } diff --git a/vendor/github.com/go-pkgz/rest/middleware.go b/vendor/github.com/go-pkgz/rest/middleware.go index c130a462..59e9f069 100644 --- a/vendor/github.com/go-pkgz/rest/middleware.go +++ b/vendor/github.com/go-pkgz/rest/middleware.go @@ -36,14 +36,17 @@ func AppInfo(app, author, version string) func(http.Handler) http.Handler { return f } -// Ping middleware response with pong to /ping. Stops chain if ping request detected +// Ping middleware response with pong to /ping. Stops chain if ping request detected. +// Handles both GET and HEAD methods - HEAD returns headers only without body, +// which is useful for lightweight health checks by monitoring tools. func Ping(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - - if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") { + if (r.Method == "GET" || r.Method == "HEAD") && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("pong")) + if r.Method == "GET" { + _, _ = w.Write([]byte("pong")) + } return } next.ServeHTTP(w, r) diff --git a/vendor/github.com/go-pkgz/rest/realip/real.go b/vendor/github.com/go-pkgz/rest/realip/real.go index ad10149f..cb492b19 100644 --- a/vendor/github.com/go-pkgz/rest/realip/real.go +++ b/vendor/github.com/go-pkgz/rest/realip/real.go @@ -58,8 +58,7 @@ func Get(r *http.Request) (string, error) { // check X-Forwarded-For, find leftmost public IP if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - addresses := strings.Split(xff, ",") - for _, addr := range addresses { + for addr := range strings.SplitSeq(xff, ",") { ip := strings.TrimSpace(addr) if parsedIP := net.ParseIP(ip); isPublicIP(parsedIP) { return ip, nil diff --git a/vendor/github.com/go-pkgz/rest/rest.go b/vendor/github.com/go-pkgz/rest/rest.go index d8181f3e..94b3e9d8 100644 --- a/vendor/github.com/go-pkgz/rest/rest.go +++ b/vendor/github.com/go-pkgz/rest/rest.go @@ -13,7 +13,7 @@ import ( type JSON map[string]any // RenderJSON sends data as json -func RenderJSON(w http.ResponseWriter, data interface{}) { +func RenderJSON(w http.ResponseWriter, data any) { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(true) @@ -35,9 +35,8 @@ func RenderJSONFromBytes(w http.ResponseWriter, r *http.Request, data []byte) er } // RenderJSONWithHTML allows html tags and forces charset=utf-8 -func RenderJSONWithHTML(w http.ResponseWriter, r *http.Request, v interface{}) error { - - encodeJSONWithHTML := func(v interface{}) ([]byte, error) { +func RenderJSONWithHTML(w http.ResponseWriter, r *http.Request, v any) error { + encodeJSONWithHTML := func(v any) ([]byte, error) { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) @@ -55,7 +54,7 @@ func RenderJSONWithHTML(w http.ResponseWriter, r *http.Request, v interface{}) e } // renderJSONWithStatus sends data as json and enforces status code -func renderJSONWithStatus(w http.ResponseWriter, data interface{}, code int) { +func renderJSONWithStatus(w http.ResponseWriter, data any, code int) { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(true) diff --git a/vendor/github.com/go-pkgz/rest/rewrite.go b/vendor/github.com/go-pkgz/rest/rewrite.go index f3286bd2..d9e9537c 100644 --- a/vendor/github.com/go-pkgz/rest/rewrite.go +++ b/vendor/github.com/go-pkgz/rest/rewrite.go @@ -9,6 +9,79 @@ import ( "strings" ) +// CleanPath middleware cleans double slashes from URL path. +// For example, if a request is made to /users//1 or //users////1, +// it will be cleaned to /users/1 before routing. +// Trailing slashes are preserved: /users//1/ becomes /users/1/. +// Dot segments (. and ..) are intentionally NOT cleaned to preserve routing semantics. +func CleanPath(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rctx := r.Context() + // skip if already cleaned + if _, ok := rctx.Value(contextKey("cleanpath")).(bool); ok { + next.ServeHTTP(w, r) + return + } + + p := r.URL.Path + cleaned := cleanDoubleSlashes(p) + + if cleaned != p { + r.URL.Path = cleaned + if r.URL.RawPath != "" { + // clean double slashes in RawPath separately to preserve percent-encoding + r.URL.RawPath = cleanDoubleSlashes(r.URL.RawPath) + } + rctx = context.WithValue(rctx, contextKey("cleanpath"), true) + r = r.WithContext(rctx) + } + next.ServeHTTP(w, r) + }) +} + +// cleanDoubleSlashes removes consecutive slashes from path while preserving +// trailing slashes and dot segments (. and ..). +func cleanDoubleSlashes(p string) string { + if p == "" || p == "/" { + return p + } + + var b strings.Builder + b.Grow(len(p)) + + prevSlash := false + for i := 0; i < len(p); i++ { + c := p[i] + if c == '/' { + if !prevSlash { + b.WriteByte(c) + } + prevSlash = true + } else { + b.WriteByte(c) + prevSlash = false + } + } + + return b.String() +} + +// StripSlashes middleware removes trailing slashes from URL path. +// For example, /users/1/ becomes /users/1. +// The root path "/" is preserved. +func StripSlashes(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p := r.URL.Path + if len(p) > 1 && p[len(p)-1] == '/' { + r.URL.Path = p[:len(p)-1] + if r.URL.RawPath != "" { + r.URL.RawPath = strings.TrimSuffix(r.URL.RawPath, "/") + } + } + next.ServeHTTP(w, r) + }) +} + // Rewrite middleware with from->to rule. Supports regex (like nginx) and prevents multiple rewrites // example: Rewrite(`^/sites/(.*)/settings/$`, `/sites/settings/$1` func Rewrite(from, to string) func(http.Handler) http.Handler { diff --git a/vendor/github.com/go-pkgz/rest/secure.go b/vendor/github.com/go-pkgz/rest/secure.go new file mode 100644 index 00000000..19d4eb04 --- /dev/null +++ b/vendor/github.com/go-pkgz/rest/secure.go @@ -0,0 +1,207 @@ +package rest + +import ( + "net/http" + "strconv" + "strings" +) + +// SecureConfig defines security headers configuration. +// Use SecOpt functions to customize. +type SecureConfig struct { + // xFrameOptions sets X-Frame-Options header. Default: DENY + XFrameOptions string + // xContentTypeOptions sets X-Content-Type-Options. Default: nosniff + XContentTypeOptions string + // ReferrerPolicy sets Referrer-Policy header. Default: strict-origin-when-cross-origin + ReferrerPolicy string + // ContentSecurityPolicy sets Content-Security-Policy header. Default: empty (not set) + ContentSecurityPolicy string + // PermissionsPolicy sets Permissions-Policy header. Default: empty (not set) + PermissionsPolicy string + // sTSSeconds sets max-age for Strict-Transport-Security. 0 disables. + // only sent when request uses HTTPS. Default: 31536000 (1 year) + STSSeconds int + // sTSIncludeSubdomains adds includeSubDomains to HSTS. Default: true + STSIncludeSubdomains bool + // sTSPreload adds preload flag to HSTS. Default: false + STSPreload bool + // xSSProtection sets X-XSS-Protection header. Default: 1; mode=block + // note: this header is deprecated in modern browsers but still useful for older ones + XSSProtection string +} + +// SecOpt is a functional option for SecureConfig +type SecOpt func(*SecureConfig) + +// defaultSecureConfig returns config with sensible defaults +func defaultSecureConfig() SecureConfig { + return SecureConfig{ + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", + ReferrerPolicy: "strict-origin-when-cross-origin", + STSSeconds: 31536000, // 1 year + STSIncludeSubdomains: true, + STSPreload: false, + XSSProtection: "1; mode=block", + } +} + +// SecFrameOptions sets X-Frame-Options header. +// Common values: "DENY", "SAMEORIGIN" +func SecFrameOptions(value string) SecOpt { + return func(c *SecureConfig) { + c.XFrameOptions = value + } +} + +// SecContentTypeNosniff enables or disables X-Content-Type-Options: nosniff +func SecContentTypeNosniff(enable bool) SecOpt { + return func(c *SecureConfig) { + if enable { + c.XContentTypeOptions = "nosniff" + } else { + c.XContentTypeOptions = "" + } + } +} + +// SecReferrerPolicy sets Referrer-Policy header. +// Common values: "no-referrer", "same-origin", "strict-origin", "strict-origin-when-cross-origin" +func SecReferrerPolicy(policy string) SecOpt { + return func(c *SecureConfig) { + c.ReferrerPolicy = policy + } +} + +// SecContentSecurityPolicy sets Content-Security-Policy header. +// Example: "default-src 'self'; script-src 'self'" +func SecContentSecurityPolicy(policy string) SecOpt { + return func(c *SecureConfig) { + c.ContentSecurityPolicy = policy + } +} + +// SecPermissionsPolicy sets Permissions-Policy header. +// Example: "geolocation=(), microphone=()" +func SecPermissionsPolicy(policy string) SecOpt { + return func(c *SecureConfig) { + c.PermissionsPolicy = policy + } +} + +// SecHSTS configures Strict-Transport-Security header. +// maxAge is in seconds (0 disables HSTS), includeSubdomains and preload are optional flags. +// Note: HSTS header is only sent when the request is over HTTPS. +func SecHSTS(maxAge int, includeSubdomains, preload bool) SecOpt { + return func(c *SecureConfig) { + c.STSSeconds = maxAge + c.STSIncludeSubdomains = includeSubdomains + c.STSPreload = preload + } +} + +// SecXSSProtection sets X-XSS-Protection header. +// Set to empty string to disable. Common values: "0", "1", "1; mode=block" +func SecXSSProtection(value string) SecOpt { + return func(c *SecureConfig) { + c.XSSProtection = value + } +} + +// SecAllHeaders is a convenience option to set common headers for secure web applications. +// Sets CSP with self-only policy and restrictive permissions. +func SecAllHeaders() SecOpt { + return func(c *SecureConfig) { + c.ContentSecurityPolicy = "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'; form-action 'self'; frame-ancestors 'none'" + c.PermissionsPolicy = "geolocation=(), microphone=(), camera=()" + } +} + +// Secure is middleware that adds security headers to responses. +// By default it sets: X-Frame-Options, X-Content-Type-Options, Referrer-Policy, +// X-XSS-Protection, and Strict-Transport-Security (for HTTPS only). +// Use SecOpt functions to customize the configuration. +func Secure(opts ...SecOpt) func(http.Handler) http.Handler { + cfg := defaultSecureConfig() + for _, opt := range opts { + opt(&cfg) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // set security headers + if cfg.XFrameOptions != "" { + w.Header().Set("X-Frame-Options", cfg.XFrameOptions) + } + if cfg.XContentTypeOptions != "" { + w.Header().Set("X-Content-Type-Options", cfg.XContentTypeOptions) + } + if cfg.ReferrerPolicy != "" { + w.Header().Set("Referrer-Policy", cfg.ReferrerPolicy) + } + if cfg.XSSProtection != "" { + w.Header().Set("X-XSS-Protection", cfg.XSSProtection) + } + if cfg.ContentSecurityPolicy != "" { + w.Header().Set("Content-Security-Policy", cfg.ContentSecurityPolicy) + } + if cfg.PermissionsPolicy != "" { + w.Header().Set("Permissions-Policy", cfg.PermissionsPolicy) + } + + // HSTS only for HTTPS connections + if cfg.STSSeconds > 0 && isHTTPS(r) { + sts := "max-age=" + strconv.Itoa(cfg.STSSeconds) + if cfg.STSIncludeSubdomains { + sts += "; includeSubDomains" + } + if cfg.STSPreload { + sts += "; preload" + } + w.Header().Set("Strict-Transport-Security", sts) + } + + next.ServeHTTP(w, r) + }) + } +} + +// isHTTPS checks if the request is over HTTPS by examining TLS state and common proxy headers +func isHTTPS(r *http.Request) bool { + // direct TLS connection + if r.TLS != nil { + return true + } + // check common proxy headers (case-insensitive) + if strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") { + return true + } + // check RFC 7239 Forwarded header + if forwarded := r.Header.Get("Forwarded"); forwarded != "" { + if forwardedProtoIsHTTPS(forwarded) { + return true + } + } + return false +} + +// forwardedProtoIsHTTPS parses RFC 7239 Forwarded header to check for proto=https. +// The header format is: Forwarded: for=1.2.3.4;proto=https;by=proxy, for=5.6.7.8 +// Parameters are separated by semicolons, multiple forwarded elements by commas. +func forwardedProtoIsHTTPS(header string) bool { + // split by comma to get individual forwarded elements + for element := range strings.SplitSeq(header, ",") { + // split by semicolon to get parameters within element + for param := range strings.SplitSeq(element, ";") { + param = strings.TrimSpace(param) + // check for proto=https (case-insensitive per RFC 7239) + if len(param) > 6 && strings.EqualFold(param[:6], "proto=") { + if strings.EqualFold(strings.TrimSpace(param[6:]), "https") { + return true + } + } + } + } + return false +} diff --git a/vendor/github.com/lib/pq/.gitattributes b/vendor/github.com/lib/pq/.gitattributes new file mode 100644 index 00000000..dfdb8b77 --- /dev/null +++ b/vendor/github.com/lib/pq/.gitattributes @@ -0,0 +1 @@ +*.sh text eol=lf diff --git a/vendor/github.com/lib/pq/CHANGELOG.md b/vendor/github.com/lib/pq/CHANGELOG.md new file mode 100644 index 00000000..f0c46a96 --- /dev/null +++ b/vendor/github.com/lib/pq/CHANGELOG.md @@ -0,0 +1,265 @@ +unreleased +---------- + + +v1.12.3 (2026-04-03) +-------------------- +- Send datestyle startup parameter, improving compatbility with database engines + that use a different default datestyle such as EnterpriseDB ([#1312]). + +[#1312]: https://github.com/lib/pq/pull/1312 + +v1.12.2 (2026-04-02) +-------------------- + +- Treat io.ErrUnexpectedEOF as driver.ErrBadConn so database/sql discards the + connection. Since v1.12.0 this could result in permanently broken connections, + especially with CockroachDB which frequently sends partial messages ([#1299]). + +[#1299]: https://github.com/lib/pq/pull/1299 + +v1.12.1 (2026-03-30) +-------------------- + +- Look for pgpass file in ~/.pgpass instead of ~/.postgresql/pgpass ([#1300]). + +- Don't clear password if directly set on pq.Config ([#1302]). + +[#1300]: https://github.com/lib/pq/pull/1300 +[#1302]: https://github.com/lib/pq/pull/1302 + +v1.12.0 (2026-03-18) +-------------------- + +- The next release may change the default sslmode from `require` to `prefer`. + See [#1271] for details. + +- `CopyIn()` and `CopyInToSchema()` have been marked as deprecated. These are + simple query builders and not needed for `COPY [..] FROM STDIN` support (which + is *not* deprecated). ([#1279]) + + // Old + tx.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) + + // Replacement + tx.Prepare(`copy temp (num, text, blob, nothing) from stdin`) + +### Features + +- Support protocol 3.2, and the `min_protocol_version` and + `max_protocol_version` DSN parameters ([#1258]). + +- Support `sslmode=prefer` and `sslmode=allow` ([#1270]). + +- Support `ssl_min_protocol_version` and `ssl_max_protocol_version` ([#1277]). + +- Support connection service file to load connection details ([#1285]). + +- Support `sslrootcert=system` and use `~/.postgresql/root.crt` as the default + value of sslrootcert ([#1280], [#1281]). + +- Add a new `pqerror` package with PostgreSQL error codes ([#1275]). + + For example, to test if an error is a UNIQUE constraint violation: + + if pqErr, ok := errors.AsType[*pq.Error](err); ok && pqErr.Code == pqerror.UniqueViolation { + log.Fatalf("email %q already exsts", email) + } + + To make this a bit more convenient, it also adds a `pq.As()` function: + + pqErr := pq.As(err, pqerror.UniqueViolation) + if pqErr != nil { + log.Fatalf("email %q already exsts", email) + } + +### Fixes + +- Fix SSL key permission check to allow modes stricter than 0600/0640#1265 ([#1265]). + +- Fix Hstore to work with binary parameters ([#1278]). + +- Clearer error when starting a new query while pq is still processing another + query ([#1272]). + +- Send intermediate CAs with client certificates, so they can be signed by an + intermediate CA ([#1267]). + +- Use `time.UTC` for UTC aliases such as `Etc/UTC` ([#1282]). + +[#1258]: https://github.com/lib/pq/pull/1258 +[#1265]: https://github.com/lib/pq/pull/1265 +[#1267]: https://github.com/lib/pq/pull/1267 +[#1270]: https://github.com/lib/pq/pull/1270 +[#1271]: https://github.com/lib/pq/pull/1271 +[#1272]: https://github.com/lib/pq/pull/1272 +[#1275]: https://github.com/lib/pq/pull/1275 +[#1277]: https://github.com/lib/pq/pull/1277 +[#1278]: https://github.com/lib/pq/pull/1278 +[#1279]: https://github.com/lib/pq/pull/1279 +[#1280]: https://github.com/lib/pq/pull/1280 +[#1281]: https://github.com/lib/pq/pull/1281 +[#1282]: https://github.com/lib/pq/pull/1282 +[#1283]: https://github.com/lib/pq/pull/1283 +[#1285]: https://github.com/lib/pq/pull/1285 + +v1.11.2 (2026-02-10) +-------------------- +This fixes two regressions: + +- Don't send startup parameters if there is no value, improving compatibility + with Supavisor ([#1260]). + +- Don't send `dbname` as a startup parameter if `database=[..]` is used in the + connection string. It's recommended to use dbname=, as database= is not a + libpq option, and only worked by accident previously. ([#1261]) + +[#1260]: https://github.com/lib/pq/pull/1260 +[#1261]: https://github.com/lib/pq/pull/1261 + +v1.11.1 (2026-01-29) +-------------------- +This fixes two regressions present in the v1.11.0 release: + +- Fix build on 32bit systems, Windows, and Plan 9 ([#1253]). + +- Named []byte types and pointers to []byte (e.g. `*[]byte`, `json.RawMessage`) + would be treated as an array instead of bytea ([#1252]). + +[#1252]: https://github.com/lib/pq/pull/1252 +[#1253]: https://github.com/lib/pq/pull/1253 + +v1.11.0 (2026-01-28) +-------------------- +This version of pq requires Go 1.21 or newer. + +pq now supports only maintained PostgreSQL releases, which is PostgreSQL 14 and +newer. Previously PostgreSQL 8.4 and newer were supported. + +### Features + +- The `pq.Error.Error()` text includes the position of the error (if reported + by PostgreSQL) and SQLSTATE code ([#1219], [#1224]): + + pq: column "columndoesntexist" does not exist at column 8 (42703) + pq: syntax error at or near ")" at position 2:71 (42601) + +- The `pq.Error.ErrorWithDetail()` method prints a more detailed multiline + message, with the Detail, Hint, and error position (if any) ([#1219]): + + ERROR: syntax error at or near ")" (42601) + CONTEXT: line 12, column 1: + + 10 | name varchar, + 11 | version varchar, + 12 | ); + ^ + +- Add `Config`, `NewConfig()`, and `NewConnectorConfig()` to supply connection + details in a more structured way ([#1240]). + +- Support `hostaddr` and `$PGHOSTADDR` ([#1243]). + +- Support multiple values in `host`, `port`, and `hostaddr`, which are each + tried in order, or randomly if `load_balance_hosts=random` is set ([#1246]). + +- Support `target_session_attrs` connection parameter ([#1246]). + +- Support [`sslnegotiation`] to use SSL without negotiation ([#1180]). + +- Allow using a custom `tls.Config`, for example for encrypted keys ([#1228]). + +- Add `PQGO_DEBUG=1` print the communication with PostgreSQL to stderr, to aid + in debugging, testing, and bug reports ([#1223]). + +- Add support for NamedValueChecker interface ([#1125], [#1238]). + + +### Fixes + +- Match HOME directory lookup logic with libpq: prefer $HOME over /etc/passwd, + ignore ENOTDIR errors, and use APPDATA on Windows ([#1214]). + +- Fix `sslmode=verify-ca` verifying the hostname anyway when connecting to a DNS + name (rather than IP) ([#1226]). + +- Correctly detect pre-protocol errors such as the server not being able to fork + or running out of memory ([#1248]). + +- Fix build with wasm ([#1184]), appengine ([#745]), and Plan 9 ([#1133]). + +- Deprecate and type alias `pq.NullTime` to `sql.NullTime` ([#1211]). + +- Enforce integer limits of the Postgres wire protocol ([#1161]). + +- Accept the `passfile` connection parameter to override `PGPASSFILE` ([#1129]). + +- Fix connecting to socket on Windows systems ([#1179]). + +- Don't perform a permission check on the .pgpass file on Windows ([#595]). + +- Warn about incorrect .pgpass permissions ([#595]). + +- Don't set extra_float_digits ([#1212]). + +- Decode bpchar into a string ([#949]). + +- Fix panic in Ping() by not requiring CommandComplete or EmptyQueryResponse in + simpleQuery() ([#1234]) + +- Recognize bit/varbit ([#743]) and float types ([#1166]) in ColumnTypeScanType(). + +- Accept `PGGSSLIB` and `PGKRBSRVNAME` environment variables ([#1143]). + +- Handle ErrorResponse in readReadyForQuery and return proper error ([#1136]). + +- Detect COPY even if the query starts with whitespace or comments ([#1198]). + +- CopyIn() and CopyInSchema() now work if the list of columns is empty, in which + case it will copy all columns ([#1239]). + +- Treat nil []byte in query parameters as nil/NULL rather than `""` ([#838]). + +- Accept multiple authentication methods before checking AuthOk, which improves + compatibility with PgPool-II ([#1188]). + +[`sslnegotiation`]: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLNEGOTIATION +[#595]: https://github.com/lib/pq/pull/595 +[#745]: https://github.com/lib/pq/pull/745 +[#743]: https://github.com/lib/pq/pull/743 +[#838]: https://github.com/lib/pq/pull/838 +[#949]: https://github.com/lib/pq/pull/949 +[#1125]: https://github.com/lib/pq/pull/1125 +[#1129]: https://github.com/lib/pq/pull/1129 +[#1133]: https://github.com/lib/pq/pull/1133 +[#1136]: https://github.com/lib/pq/pull/1136 +[#1143]: https://github.com/lib/pq/pull/1143 +[#1161]: https://github.com/lib/pq/pull/1161 +[#1166]: https://github.com/lib/pq/pull/1166 +[#1179]: https://github.com/lib/pq/pull/1179 +[#1180]: https://github.com/lib/pq/pull/1180 +[#1184]: https://github.com/lib/pq/pull/1184 +[#1188]: https://github.com/lib/pq/pull/1188 +[#1198]: https://github.com/lib/pq/pull/1198 +[#1211]: https://github.com/lib/pq/pull/1211 +[#1212]: https://github.com/lib/pq/pull/1212 +[#1214]: https://github.com/lib/pq/pull/1214 +[#1219]: https://github.com/lib/pq/pull/1219 +[#1223]: https://github.com/lib/pq/pull/1223 +[#1224]: https://github.com/lib/pq/pull/1224 +[#1226]: https://github.com/lib/pq/pull/1226 +[#1228]: https://github.com/lib/pq/pull/1228 +[#1234]: https://github.com/lib/pq/pull/1234 +[#1238]: https://github.com/lib/pq/pull/1238 +[#1239]: https://github.com/lib/pq/pull/1239 +[#1240]: https://github.com/lib/pq/pull/1240 +[#1243]: https://github.com/lib/pq/pull/1243 +[#1246]: https://github.com/lib/pq/pull/1246 +[#1248]: https://github.com/lib/pq/pull/1248 + + +v1.10.9 (2023-04-26) +-------------------- +- Fixes backwards incompat bug with 1.13. + +- Fixes pgpass issue diff --git a/vendor/github.com/lib/pq/LICENSE b/vendor/github.com/lib/pq/LICENSE new file mode 100644 index 00000000..6a77dc4f --- /dev/null +++ b/vendor/github.com/lib/pq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2011-2013, 'pq' Contributors. Portions Copyright (c) 2011 Blake Mizerany + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/lib/pq/LICENSE.md b/vendor/github.com/lib/pq/LICENSE.md deleted file mode 100644 index 5773904a..00000000 --- a/vendor/github.com/lib/pq/LICENSE.md +++ /dev/null @@ -1,8 +0,0 @@ -Copyright (c) 2011-2013, 'pq' Contributors -Portions Copyright (C) 2011 Blake Mizerany - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index 126ee5d3..159b8672 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -1,36 +1,312 @@ -# pq - A pure Go postgres driver for Go's database/sql package +pq is a Go PostgreSQL driver for database/sql. -[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://pkg.go.dev/github.com/lib/pq?tab=doc) +All [maintained versions of PostgreSQL] are supported. Older versions may work, +but this is not tested. [API docs]. -## Install +[maintained versions of PostgreSQL]: https://www.postgresql.org/support/versioning +[API docs]: https://pkg.go.dev/github.com/lib/pq - go get github.com/lib/pq +Connecting +---------- +Use the `postgres` driver name in the `sql.Open()` call: -## Features +```go +package main -* SSL -* Handles bad connections for `database/sql` -* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) -* Scan binary blobs correctly (i.e. `bytea`) -* Package for `hstore` support -* COPY FROM support -* pq.ParseURL for converting urls to connection strings for sql.Open. -* Many libpq compatible environment variables -* Unix socket support -* Notifications: `LISTEN`/`NOTIFY` -* pgpass support -* GSS (Kerberos) auth +import ( + "database/sql" + "log" -## Tests + _ "github.com/lib/pq" // To register the driver. +) -`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. +func main() { + // Or as URL: postgresql://localhost/pqgo + db, err := sql.Open("postgres", "host=localhost dbname=pqgo") + if err != nil { + log.Fatal(err) + } + defer db.Close() -## Status + // db.Open() only creates a connection pool, and doesn't actually establish + // a connection. To ensure the connection works you need to do *something* + // with a connection. + err = db.Ping() + if err != nil { + log.Fatal(err) + } +} +``` -This package is currently in maintenance mode, which means: -1. It generally does not accept new features. -2. It does accept bug fixes and version compatability changes provided by the community. -3. Maintainers usually do not resolve reported issues. -4. Community members are encouraged to help each other with reported issues. +You can also use the `pq.Config` struct: -For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development. +```go +cfg := pq.Config{ + Host: "localhost", + Port: 5432, + User: "pqgo", +} +// Or: create a new Config from the defaults, environment, and DSN. +// cfg, err := pq.NewConfig("host=postgres dbname=pqgo") +// if err != nil { +// log.Fatal(err) +// } + +c, err := pq.NewConnectorConfig(cfg) +if err != nil { + log.Fatal(err) +} + +// Create connection pool. +db := sql.OpenDB(c) +defer db.Close() + +// Make sure it works. +err = db.Ping() +if err != nil { + log.Fatal(err) +} +``` + +The DSN is identical to PostgreSQL's libpq; most parameters are supported and +should behave identical. Both key=value and postgres:// URL-style connection +strings are supported. See the doc comments on the [Config struct] for the full +list and documentation. + +The most notable difference is that you can use any [run-time parameter] such as +`search_path` or `work_mem` in the connection string. This is different from +libpq, which uses the `options` parameter for this (which also works in pq). + +For example: + + sql.Open("postgres", "dbname=pqgo work_mem=100kB search_path=xyz") + +The libpq way (which also works in pq) is to use `options='-c k=v'` like so: + + sql.Open("postgres", "dbname=pqgo options='-c work_mem=100kB -c search_path=xyz'") + +[Config struct]: https://pkg.go.dev/github.com/lib/pq#Config +[run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html + +Errors +------ +Errors from PostgreSQL are returned as [pq.Error]; [pq.As] can be used to +convert an error to `pq.Error`: + +```go +pqErr := pq.As(err, pqerror.UniqueViolation) +if pqErr != nil { + return fmt.Errorf("email %q already exsts", email) +} +``` + +the Error() string contains the error message and code: + + pq: duplicate key value violates unique constraint "users_lower_idx" (23505) + +The ErrorWithDetail() string also contains the DETAIL and CONTEXT fields, if +present. For example for the above error this helpfully contains the duplicate +value: + + ERROR: duplicate key value violates unique constraint "users_lower_idx" (23505) + DETAIL: Key (lower(email))=(a@example.com) already exists. + +Or for an invalid syntax error like this: + + pq: invalid input syntax for type json (22P02) + +It contains the context where this error occurred: + + ERROR: invalid input syntax for type json (22P02) + DETAIL: Token "asd" is invalid. + CONTEXT: line 5, column 8: + + 3 | 'def', + 4 | 123, + 5 | 'foo', 'asd'::jsonb + ^ + +[pq.Error]: https://pkg.go.dev/github.com/lib/pq#Error +[pq.As]: https://pkg.go.dev/github.com/lib/pq#As + +PostgreSQL features +------------------- + +### Authentication +pq supports PASSWORD, MD5, and SCRAM-SHA256 authentication out of the box. If +you need GSS/Kerberos authentication you'll need to import the `auth/kerberos` +module: package: + + import "github.com/lib/pq/auth/kerberos" + + func init() { + pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) + } + +This is in a separate module so that users who don't need Kerberos (i.e. most +users) don't have to add unnecessary dependencies. + +Reading a [password file] (pgpass) is also supported. + +[password file]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + +### Bulk imports with `COPY [..] FROM STDIN` +You can perform bulk imports by preparing a `COPY [..] FROM STDIN` statement +inside a transaction. The returned `sql.Stmt` can then be repeatedly executed to +copy data. After all data has been processed you should call Exec() once with no +arguments to flush all buffered data. + +[Further documentation][copy-doc] and [example][copy-ex]. + +[copy-doc]: https://pkg.go.dev/github.com/lib/pq#hdr-Bulk_imports +[copy-ex]: https://pkg.go.dev/github.com/lib/pq#example-package-CopyFromStdin + +### NOTICE errors +PostgreSQL has "NOTICE" errors for informational messages. For example from the +psql CLI: + + pqgo=# drop table if exists doesnotexist; + NOTICE: table "doesnotexist" does not exist, skipping + DROP TABLE + +These errors are not returned because they're not really errors but, well, +notices. + +You can register a callback for these notices with [ConnectorWithNoticeHandler] + +[ConnectorWithNoticeHandler]: https://pkg.go.dev/github.com/lib/pq#ConnectorWithNoticeHandler + +### Using `LISTEN`/`NOTIFY` +With [pq.Listener] notifications are send on a channel. For example: + +```go +l := pq.NewListener("dbname=pqgo", time.Second, time.Minute, nil) +defer l.Close() + +err := l.Listen("coconut") +if err != nil { + log.Fatal(err) +} + +for { + n := <-l.Notify: + if n == nil { + fmt.Println("nil notify: closing Listener") + return + } + fmt.Printf("notification on %q with data %q\n", n.Channel, n.Extra) +} +``` + +And you'll get a notification for every `notify coconut`. + +See the API docs for a more complete example. + +[pq.Listener]: https://pkg.go.dev/github.com/lib/pq#Listener + + +Caveats +------- +### LastInsertId +sql.Result.LastInsertId() is not supported, because the PostgreSQL protocol does +not have this facility. Use `insert [..] returning [cols]` instead: + + db.QueryRow(`insert into tbl [..] returning id_col`).Scan(..) + // Or multiple rows: + db.Query(`insert into tbl (row1), (row2) returning id_col`) + +This will also work in SQLite and MariaDB with the same syntax. MS-SQL and +Oracle have a similar facility (with a different syntax). + +### timestamps +For timestamps with a timezone (`timestamptz`/`timestamp with time zone`), pq +uses the timezone configured in the server, as libpq. You can change this with +`timestamp=[..]` in the connection string. It's generally recommended to use +UTC. + +For timestamps without a timezone (`timestamp`/`timestamp without time zone`), +pq always uses `time.FixedZone("", 0)` as the timezone; the timestamp parameter +has no effect here. This is intentionally not equal to time.UTC, as it's not a +UTC time: it's a time without a timezone. Go's time package does not really +support this concept, so this is the best we can do This will print `+0000` +twice (e.g. `2026-03-15 17:45:47 +0000 +0000`; having a clearer name would have +been better, but is not compatible change). See [this comment][ts] for some +options on how to deal with this. + +Also see the examples for [timestamptz] and [timestamp] + +[ts]: https://github.com/lib/pq/issues/329#issuecomment-4025733506 +[timestamptz]: https://pkg.go.dev/github.com/lib/pq#example-package-TimestampWithTimezone +[timestamp]: https://pkg.go.dev/github.com/lib/pq#example-package-TimestampWithoutTimezone + +### bytea with copy +All `[]byte` parameters are encoded as `bytea` when using `copy [..] from +stdin`, which may result in errors for e.g. `jsonb` columns. The solution is to +use a string instead of []byte. See #1023 + +Development +----------- +### Running tests +Tests need to be run against a PostgreSQL database; you can use Docker compose +to start one: + + docker compose up -d + +This starts the latest PostgreSQL; use `docker compose up -d pg«v»` to start a +different version. + +In addition, your `/etc/hosts` needs an entry: + + 127.0.0.1 postgres postgres-invalid + +Or you can use any other PostgreSQL instance; see +`testdata/postgres/docker-entrypoint-initdb.d` for the required setup. You can use +the standard `PG*` environment variables to control the connection details; it +uses the following defaults: + + PGHOST=localhost + PGDATABASE=pqgo + PGUSER=pqgo + PGSSLMODE=disable + PGCONNECT_TIMEOUT=20 + +`PQTEST_BINARY_PARAMETERS` can be used to add `binary_parameters=yes` to all +connection strings: + + PQTEST_BINARY_PARAMETERS=1 go test + +Tests can be run against pgbouncer with: + + docker compose up -d pgbouncer pg18 + PGPORT=6432 go test ./... + +and pgpool with: + + docker compose up -d pgpool pg18 + PGPORT=7432 go test ./... + +### Protocol debug output +You can use PQGO_DEBUG=1 to make the driver print the communication with +PostgreSQL to stderr; this works anywhere (test or applications) and can be +useful to debug protocol problems. + +For example: + + % PQGO_DEBUG=1 go test -run TestSimpleQuery + CLIENT → Startup 69 "\x00\x03\x00\x00database\x00pqgo\x00user [..]" + SERVER ← (R) AuthRequest 4 "\x00\x00\x00\x00" + SERVER ← (S) ParamStatus 19 "in_hot_standby\x00off\x00" + [..] + SERVER ← (Z) ReadyForQuery 1 "I" + START conn.query + START conn.simpleQuery + CLIENT → (Q) Query 9 "select 1\x00" + SERVER ← (T) RowDescription 29 "\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x00\x04\xff\xff\xff\xff\x00\x00" + SERVER ← (D) DataRow 7 "\x00\x01\x00\x00\x00\x011" + END conn.simpleQuery + END conn.query + SERVER ← (C) CommandComplete 9 "SELECT 1\x00" + SERVER ← (Z) ReadyForQuery 1 "I" + CLIENT → (X) Terminate 0 "" + PASS + ok github.com/lib/pq 0.010s diff --git a/vendor/github.com/lib/pq/TESTS.md b/vendor/github.com/lib/pq/TESTS.md deleted file mode 100644 index f0502111..00000000 --- a/vendor/github.com/lib/pq/TESTS.md +++ /dev/null @@ -1,33 +0,0 @@ -# Tests - -## Running Tests - -`go test` is used for testing. A running PostgreSQL -server is required, with the ability to log in. The -database to connect to test with is "pqgotest," on -"localhost" but these can be overridden using [environment -variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html). - -Example: - - PGHOST=/run/postgresql go test - -## Benchmarks - -A benchmark suite can be run as part of the tests: - - go test -bench . - -## Example setup (Docker) - -Run a postgres container: - -``` -docker run --expose 5432:5432 postgres -``` - -Run tests: - -``` -PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test -``` diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index 39c8f7e2..4a532868 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -19,14 +19,15 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // slice of any dimension. // // For example: -// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // -// var x []sql.NullInt64 -// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. -func Array(a interface{}) interface { +func Array(a any) interface { driver.Valuer sql.Scanner } { @@ -76,7 +77,7 @@ type ArrayDelimiter interface { type BoolArray []bool // Scan implements the sql.Scanner interface. -func (a *BoolArray) Scan(src interface{}) error { +func (a *BoolArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -150,7 +151,7 @@ func (a BoolArray) Value() (driver.Value, error) { type ByteaArray [][]byte // Scan implements the sql.Scanner interface. -func (a *ByteaArray) Scan(src interface{}) error { +func (a *ByteaArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -176,7 +177,7 @@ func (a *ByteaArray) scanBytes(src []byte) error { for i, v := range elems { b[i], err = parseBytea(v) if err != nil { - return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + return fmt.Errorf("could not parse bytea array index %d: %w", i, err) } } *a = b @@ -222,7 +223,7 @@ func (a ByteaArray) Value() (driver.Value, error) { type Float64Array []float64 // Scan implements the sql.Scanner interface. -func (a *Float64Array) Scan(src interface{}) error { +func (a *Float64Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -246,8 +247,9 @@ func (a *Float64Array) scanBytes(src []byte) error { } else { b := make(Float64Array, len(elems)) for i, v := range elems { - if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + b[i], err = strconv.ParseFloat(string(v), 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } *a = b @@ -284,7 +286,7 @@ func (a Float64Array) Value() (driver.Value, error) { type Float32Array []float32 // Scan implements the sql.Scanner interface. -func (a *Float32Array) Scan(src interface{}) error { +func (a *Float32Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -308,9 +310,9 @@ func (a *Float32Array) scanBytes(src []byte) error { } else { b := make(Float32Array, len(elems)) for i, v := range elems { - var x float64 - if x, err = strconv.ParseFloat(string(v), 32); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + x, err := strconv.ParseFloat(string(v), 32) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } b[i] = float32(x) } @@ -345,7 +347,7 @@ func (a Float32Array) Value() (driver.Value, error) { // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. -type GenericArray struct{ A interface{} } +type GenericArray struct{ A any } func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { var assign func([]byte, reflect.Value) error @@ -354,7 +356,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { - if reflect.PtrTo(rt).Implements(typeSQLScanner) { + if reflect.PointerTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) @@ -383,10 +385,10 @@ FoundType: } // Scan implements the sql.Scanner interface. -func (a GenericArray) Scan(src interface{}) error { +func (a GenericArray) Scan(src any) error { dpv := reflect.ValueOf(a.A) switch { - case dpv.Kind() != reflect.Ptr: + case dpv.Kind() != reflect.Pointer: return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) case dpv.IsNil(): return fmt.Errorf("pq: destination %T is nil", a.A) @@ -449,8 +451,9 @@ func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) for i, e := range elems { - if err := assign(e, values.Index(i)); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + err := assign(e, values.Index(i)) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } @@ -483,7 +486,7 @@ func (a GenericArray) Value() (driver.Value, error) { } case reflect.Array: default: - return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) + return nil, fmt.Errorf("pq: unable to convert %T to array", a.A) } if n := rv.Len(); n > 0 { @@ -502,7 +505,7 @@ func (a GenericArray) Value() (driver.Value, error) { type Int64Array []int64 // Scan implements the sql.Scanner interface. -func (a *Int64Array) Scan(src interface{}) error { +func (a *Int64Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -526,8 +529,9 @@ func (a *Int64Array) scanBytes(src []byte) error { } else { b := make(Int64Array, len(elems)) for i, v := range elems { - if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + b[i], err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } *a = b @@ -563,7 +567,7 @@ func (a Int64Array) Value() (driver.Value, error) { type Int32Array []int32 // Scan implements the sql.Scanner interface. -func (a *Int32Array) Scan(src interface{}) error { +func (a *Int32Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -589,7 +593,7 @@ func (a *Int32Array) scanBytes(src []byte) error { for i, v := range elems { x, err := strconv.ParseInt(string(v), 10, 32) if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } b[i] = int32(x) } @@ -626,7 +630,7 @@ func (a Int32Array) Value() (driver.Value, error) { type StringArray []string // Scan implements the sql.Scanner interface. -func (a *StringArray) Scan(src interface{}) error { +func (a *StringArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -683,10 +687,10 @@ func (a StringArray) Value() (driver.Value, error) { return "{}", nil } -// appendArray appends rv to the buffer, returning the extended buffer and -// the delimiter used between elements. +// appendArray appends rv to the buffer, returning the extended buffer and the +// delimiter used between elements. // -// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +// Returns an error when n <= 0 or rv is not a reflect.Array or reflect.Slice. func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { var del string var err error @@ -728,7 +732,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { var del = "," var err error - var iv interface{} = rv.Interface() + var iv = rv.Interface() if ad, ok := iv.(ArrayDelimiter); ok { del = ad.ArrayDelimiter() @@ -769,7 +773,11 @@ func appendArrayQuotedBytes(b, v []byte) []byte { } func appendValue(b []byte, v driver.Value) ([]byte, error) { - return append(b, encode(nil, v, 0)...), nil + enc, err := encode(v, 0) + if err != nil { + return nil, err + } + return append(b, enc...), nil } // parseArray extracts the dimensions and elements of an array represented in diff --git a/vendor/github.com/lib/pq/as.go b/vendor/github.com/lib/pq/as.go new file mode 100644 index 00000000..1ea0ee5b --- /dev/null +++ b/vendor/github.com/lib/pq/as.go @@ -0,0 +1,26 @@ +//go:build !go1.26 + +package pq + +import ( + "errors" + "slices" +) + +// As asserts that the given error is [pq.Error] and returns it, returning nil +// if it's not a pq.Error. +// +// It will return nil if the pq.Error is not one of the given error codes. If no +// codes are given it will always return the Error. +// +// This is safe to call with a nil error. +func As(err error, codes ...ErrorCode) *Error { + if err == nil { // Not strictly needed, but prevents alloc for nil errors. + return nil + } + pqErr := new(Error) + if errors.As(err, &pqErr) && (len(codes) == 0 || slices.Contains(codes, pqErr.Code)) { + return pqErr + } + return nil +} diff --git a/vendor/github.com/lib/pq/as_go126.go b/vendor/github.com/lib/pq/as_go126.go new file mode 100644 index 00000000..18ffbc37 --- /dev/null +++ b/vendor/github.com/lib/pq/as_go126.go @@ -0,0 +1,23 @@ +//go:build go1.26 + +package pq + +import ( + "errors" + "github.com/lib/pq/pqerror" + "slices" +) + +// As asserts that the given error is [pq.Error] and returns it, returning nil +// if it's not a pq.Error. +// +// It will return nil if the pq.Error is not one of the given error codes. If no +// codes are given it will always return the Error. +// +// This is safe to call with a nil error. +func As(err error, codes ...pqerror.Code) *Error { + if pqErr, ok := errors.AsType[*Error](err); ok && (len(codes) == 0 || slices.Contains(codes, pqErr.Code)) { + return pqErr + } + return nil +} diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go index 4b0a0a8f..67ca60cc 100644 --- a/vendor/github.com/lib/pq/buf.go +++ b/vendor/github.com/lib/pq/buf.go @@ -3,7 +3,10 @@ package pq import ( "bytes" "encoding/binary" + "errors" + "fmt" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" ) @@ -31,7 +34,7 @@ func (b *readBuf) int16() (n int) { func (b *readBuf) string() string { i := bytes.IndexByte(*b, 0) if i < 0 { - errorf("invalid message format; expected string terminator") + panic(errors.New("pq: invalid message format; expected string terminator")) } s := (*b)[:i] *b = (*b)[i+1:] @@ -69,8 +72,8 @@ func (b *writeBuf) string(s string) { b.buf = append(append(b.buf, s...), '\000') } -func (b *writeBuf) byte(c byte) { - b.buf = append(b.buf, c) +func (b *writeBuf) byte(c proto.RequestCode) { + b.buf = append(b.buf, byte(c)) } func (b *writeBuf) bytes(v []byte) { @@ -79,13 +82,19 @@ func (b *writeBuf) bytes(v []byte) { func (b *writeBuf) wrap() []byte { p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } binary.BigEndian.PutUint32(p, uint32(len(p))) return b.buf } -func (b *writeBuf) next(c byte) { +func (b *writeBuf) next(c proto.RequestCode) { p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } binary.BigEndian.PutUint32(p, uint32(len(p))) b.pos = len(b.buf) + 1 - b.buf = append(b.buf, c, 0, 0, 0, 0) + b.buf = append(b.buf, byte(c), 0, 0, 0, 0) } diff --git a/vendor/github.com/lib/pq/compose.yaml b/vendor/github.com/lib/pq/compose.yaml new file mode 100644 index 00000000..0092b7e3 --- /dev/null +++ b/vendor/github.com/lib/pq/compose.yaml @@ -0,0 +1,89 @@ +name: 'pqgo' + +services: + pgbouncer: + profiles: ['pgbouncer'] + image: 'cleanstart/pgbouncer:latest' + ports: ['127.0.0.1:6432:6432'] + command: ['/init/pgbouncer.ini'] + volumes: ['./testdata/pgbouncer:/init', './testdata/ssl:/ssl'] + environment: + 'PGBOUNCER_DATABASE': 'pqgo' + + pgpool: + profiles: ['pgpool'] + image: 'pgpool/pgpool:4.4.3' + ports: ['127.0.0.1:7432:7432'] + volumes: ['./testdata/pgpool:/init', './testdata/ssl:/ssl'] + entrypoint: '/init/entry.sh' + environment: + 'PGPOOL_PARAMS_PORT': '7432' + 'PGPOOL_PARAMS_BACKEND_HOSTNAME0': 'pg18' + + cockroach: + profiles: ['cockroach'] + image: 'cockroachdb/cockroach:latest-v26.1' + ports: ['127.0.0.1:26257:26257'] + volumes: ['./testdata/cockroach:/docker-entrypoint-initdb.d', './testdata/ssl:/ssl'] + command: ['start-single-node', '--accept-sql-without-tls', '--certs-dir=/ssl2'] + healthcheck: {test: ['CMD-SHELL', '/cockroach/cockroach node status --insecure --user=pqgo'], start_period: '30s', start_interval: '1s'} + + pg18: + image: 'postgres:18' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/postgres:/init', './testdata/ssl:/ssl'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready -U pqgo -d pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg17: + profiles: ['pg17'] + image: 'postgres:17' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/postgres:/init', './testdata/ssl:/ssl'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready -U pqgo -d pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg16: + profiles: ['pg16'] + image: 'postgres:16' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/postgres:/init', './testdata/ssl:/ssl'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready -U pqgo -d pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg15: + profiles: ['pg15'] + image: 'postgres:15' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/postgres:/init', './testdata/ssl:/ssl'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready -U pqgo -d pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg14: + profiles: ['pg14'] + image: 'postgres:14' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/postgres:/init', './testdata/ssl:/ssl'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready -U pqgo -d pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index da4ff9de..688329cd 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -2,7 +2,6 @@ package pq import ( "bufio" - "bytes" "context" "crypto/md5" "crypto/sha256" @@ -12,31 +11,34 @@ import ( "errors" "fmt" "io" + "math" "net" "os" - "os/user" - "path" - "path/filepath" + "reflect" "strconv" "strings" "sync" + "sync/atomic" "time" - "unicode" + "github.com/lib/pq/internal/pgpass" + "github.com/lib/pq/internal/pqsql" + "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" "github.com/lib/pq/scram" ) // Common error types var ( - ErrNotSupported = errors.New("pq: Unsupported command") - ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") + ErrNotSupported = errors.New("pq: unsupported command") + ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") - - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly") + ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership + ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions + errQueryInProgress = errors.New("pq: there is already a query being processed on this connection") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") @@ -44,9 +46,32 @@ var ( // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = Driver{} + _ driver.Driver = Driver{} + _ driver.ConnBeginTx = (*conn)(nil) + _ driver.ConnPrepareContext = (*conn)(nil) + _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.ExecerContext = (*conn)(nil) + _ driver.NamedValueChecker = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.QueryerContext = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) ) +func init() { + sql.Register("postgres", &Driver{}) +} + +var debugProto = func() bool { + // Check for exactly "1" (rather than mere existence) so we can add + // options/flags in the future. I don't know if we ever want that, but it's + // nice to leave the option open. + return os.Getenv("PQGO_DEBUG") == "1" +}() + // Driver is the Postgres database driver. type Driver struct{} @@ -57,19 +82,27 @@ func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } -func init() { - sql.Register("postgres", &Driver{}) +// Parameters sent by PostgreSQL on startup. +type parameterStatus struct { + serverVersion int + currentLocation *time.Location + inHotStandby, defaultTransactionReadOnly sql.NullBool } -type parameterStatus struct { - // server version in the same format as server_version_num, or 0 if - // unavailable - serverVersion int +type format int - // the current location based on the TimeZone value of the session, if - // available - currentLocation *time.Location -} +const ( + formatText format = 0 + formatBinary format = 1 +) + +var ( + // One result-column format code with the value 1 (i.e. all binary). + colFmtDataAllBinary = []byte{0, 1, 0, 1} + + // No result-column format codes (i.e. all text). + colFmtDataAllText = []byte{0, 0} +) type transactionStatus byte @@ -88,10 +121,8 @@ func (s transactionStatus) String() string { case txnStatusInFailedTransaction: return "in a failed transaction" default: - errorf("unknown transactionStatus %d", s) + panic(fmt.Sprintf("pq: unknown transactionStatus %d", s)) } - - panic("not reached") } // Dialer is the dialer interface. It can be used to obtain more control over @@ -113,13 +144,13 @@ type defaultDialer struct { func (d defaultDialer) Dial(network, address string) (net.Conn, error) { return d.d.Dial(network, address) } -func (d defaultDialer) DialTimeout( - network, address string, timeout time.Duration, -) (net.Conn, error) { + +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return d.DialContext(ctx, network, address) } + func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.d.DialContext(ctx, network, address) } @@ -133,43 +164,24 @@ type conn struct { txnFinish func() // Save connection arguments to use during CancelRequest. - dialer Dialer - opts values - - // Cancellation key data for use with CancelRequest messages. - processID int - secretKey int - + dialer Dialer + cfg Config parameterStatus parameterStatus - saveMessageType byte + saveMessageType proto.ResponseCode saveMessageBuffer []byte - // If an error is set, this connection is bad and all public-facing + // If an error is set this connection is bad and all public-facing // functions should return the appropriate error by calling get() // (ErrBadConn) or getForNext(). err syncErr - // If set, this connection should never use the binary format when - // receiving query results from prepared statements. Only provided for - // debugging. - disablePreparedBinaryResult bool - - // Whether to always send []byte parameters over as binary. Enables single - // round-trip mode for non-prepared Query calls. - binaryParameters bool - - // If true this connection is in the middle of a COPY - inCopy bool - - // If not nil, notices will be synchronously sent here - noticeHandler func(*Error) - - // If not nil, notifications will be synchronously sent here - notificationHandler func(*Notification) - - // GSSAPI context - gss GSS + secretKey []byte // Cancellation key for CancelRequest messages. + pid int // Cancellation PID. + inProgress atomic.Bool // This connection is in the middle of a processing a request. + noticeHandler func(*Error) // If not nil, notices will be synchronously sent here + notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here + gss GSS // GSSAPI context } type syncErr struct { @@ -206,125 +218,16 @@ func (e *syncErr) set(err error) { } } -// Handle driver-side settings in parsed connection string. -func (cn *conn) handleDriverSettings(o values) (err error) { - boolSetting := func(key string, val *bool) error { - if value, ok := o[key]; ok { - if value == "yes" { - *val = true - } else if value == "no" { - *val = false - } else { - return fmt.Errorf("unrecognized value %q for %s", value, key) - } - } - return nil - } - - err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) - if err != nil { - return err - } - return boolSetting("binary_parameters", &cn.binaryParameters) -} - -func (cn *conn) handlePgpass(o values) { - // if a password was supplied, do not process .pgpass - if _, ok := o["password"]; ok { - return - } - filename := os.Getenv("PGPASSFILE") - if filename == "" { - // XXX this code doesn't work on Windows where the default filename is - // XXX %APPDATA%\postgresql\pgpass.conf - // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 - userHome := os.Getenv("HOME") - if userHome == "" { - user, err := user.Current() - if err != nil { - return - } - userHome = user.HomeDir - } - filename = filepath.Join(userHome, ".pgpass") - } - fileinfo, err := os.Stat(filename) - if err != nil { - return - } - mode := fileinfo.Mode() - if mode&(0x77) != 0 { - // XXX should warn about incorrect .pgpass permissions as psql does - return - } - file, err := os.Open(filename) - if err != nil { - return - } - defer file.Close() - scanner := bufio.NewScanner(io.Reader(file)) - // From: https://github.com/tg/pgpass/blob/master/reader.go - for scanner.Scan() { - if scanText(scanner.Text(), o) { - break - } - } -} - -// GetFields is a helper function for scanText. -func getFields(s string) []string { - fs := make([]string, 0, 5) - f := make([]rune, 0, len(s)) - - var esc bool - for _, c := range s { - switch { - case esc: - f = append(f, c) - esc = false - case c == '\\': - esc = true - case c == ':': - fs = append(fs, string(f)) - f = f[:0] - default: - f = append(f, c) - } - } - return append(fs, string(f)) -} - -// ScanText assists HandlePgpass in it's objective. -func scanText(line string, o values) bool { - hostname := o["host"] - ntw, _ := network(o) - port := o["port"] - db := o["dbname"] - username := o["user"] - if len(line) == 0 || line[0] == '#' { - return false - } - split := getFields(line) - if len(split) != 5 { - return false - } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { - o["password"] = split[4] - return true - } - return false -} - -func (cn *conn) writeBuf(b byte) *writeBuf { - cn.scratch[0] = b +func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf { + cn.scratch[0] = byte(b) return &writeBuf{ buf: cn.scratch[:5], pos: 1, } } -// Open opens a new connection to the database. dsn is a connection string. -// Most users should only use it through database/sql package from the standard +// Open opens a new connection to the database. dsn is a connection string. Most +// users should only use it through database/sql package from the standard // library. func Open(dsn string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, dsn) @@ -340,86 +243,211 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { return c.open(context.Background()) } -func (c *Connector) open(ctx context.Context) (cn *conn, err error) { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn any - // connection errors into ErrBadConns, hiding the real error message from - // the user. - defer errRecoverNoErrBadConn(&err) +func (c *Connector) open(ctx context.Context) (*conn, error) { + tsa := c.cfg.TargetSessionAttrs +restartAll: + var ( + errs []error + app = func(err error, cfg Config) bool { + if err != nil { + if debugProto { + fmt.Fprintln(os.Stderr, "CONNECT (error)", err) + } + errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) + } + return err != nil + } + ) + for _, cfg := range c.cfg.hosts() { + mode := cfg.SSLMode + restartHost: + if debugProto { + fmt.Fprintln(os.Stderr, "CONNECT ", cfg.string()) + } + + cfg.SSLMode = mode + cn := &conn{cfg: cfg, dialer: c.dialer} + cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, + cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database) + + var err error + cn.c, err = dial(ctx, c.dialer, cn.cfg) + if app(err, cfg) { + continue + } - // Create a new values map (copy). This makes it so maps in different - // connections do not reference the same underlying data structure, so it - // is safe for multiple connections to concurrently write to their opts. - o := make(values) - for k, v := range c.opts { - o[k] = v + err = cn.ssl(cn.cfg, mode) + if err != nil && mode == SSLModePrefer { + mode = SSLModeDisable + goto restartHost + } + if app(err, cfg) { + if cn.c != nil { + _ = cn.c.Close() + } + continue + } + + cn.buf = bufio.NewReader(cn.c) + err = cn.startup(cn.cfg) + if err != nil && mode == SSLModeAllow { + mode = SSLModeRequire + goto restartHost + } + if app(err, cfg) { + _ = cn.c.Close() + continue + } + + // Reset the deadline, in case one was set (see dial) + if cn.cfg.ConnectTimeout > 0 { + err := cn.c.SetDeadline(time.Time{}) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + } + + err = cn.checkTSA(tsa) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + + return cn, nil } - cn = &conn{ - opts: o, - dialer: c.dialer, + // target_session_attrs=prefer-standby is treated as standby in checkTSA; we + // ran out of hosts so none are on standby. Clear the setting and try again. + if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { + tsa = TargetSessionAttrsAny + goto restartAll } - err = cn.handleDriverSettings(o) - if err != nil { - return nil, err + + if len(c.cfg.Multi) == 0 { + // Remove the "connecting to [..]" when we have just one host, so the + // error is identical to what we had before. + return nil, errors.Unwrap(errs[0]) } - cn.handlePgpass(o) + return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...)) +} - cn.c, err = dial(ctx, c.dialer, o) +func (cn *conn) getBool(query string) (bool, error) { + res, err := cn.simpleQuery(query) if err != nil { - return nil, err + return false, err } + defer res.Close() - err = cn.ssl(o) + v := make([]driver.Value, 1) + err = res.Next(v) if err != nil { - if cn.c != nil { - cn.c.Close() - } - return nil, err + return false, err } - // cn.startup panics on error. Make sure we don't leak cn.c. - panicking := true - defer func() { - if panicking { - cn.c.Close() + switch vv := v[0].(type) { + default: + return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0]) + case bool: + return vv, nil + case string: + vv, ok := v[0].(string) + if !ok { + return false, err } - }() - - cn.buf = bufio.NewReader(cn.c) - cn.startup(o) - - // reset the deadline, in case one was set (see dial) - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - err = cn.c.SetDeadline(time.Time{}) + return vv == "on", nil } - panicking = false - return cn, err } -func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { - network, address := network(o) +func (cn *conn) checkTSA(tsa TargetSessionAttrs) error { + var ( + geths = func() (hs bool, err error) { + hs = cn.parameterStatus.inHotStandby.Bool + if !cn.parameterStatus.inHotStandby.Valid { + hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()") + } + return hs, err + } + getro = func() (ro bool, err error) { + ro = cn.parameterStatus.defaultTransactionReadOnly.Bool + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + ro, err = cn.getBool("show transaction_read_only") + } + return ro, err + } + ) - // Zero or not specified means wait indefinitely. - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - seconds, err := strconv.ParseInt(timeout, 10, 0) + switch tsa { + default: + panic("unreachable") + case "", TargetSessionAttrsAny: + return nil + case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly: + readonly, err := getro() if err != nil { - return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) + return err } - duration := time.Duration(seconds) * time.Second + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + var err error + readonly, err = cn.getBool("show transaction_read_only") + if err != nil { + return err + } + } + switch { + case tsa == TargetSessionAttrsReadOnly && !readonly: + return errors.New("session is not read-only") + case tsa == TargetSessionAttrsReadWrite: + if readonly { + return errors.New("session is read-only") + } + hs, err := geths() + if err != nil { + return err + } + if hs { + return errors.New("server is in hot standby mode") + } + return nil + default: + return nil + } + case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby: + hs, err := geths() + if err != nil { + return err + } + switch { + case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs: + return errors.New("server is not in hot standby mode") + case tsa == TargetSessionAttrsPrimary && hs: + return errors.New("server is in hot standby mode") + default: + return nil + } + } +} +func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) { + network, address := cfg.network() + + // Zero or not specified means wait indefinitely. + if cfg.ConnectTimeout > 0 { // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection - // establishment and set a deadline for doing the initial handshake. - // The deadline is then reset after startup() is done. - deadline := time.Now().Add(duration) - var conn net.Conn + // establishment and set a deadline for doing the initial handshake. The + // deadline is then reset after startup() is done. + var ( + deadline = time.Now().Add(cfg.ConnectTimeout) + conn net.Conn + err error + ) if dctx, ok := d.(DialerContext); ok { - ctx, cancel := context.WithTimeout(ctx, duration) + ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) defer cancel() conn, err = dctx.DialContext(ctx, network, address) } else { - conn, err = d.DialTimeout(network, address, duration) + conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout) } if err != nil { return nil, err @@ -433,140 +461,17 @@ func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { return d.Dial(network, address) } -func network(o values) (string, string) { - host := o["host"] - - if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o["port"]) - return "unix", sockPath - } - - return "tcp", net.JoinHostPort(host, o["port"]) -} - -type values map[string]string - -// scanner implements a tokenizer for libpq-style option strings. -type scanner struct { - s []rune - i int -} - -// newScanner returns a new scanner initialized with the option string s. -func newScanner(s string) *scanner { - return &scanner{[]rune(s), 0} -} - -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) Next() (rune, bool) { - if s.i >= len(s.s) { - return 0, false - } - r := s.s[s.i] - s.i++ - return r, true -} - -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) SkipSpaces() (rune, bool) { - r, ok := s.Next() - for unicode.IsSpace(r) && ok { - r, ok = s.Next() - } - return r, ok -} - -// parseOpts parses the options from name and adds them to the values. -// -// The parsing code is based on conninfo_parse from libpq's fe-connect.c -func parseOpts(name string, o values) error { - s := newScanner(name) - - for { - var ( - keyRunes, valRunes []rune - r rune - ok bool - ) - - if r, ok = s.SkipSpaces(); !ok { - break - } - - // Scan the key - for !unicode.IsSpace(r) && r != '=' { - keyRunes = append(keyRunes, r) - if r, ok = s.Next(); !ok { - break - } - } - - // Skip any whitespace if we're not at the = yet - if r != '=' { - r, ok = s.SkipSpaces() - } - - // The current character should be = - if r != '=' || !ok { - return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) - } - - // Skip any whitespace after the = - if r, ok = s.SkipSpaces(); !ok { - // If we reach the end here, the last value is just an empty string as per libpq. - o[string(keyRunes)] = "" - break - } - - if r != '\'' { - for !unicode.IsSpace(r) { - if r == '\\' { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`missing character after backslash`) - } - } - valRunes = append(valRunes, r) - - if r, ok = s.Next(); !ok { - break - } - } - } else { - quote: - for { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`unterminated quoted string literal in connection string`) - } - switch r { - case '\'': - break quote - case '\\': - r, _ = s.Next() - fallthrough - default: - valRunes = append(valRunes, r) - } - } - } - - o[string(keyRunes)] = string(valRunes) - } - - return nil -} - func (cn *conn) isInTransaction() bool { return cn.txnStatus == txnStatusIdleInTransaction || cn.txnStatus == txnStatusInFailedTransaction } -func (cn *conn) checkIsInTransaction(intxn bool) { +func (cn *conn) checkIsInTransaction(intxn bool) error { if cn.isInTransaction() != intxn { cn.err.set(driver.ErrBadConn) - errorf("unexpected transaction status %v", cn.txnStatus) + return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus) } + return nil } func (cn *conn) Begin() (_ driver.Tx, err error) { @@ -577,12 +482,13 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) + if err := cn.checkIsInTransaction(false); err != nil { + return nil, err + } - cn.checkIsInTransaction(false) _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { - return nil, err + return nil, cn.handleError(err) } if commandTag != "BEGIN" { cn.err.set(driver.ErrBadConn) @@ -601,14 +507,15 @@ func (cn *conn) closeTxn() { } } -func (cn *conn) Commit() (err error) { +func (cn *conn) Commit() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } - defer cn.errRecover(&err) + if err := cn.checkIsInTransaction(true); err != nil { + return err + } - cn.checkIsInTransaction(true) // We don't want the client to think that everything is okay if it tries // to commit a failed transaction. However, no matter what we return, // database/sql will release this connection back into the free connection @@ -627,27 +534,33 @@ func (cn *conn) Commit() (err error) { if cn.isInTransaction() { cn.err.set(driver.ErrBadConn) } - return err + return cn.handleError(err) } if commandTag != "COMMIT" { cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected command tag %s", commandTag) } - cn.checkIsInTransaction(false) - return nil + return cn.checkIsInTransaction(false) } -func (cn *conn) Rollback() (err error) { +func (cn *conn) Rollback() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } - defer cn.errRecover(&err) - return cn.rollback() + + err := cn.rollback() + if err != nil { + return cn.handleError(err) + } + return nil } func (cn *conn) rollback() (err error) { - cn.checkIsInTransaction(true) + if err := cn.checkIsInTransaction(true); err != nil { + return err + } + _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { @@ -658,8 +571,7 @@ func (cn *conn) rollback() (err error) { if commandTag != "ROLLBACK" { return fmt.Errorf("unexpected command tag %s", commandTag) } - cn.checkIsInTransaction(false) - return nil + return cn.checkIsInTransaction(false) } func (cn *conn) gname() string { @@ -667,126 +579,136 @@ func (cn *conn) gname() string { return strconv.FormatInt(int64(cn.namei), 10) } -func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { - b := cn.writeBuf('Q') +func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.simpleExec") + defer fmt.Fprintln(os.Stderr, " END conn.simpleExec") + } + + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, "", err + } for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } switch t { - case 'C': - res, commandTag = cn.parseComplete(r.string()) - case 'Z': + case proto.CommandComplete: + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady + if res == nil && resErr == nil { + resErr = errUnexpectedReady } - // done - return - case 'E': - err = parseError(r) - case 'I': + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.EmptyQueryResponse: res = emptyRows - case 'T', 'D': + case proto.RowDescription, proto.DataRow: // ignore any results default: cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) + return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t) } } } -func (cn *conn) simpleQuery(q string) (res *rows, err error) { - defer cn.errRecover(&err) +func (cn *conn) simpleQuery(q string) (*rows, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.simpleQuery") + defer fmt.Fprintln(os.Stderr, " END conn.simpleQuery") + } - b := cn.writeBuf('Q') + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, cn.handleError(err, q) + } + var ( + res *rows + resErr error + ) for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, cn.handleError(err, q) + } switch t { - case 'C', 'I': + case proto.CommandComplete, proto.EmptyQueryResponse: // We allow queries which don't return any results through Query as - // well as Exec. We still have to give database/sql a rows object + // well as Exec. We still have to give database/sql a rows object // the user can close, though, to avoid connections from being - // leaked. A "rows" with done=true works fine for that purpose. - if err != nil { + // leaked. A "rows" with done=true works fine for that purpose. + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q in simple query execution", t) + return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t) } if res == nil { - res = &rows{ - cn: cn, - } + res = &rows{cn: cn} } // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' { - res.result, res.tag = cn.parseComplete(r.string()) + if t == proto.CommandComplete { + res.result, res.tag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, cn.handleError(err, q) + } if res.colNames != nil { - return + return res, cn.handleError(resErr, q) } } res.done = true - case 'Z': + case proto.ReadyForQuery: cn.processReadyForQuery(r) - // done - return - case 'E': + if err == nil && res == nil { + res = &rows{done: true} + } + return res, cn.handleError(resErr, q) // done + case proto.ErrorResponse: res = nil - err = parseError(r) - case 'D': + resErr = parseError(r, q) + case proto.DataRow: if res == nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow in simple query execution") + return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution") } - // the query didn't fail; kick off to Next - cn.saveMessage(t, r) - return - case 'T': + return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next + case proto.RowDescription: // res might be non-nil here if we received a previous - // CommandComplete, but that's fine; just overwrite it - res = &rows{cn: cn} - res.rowsHeader = parsePortalRowDescribe(r) + // CommandComplete, but that's fine and just overwrite it. + res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)} // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) + return nil, fmt.Errorf("pq: unknown response for simple query: %q", t) } } } -type noRows struct{} - -var emptyRows noRows - -var _ driver.Result = noRows{} - -func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertID -} - -func (noRows) RowsAffected() (int64, error) { - return 0, errNoRowsAffected -} - // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats( - colTyps []fieldDesc, forceText bool, -) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) { if len(colTyps) == 0 { - return nil, colFmtDataAllText + return nil, colFmtDataAllText, nil } colFmts = make([]format, len(colTyps)) if forceText { - return colFmts, colFmtDataAllText + return colFmts, colFmtDataAllText, nil } allBinary := true @@ -807,95 +729,172 @@ func decideColumnFormats( case oid.T_uuid: colFmts[i] = formatBinary allText = false - default: allBinary = false } } if allBinary { - return colFmts, colFmtDataAllBinary + return colFmts, colFmtDataAllBinary, nil } else if allText { - return colFmts, colFmtDataAllText + return colFmts, colFmtDataAllText, nil } else { colFmtData = make([]byte, 2+len(colFmts)*2) + if len(colFmts) > math.MaxUint16 { + return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts)) + } binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) for i, v := range colFmts { binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) } - return colFmts, colFmtData + return colFmts, colFmtData, nil } } -func (cn *conn) prepareTo(q, stmtName string) *stmt { +func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.prepareTo") + defer fmt.Fprintln(os.Stderr, " END conn.prepareTo") + } + st := &stmt{cn: cn, name: stmtName} - b := cn.writeBuf('P') + b := cn.writeBuf(proto.Parse) b.string(st.name) b.string(q) b.int16(0) - b.next('D') - b.byte('S') + b.next(proto.Describe) + b.byte(proto.Sync) b.string(st.name) - b.next('S') - cn.send(b) + b.next(proto.Sync) + err := cn.send(b) + if err != nil { + return nil, err + } - cn.readParseResponse() - st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() - st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) - cn.readReadyForQuery() - return st + err = cn.readParseResponse() + if err != nil { + return nil, err + } + st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() + if err != nil { + return nil, err + } + st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult) + if err != nil { + return nil, err + } + + err = cn.readReadyForQuery() + if err != nil { + return nil, err + } + return st, nil } -func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { +func (cn *conn) Prepare(q string) (driver.Stmt, error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) - if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { + if pqsql.StartsWithCopy(q) { s, err := cn.prepareCopyIn(q) if err == nil { - cn.inCopy = true + cn.inProgress.Store(true) } - return s, err + return s, cn.handleError(err, q) + } + s, err := cn.prepareTo(q, cn.gname()) + if err != nil { + return nil, cn.handleError(err, q) + } + return s, nil +} + +func (cn *conn) Close() error { + // Don't go through send(); ListenerConn relies on us not scribbling on the + // scratch buffer of this connection. + err := cn.sendSimpleMessage(proto.Terminate) + if err != nil { + _ = cn.c.Close() // Ensure that cn.c.Close is always run. + return cn.handleError(err) } - return cn.prepareTo(q, cn.gname()), nil + return cn.c.Close() } -func (cn *conn) Close() (err error) { - // Skip cn.bad return here because we always want to close a connection. - defer cn.errRecover(&err) +func toNamedValue(v []driver.Value) []driver.NamedValue { + v2 := make([]driver.NamedValue, len(v)) + for i := range v { + v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} + } + return v2 +} - // Ensure that cn.c.Close is always run. Since error handling is done with - // panics and cn.errRecover, the Close must be in a defer. - defer func() { - cerr := cn.c.Close() - if err == nil { - err = cerr +// CheckNamedValue implements [driver.NamedValueChecker]. +func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { + if cn.cfg.BinaryParameters { + if bin, ok := nv.Value.(interface{ BinaryValue() ([]byte, error) }); ok { + var err error + nv.Value, err = bin.BinaryValue() + return err } - }() + } - // Don't go through send(); ListenerConn relies on us not scribbling on the - // scratch buffer of this connection. - return cn.sendSimpleMessage('X') + // Ignore Valuer, for backward compatibility with pq.Array(). + if _, ok := nv.Value.(driver.Valuer); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if !v.IsValid() { + return driver.ErrSkip + } + t := v.Type() + for t.Kind() == reflect.Pointer { + t, v = t.Elem(), v.Elem() + } + + // Ignore []byte and related types: *[]byte, json.RawMessage, etc. + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return driver.ErrSkip + } + + switch v.Kind() { + default: + return driver.ErrSkip + case reflect.Slice: + var err error + nv.Value, err = Array(v.Interface()).Value() + return err + case reflect.Uint64: + value := v.Uint() + if value >= math.MaxInt64 { + nv.Value = strconv.FormatUint(value, 10) + } else { + nv.Value = int64(value) + } + return nil + } } // Implement the "Queryer" interface func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, args) + return cn.query(query, toNamedValue(args)) } -func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { +func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.query") + defer fmt.Fprintln(os.Stderr, " END conn.query") + } if err := cn.err.get(); err != nil { return nil, err } - if cn.inCopy { - return nil, errCopyInProgress + if !cn.inProgress.CompareAndSwap(false, true) { + return nil, errQueryInProgress } - defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec @@ -903,18 +902,40 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { return cn.simpleQuery(query) } - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, args) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } - cn.readParseResponse() - cn.readBindResponse() rows := &rows{cn: cn} - rows.rowsHeader = cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() + rows.rowsHeader, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } return rows, nil } - st := cn.prepareTo(query, "") - st.exec(args) + + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) + } + err = st.exec(args) + if err != nil { + return nil, cn.handleError(err, query) + } return &rows{ cn: cn, rowsHeader: st.rowsHeader, @@ -922,69 +943,99 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { } // Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { +func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) + if !cn.inProgress.CompareAndSwap(false, true) { + return nil, errQueryInProgress + } - // Check to see if we can use the "simpleExec" interface, which is - // *much* faster than going through prepare/exec + // Check to see if we can use the "simpleExec" interface, which is *much* + // faster than going through prepare/exec if len(args) == 0 { // ignore commandTag, our caller doesn't care r, _, err := cn.simpleExec(query) - return r, err + return r, cn.handleError(err, query) + } + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, toNamedValue(args)) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + + _, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } + res, _, err := cn.readExecuteResponse("Execute") + return res, cn.handleError(err, query) } - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - res, _, err = cn.readExecuteResponse("Execute") - return res, err + // Use the unnamed statement to defer planning until bind time, or else + // value-based selectivity estimates cannot be used. + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") r, err := st.Exec(args) if err != nil { - panic(err) + return nil, cn.handleError(err, query) } - return r, err + return r, nil } -type safeRetryError struct { - Err error -} +type safeRetryError struct{ Err error } -func (se *safeRetryError) Error() string { - return se.Err.Error() -} +func (se *safeRetryError) Error() string { return se.Err.Error() } -func (cn *conn) send(m *writeBuf) { - n, err := cn.c.Write(m.wrap()) - if err != nil { - if n == 0 { - err = &safeRetryError{Err: err} +func (cn *conn) send(m *writeBuf) error { + if debugProto { + w := m.wrap() + for len(w) > 0 { // Can contain multiple messages. + c := proto.RequestCode(w[0]) + l := int(binary.BigEndian.Uint32(w[1:5])) - 4 + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5]) + w = w[l+5:] } - panic(err) } + + n, err := cn.c.Write(m.wrap()) + if err != nil && n == 0 { + err = &safeRetryError{Err: err} + } + return err } func (cn *conn) sendStartupPacket(m *writeBuf) error { + if debugProto { + w := m.wrap() + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", "Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:]) + } _, err := cn.c.Write((m.wrap())[1:]) return err } -// Send a message of type typ to the server on the other end of cn. The -// message should have no payload. This method does not use the scratch -// buffer. -func (cn *conn) sendSimpleMessage(typ byte) (err error) { - _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) +// Send a message of type typ to the server on the other end of cn. The message +// should have no payload. This method does not use the scratch buffer. +func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { + if debugProto { + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", typ, 0, []byte{}) + } + _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) return err } @@ -993,18 +1044,19 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // method is useful in cases where you have to see what the next message is // going to be (e.g. to see whether it's an error or not) but you can't handle // the message yourself. -func (cn *conn) saveMessage(typ byte, buf *readBuf) { +func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error { if cn.saveMessageType != 0 { cn.err.set(driver.ErrBadConn) - errorf("unexpected saveMessageType %d", cn.saveMessageType) + return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ cn.saveMessageBuffer = *buf + return nil } // recvMessage receives any message from the backend, or returns an error if // a problem occurred while reading the message. -func (cn *conn) recvMessage(r *readBuf) (byte, error) { +func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { // workaround for a QueryRow bug, see exec if cn.saveMessageType != 0 { t := cn.saveMessageType @@ -1020,9 +1072,25 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { return 0, err } - // read the type and length of the message that follows - t := x[0] + // Read the type and length of the message that follows. + t := proto.ResponseCode(x[0]) n := int(binary.BigEndian.Uint32(x[1:])) - 4 + + if proto.ResponseCode(t) == proto.ReadyForQuery { + cn.inProgress.Store(false) + } + + // When PostgreSQL cannot start a backend (e.g., an external process limit), + // it sends plain text like "Ecould not fork new process [..]", which + // doesn't use the standard encoding for the Error message. + // + // libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)", + // but check < 4 since n represents bytes remaining to be read after length. + if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) { + msg, _ := cn.buf.ReadString('\x00') + return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00")) + } + var y []byte if n <= len(cn.scratch) { y = cn.scratch[:n] @@ -1034,445 +1102,338 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { return 0, err } *r = y + if debugProto { + fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) + } return t, nil } -// recv receives a message from the backend, but if an error happened while -// reading the message or the received message was an ErrorResponse, it panics. -// NoticeResponses are ignored. This function should generally be used only +// recv receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// NoticeResponses are ignored. This function should generally be used only // during the startup sequence. -func (cn *conn) recv() (t byte, r *readBuf) { +func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) { for { - var err error - r = &readBuf{} - t, err = cn.recvMessage(r) + r := new(readBuf) + t, err := cn.recvMessage(r) if err != nil { - panic(err) + return 0, nil, err } switch t { - case 'E': - panic(parseError(r)) - case 'N': + case proto.ErrorResponse: + return 0, nil, parseError(r, "") + case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } - case 'A': + case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } default: - return + return t, r, nil } } } // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by // the caller to avoid an allocation. -func (cn *conn) recv1Buf(r *readBuf) byte { +func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) { for { t, err := cn.recvMessage(r) if err != nil { - panic(err) + return 0, err } switch t { - case 'A': + case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } - case 'N': + case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } - case 'S': + case proto.ParameterStatus: cn.processParameterStatus(r) default: - return t + return t, nil } } } -// recv1 receives a message from the backend, panicking if an error occurs -// while attempting to read it. All asynchronous messages are ignored, with -// the exception of ErrorResponse. -func (cn *conn) recv1() (t byte, r *readBuf) { - r = &readBuf{} - t = cn.recv1Buf(r) - return t, r +// recv1 receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// All asynchronous messages are ignored, with the exception of ErrorResponse. +func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { + r := new(readBuf) + t, err := cn.recv1Buf(r) + if err != nil { + return 0, nil, err + } + return t, r, nil } -func (cn *conn) ssl(o values) error { - upgrade, err := ssl(o) +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func (cn *conn) ssl(cfg Config, mode SSLMode) error { + upgrade, err := ssl(cfg, mode) if err != nil { return err } - if upgrade == nil { - // Nothing to do - return nil + return nil // Nothing to do } - w := cn.writeBuf(0) - w.int32(80877103) - if err = cn.sendStartupPacket(w); err != nil { - return err - } + // Only negotiate the ssl handshake if requested (which is the default). + // sslnegotiation=direct is supported by pg17 and above. + if cfg.SSLNegotiation != SSLNegotiationDirect { + w := cn.writeBuf(0) + w.int32(proto.NegotiateSSLCode) + if err = cn.sendStartupPacket(w); err != nil { + return err + } - b := cn.scratch[:1] - _, err = io.ReadFull(cn.c, b) - if err != nil { - return err - } + b := cn.scratch[:1] + _, err = io.ReadFull(cn.c, b) + if err != nil { + return err + } - if b[0] != 'S' { - return ErrSSLNotSupported + if b[0] != 'S' { + return ErrSSLNotSupported + } } cn.c, err = upgrade(cn.c) return err } -// isDriverSetting returns true iff a setting is purely for configuring the -// driver's options and should not be sent to the server in the connection -// startup packet. -func isDriverSetting(key string) bool { - switch key { - case "host", "port": - return true - case "password": - return true - case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": - return true - case "fallback_application_name": - return true - case "connect_timeout": - return true - case "disable_prepared_binary_result": - return true - case "binary_parameters": - return true - case "krbsrvname": - return true - case "krbspn": - return true - default: - return false - } -} - -func (cn *conn) startup(o values) { +func (cn *conn) startup(cfg Config) error { w := cn.writeBuf(0) - w.int32(196608) - // Send the backend the name of the database we want to connect to, and the - // user we want to connect as. Additionally, we send over any run-time - // parameters potentially included in the connection string. If the server - // doesn't recognize any of them, it will reply with an error. - for k, v := range o { - if isDriverSetting(k) { - // skip options which can't be run-time parameters - continue - } - // The protocol requires us to supply the database name as "database" - // instead of "dbname". - if k == "dbname" { - k = "database" - } + // Send maximum protocol version in startup; if the server doesn't support + // this version it responds with NegotiateProtocolVersion and the maximum + // version it supports (and will use). + w.int32(cfg.MaxProtocolVersion.proto()) + + if cfg.User != "" { + w.string("user") + w.string(cfg.User) + } + if cfg.Database != "" { + w.string("database") + w.string(cfg.Database) + } + // w.string("replication") // Sent by libpq, but we don't support that. + if cfg.Options != "" { + w.string("options") + w.string(cfg.Options) + } + if cfg.ApplicationName != "" { + w.string("application_name") + w.string(cfg.ApplicationName) + } + if cfg.ClientEncoding != "" { + w.string("client_encoding") + w.string(cfg.ClientEncoding) + } + if cfg.Datestyle != "" { + w.string("datestyle") + w.string(cfg.Datestyle) + } + for k, v := range cfg.Runtime { w.string(k) w.string(v) } + w.string("") if err := cn.sendStartupPacket(w); err != nil { - panic(err) + return err } for { - t, r := cn.recv() + t, r, err := cn.recv() + if err != nil { + return err + } switch t { - case 'K': - cn.processBackendKeyData(r) - case 'S': + case proto.BackendKeyData: + cn.pid = r.int32() + if len(*r) > 256 { + return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r)) + } + cn.secretKey = make([]byte, len(*r)) + copy(cn.secretKey, *r) + case proto.ParameterStatus: cn.processParameterStatus(r) - case 'R': - cn.auth(r, o) - case 'Z': + case proto.AuthenticationRequest: + err := cn.auth(r, cfg) + if err != nil { + return err + } + case proto.NegotiateProtocolVersion: + newestMinor := r.int32() + serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor + if serverVersion < cfg.MinProtocolVersion.proto() { + return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - return + return nil default: - errorf("unknown response for startup: %q", t) + return fmt.Errorf("pq: unknown response for startup: %q", t) } } } -func (cn *conn) auth(r *readBuf, o values) { - switch code := r.int32(); code { - case 0: - // OK - case 3: - w := cn.writeBuf('p') - w.string(o["password"]) - cn.send(w) +func (cn *conn) auth(r *readBuf, cfg Config) error { + switch code := proto.AuthCode(r.int32()); code { + default: + return fmt.Errorf("pq: unknown authentication response: %s", code) + case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: + return fmt.Errorf("pq: unsupported authentication method: %s", code) + case proto.AuthReqOk: + return nil - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } + case proto.AuthReqPassword: + w := cn.writeBuf(proto.PasswordMessage) + w.string(cfg.Password) + // Don't need to check AuthOk response here; auth() is called in a loop, + // which catches the errors and AuthReqOk responses. + return cn.send(w) - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 5: + case proto.AuthReqMD5: s := string(r.next(4)) - w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } + w := cn.writeBuf(proto.PasswordMessage) + w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) + // Same here. + return cn.send(w) - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 7: // GSSAPI, startup + case proto.AuthReqGSS: // GSSAPI, startup if newGss == nil { - errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") + return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)") } cli, err := newGss() if err != nil { - errorf("kerberos error: %s", err.Error()) + return fmt.Errorf("pq: kerberos error: %w", err) } var token []byte - - if spn, ok := o["krbspn"]; ok { - // Use the supplied SPN if provided.. - token, err = cli.GetInitTokenFromSpn(spn) + if cfg.KrbSpn != "" { + // Use the supplied SPN if provided. + token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn) } else { - // Allow the kerberos service name to be overridden + // Allow the kerberos service name to be overridden. service := "postgres" - if val, ok := o["krbsrvname"]; ok { - service = val + if cfg.KrbSrvname != "" { + service = cfg.KrbSrvname } - - token, err = cli.GetInitToken(o["host"], service) + token, err = cli.GetInitToken(cfg.Host, service) } - if err != nil { - errorf("failed to get Kerberos ticket: %q", err) + return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err) } - w := cn.writeBuf('p') + w := cn.writeBuf(proto.GSSResponse) w.bytes(token) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } // Store for GSSAPI continue message cn.gss = cli + return nil - case 8: // GSSAPI continue - + case proto.AuthReqGSSCont: // GSSAPI continue if cn.gss == nil { - errorf("GSSAPI protocol error") + return errors.New("pq: GSSAPI protocol error") } - b := []byte(*r) - - done, tokOut, err := cn.gss.Continue(b) + done, tokOut, err := cn.gss.Continue([]byte(*r)) if err == nil && !done { - w := cn.writeBuf('p') + w := cn.writeBuf(proto.SASLInitialResponse) w.bytes(tokOut) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } } - // Errors fall through and read the more detailed message - // from the server.. + // Errors fall through and read the more detailed message from the + // server. + return nil - case 10: - sc := scram.NewClient(sha256.New, o["user"], o["password"]) + case proto.AuthReqSASL: + sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) sc.Step(nil) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut := sc.Out() - w := cn.writeBuf('p') + w := cn.writeBuf(proto.SASLResponse) w.string("SCRAM-SHA-256") w.int32(len(scOut)) w.bytes(scOut) - cn.send(w) + err := cn.send(w) + if err != nil { + return err + } - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) + t, r, err := cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) } - if r.int32() != 11 { - errorf("unexpected authentication response: %q", t) + if r.int32() != int(proto.AuthReqSASLCont) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep := r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut = sc.Out() - w = cn.writeBuf('p') + w = cn.writeBuf(proto.SASLResponse) w.bytes(scOut) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } - t, r = cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) + t, r, err = cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) } - if r.int32() != 12 { - errorf("unexpected authentication response: %q", t) + if r.int32() != int(proto.AuthReqSASLFin) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep = r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } - default: - errorf("unknown authentication response: %d", code) - } -} - -type format int - -const formatText format = 0 -const formatBinary format = 1 - -// One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary = []byte{0, 1, 0, 1} - -// No result-column format codes (i.e. all text). -var colFmtDataAllText = []byte{0, 0} - -type stmt struct { - cn *conn - name string - rowsHeader - colFmtData []byte - paramTyps []oid.Oid - closed bool -} - -func (st *stmt) Close() (err error) { - if st.closed { return nil } - if err := st.cn.err.get(); err != nil { - return err - } - defer st.cn.errRecover(&err) - - w := st.cn.writeBuf('C') - w.byte('S') - w.string(st.name) - st.cn.send(w) - - st.cn.send(st.cn.writeBuf('S')) - - t, _ := st.cn.recv1() - if t != '3' { - st.cn.err.set(driver.ErrBadConn) - errorf("unexpected close response: %q", t) - } - st.closed = true - - t, r := st.cn.recv1() - if t != 'Z' { - st.cn.err.set(driver.ErrBadConn) - errorf("expected ready for query, but got: %q", t) - } - st.cn.processReadyForQuery(r) - - return nil -} - -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(v) -} - -func (st *stmt) query(v []driver.Value) (r *rows, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - return &rows{ - cn: st.cn, - rowsHeader: st.rowsHeader, - }, nil -} - -func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - res, _, err = st.cn.readExecuteResponse("simple query") - return res, err -} - -func (st *stmt) exec(v []driver.Value) { - if len(v) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) - } - if len(v) != len(st.paramTyps) { - errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) - } - - cn := st.cn - w := cn.writeBuf('B') - w.byte(0) // unnamed portal - w.string(st.name) - - if cn.binaryParameters { - cn.sendBinaryParameters(w, v) - } else { - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) - } - } - } - w.bytes(st.colFmtData) - - w.next('E') - w.byte(0) - w.int32(0) - - w.next('S') - cn.send(w) - - cn.readBindResponse() - cn.postExecuteWorkaround() - -} - -func (st *stmt) NumInput() int { - return len(st.paramTyps) } // parseComplete parses the "command tag" from a CommandComplete message, and -// returns the number of rows affected (if applicable) and a string -// identifying only the command that was executed, e.g. "ALTER TABLE". If the -// command tag could not be parsed, parseComplete panics. -func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { +// returns the number of rows affected (if applicable) and a string identifying +// only the command that was executed, e.g. "ALTER TABLE". Returns an error if +// the command can cannot be parsed. +func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) { commandsWithAffectedRows := []string{ "SELECT ", // INSERT is handled below @@ -1492,218 +1453,29 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { break } } - // INSERT also includes the oid of the inserted row in its command tag. - // Oids in user tables are deprecated, and the oid is only returned when - // exactly one row is inserted, so it's unlikely to be of value to any - // real-world application and we can ignore it. + // INSERT also includes the oid of the inserted row in its command tag. Oids + // in user tables are deprecated, and the oid is only returned when exactly + // one row is inserted, so it's unlikely to be of value to any real-world + // application and we can ignore it. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { cn.err.set(driver.ErrBadConn) - errorf("unexpected INSERT command tag %s", commandTag) + return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] commandTag = "INSERT" } // There should be no affected rows attached to the tag, just return it if affectedRows == nil { - return driver.RowsAffected(0), commandTag + return driver.RowsAffected(0), commandTag, nil } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { cn.err.set(driver.ErrBadConn) - errorf("could not parse commandTag: %s", err) - } - return driver.RowsAffected(n), commandTag -} - -type rowsHeader struct { - colNames []string - colTyps []fieldDesc - colFmts []format -} - -type rows struct { - cn *conn - finish func() - rowsHeader - done bool - rb readBuf - result driver.Result - tag string - - next *rowsHeader -} - -func (rs *rows) Close() error { - if finish := rs.finish; finish != nil { - defer finish() - } - // no need to look at cn.bad as Next() will - for { - err := rs.Next(nil) - switch err { - case nil: - case io.EOF: - // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row - // description, used with HasNextResultSet). We need to fetch messages until - // we hit a 'Z', which is done by waiting for done to be set. - if rs.done { - return nil - } - default: - return err - } - } -} - -func (rs *rows) Columns() []string { - return rs.colNames -} - -func (rs *rows) Result() driver.Result { - if rs.result == nil { - return emptyRows - } - return rs.result -} - -func (rs *rows) Tag() string { - return rs.tag -} - -func (rs *rows) Next(dest []driver.Value) (err error) { - if rs.done { - return io.EOF - } - - conn := rs.cn - if err := conn.err.getForNext(); err != nil { - return err - } - defer conn.errRecover(&err) - - for { - t := conn.recv1Buf(&rs.rb) - switch t { - case 'E': - err = parseError(&rs.rb) - case 'C', 'I': - if t == 'C' { - rs.result, rs.tag = conn.parseComplete(rs.rb.string()) - } - continue - case 'Z': - conn.processReadyForQuery(&rs.rb) - rs.done = true - if err != nil { - return err - } - return io.EOF - case 'D': - n := rs.rb.int16() - if err != nil { - conn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow after error %s", err) - } - if n < len(dest) { - dest = dest[:n] - } - for i := range dest { - l := rs.rb.int32() - if l == -1 { - dest[i] = nil - continue - } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) - } - return - case 'T': - next := parsePortalRowDescribe(&rs.rb) - rs.next = &next - return io.EOF - default: - errorf("unexpected message after execute: %q", t) - } - } -} - -func (rs *rows) HasNextResultSet() bool { - hasNext := rs.next != nil && !rs.done - return hasNext -} - -func (rs *rows) NextResultSet() error { - if rs.next == nil { - return io.EOF - } - rs.rowsHeader = *rs.next - rs.next = nil - return nil -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// quoted := pq.QuoteIdentifier(tblname) -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func QuoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` -} - -// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a -// byte buffer. -func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - buffer.WriteRune('"') - buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) - buffer.WriteRune('"') -} - -// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal -// to DDL and other statements that do not accept parameters) to be used as part -// of an SQL statement. For example: -// -// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") -// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) -// -// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be -// replaced by two backslashes (i.e. "\\") and the C-style escape identifier -// that PostgreSQL provides ('E') will be prepended to the string. -func QuoteLiteral(literal string) string { - // This follows the PostgreSQL internal algorithm for handling quoted literals - // from libpq, which can be found in the "PQEscapeStringInternal" function, - // which is found in the libpq/fe-exec.c source file: - // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c - // - // substitute any single-quotes (') with two single-quotes ('') - literal = strings.Replace(literal, `'`, `''`, -1) - // determine if the string has any backslashes (\) in it. - // if it does, replace any backslashes (\) with two backslashes (\\) - // then, we need to wrap the entire string with a PostgreSQL - // C-style escape. Per how "PQEscapeStringInternal" handles this case, we - // also add a space before the "E" - if strings.Contains(literal, `\`) { - literal = strings.Replace(literal, `\`, `\\`, -1) - literal = ` E'` + literal + `'` - } else { - // otherwise, we can just wrap the literal with a pair of single quotes - literal = `'` + literal + `'` + return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err) } - return literal + return driver.RowsAffected(n), commandTag, nil } func md5s(s string) string { @@ -1712,13 +1484,12 @@ func md5s(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } -func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { - // Do one pass over the parameters to see if we're going to send any of - // them over in binary. If we are, create a paramFormats array at the - // same time. +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error { + // Do one pass over the parameters to see if we're going to send any of them + // over in binary. If we are, create a paramFormats array at the same time. var paramFormats []int for i, x := range args { - _, ok := x.([]byte) + _, ok := x.Value.([]byte) if ok { if paramFormats == nil { paramFormats = make([]int, len(args)) @@ -1737,64 +1508,86 @@ func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { b.int16(len(args)) for _, x := range args { - if x == nil { + if x.Value == nil { + b.int32(-1) + } else if xx, ok := x.Value.([]byte); ok && xx == nil { b.int32(-1) } else { - datum := binaryEncode(&cn.parameterStatus, x) + datum, err := binaryEncode(x.Value) + if err != nil { + return err + } b.int32(len(datum)) b.bytes(datum) } } + return nil } -func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { +func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error { if len(args) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } - b := cn.writeBuf('P') + b := cn.writeBuf(proto.Parse) b.byte(0) // unnamed statement b.string(query) b.int16(0) - b.next('B') + b.next(proto.Bind) b.int16(0) // unnamed portal and statement - cn.sendBinaryParameters(b, args) + err := cn.sendBinaryParameters(b, args) + if err != nil { + return err + } b.bytes(colFmtDataAllText) - b.next('D') - b.byte('P') + b.next(proto.Describe) + b.byte(proto.Parse) b.byte(0) // unnamed portal - b.next('E') + b.next(proto.Execute) b.byte(0) b.int32(0) - b.next('S') - cn.send(b) + b.next(proto.Sync) + return cn.send(b) } func (cn *conn) processParameterStatus(r *readBuf) { - var err error - - param := r.string() - switch param { + switch r.string() { + default: + // ignore case "server_version": - var major1 int - var major2 int - _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) + var major1, major2 int + _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) if err == nil { cn.parameterStatus.serverVersion = major1*10000 + major2*100 } - case "TimeZone": - cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) - if err != nil { - cn.parameterStatus.currentLocation = nil + switch tz := r.string(); tz { + case "UTC", "Etc/UTC", "Etc/Universal", "Etc/Zulu", "Etc/UCT": + cn.parameterStatus.currentLocation = time.UTC + default: + var err error + cn.parameterStatus.currentLocation, err = time.LoadLocation(tz) + if err != nil { + cn.parameterStatus.currentLocation = nil + } + } + // Use sql.NullBool so we can distinguish between false and not sent. If + // it's not sent we use a query to get the value – I don't know when these + // parameters are not sent, but this is what libpq does. + case "in_hot_standby": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b} + } + case "default_transaction_read_only": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b} } - - default: - // ignore } } @@ -1802,102 +1595,111 @@ func (cn *conn) processReadyForQuery(r *readBuf) { cn.txnStatus = transactionStatus(r.byte()) } -func (cn *conn) readReadyForQuery() { - t, r := cn.recv1() +func (cn *conn) readReadyForQuery() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case 'Z': + case proto.ReadyForQuery: cn.processReadyForQuery(r) - return + return nil + case proto.ErrorResponse: + err := parseError(r, "") + cn.err.set(driver.ErrBadConn) + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q; expected ReadyForQuery", t) + return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t) } } -func (cn *conn) processBackendKeyData(r *readBuf) { - cn.processID = r.int32() - cn.secretKey = r.int32() -} - -func (cn *conn) readParseResponse() { - t, r := cn.recv1() +func (cn *conn) readParseResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case '1': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.ParseComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Parse response %q", t) + return fmt.Errorf("pq: unexpected Parse response %q", t) } } -func (cn *conn) readStatementDescribeResponse() ( - paramTyps []oid.Oid, - colNames []string, - colTyps []fieldDesc, -) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) { for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, nil, nil, err + } switch t { - case 't': + case proto.ParameterDescription: nparams := r.int16() paramTyps = make([]oid.Oid, nparams) for i := range paramTyps { paramTyps[i] = r.oid() } - case 'n': - return paramTyps, nil, nil - case 'T': + case proto.NoData: + return paramTyps, nil, nil, nil + case proto.RowDescription: colNames, colTyps = parseStatementRowDescribe(r) - return paramTyps, colNames, colTyps - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + return paramTyps, colNames, colTyps, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return nil, nil, nil, err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe statement response %q", t) + return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t) } } } -func (cn *conn) readPortalDescribeResponse() rowsHeader { - t, r := cn.recv1() +func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) { + t, r, err := cn.recv1() + if err != nil { + return rowsHeader{}, err + } switch t { - case 'T': - return parsePortalRowDescribe(r) - case 'n': - return rowsHeader{} - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.RowDescription: + return parsePortalRowDescribe(r), nil + case proto.NoData: + return rowsHeader{}, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return rowsHeader{}, err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe response %q", t) + return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t) } - panic("not reached") } -func (cn *conn) readBindResponse() { - t, r := cn.recv1() +func (cn *conn) readBindResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case '2': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.BindComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Bind response %q", t) + return fmt.Errorf("pq: unexpected Bind response %q", t) } } -func (cn *conn) postExecuteWorkaround() { +func (cn *conn) postExecuteWorkaround() error { // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores // any errors from rows.Next, which masks errors that happened during the // execution of the query. To avoid the problem in common cases, we wait @@ -1908,56 +1710,62 @@ func (cn *conn) postExecuteWorkaround() { // However, if it's an error, we wait until ReadyForQuery and then return // the error to our caller. for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - case 'C', 'D', 'I': + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err + case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse: // the query didn't fail, but we can't process this message - cn.saveMessage(t, r) - return + return cn.saveMessage(t, r) default: cn.err.set(driver.ErrBadConn) - errorf("unexpected message during extended query execution: %q", t) + return fmt.Errorf("pq: unexpected message during extended query execution: %q", t) } } } // Only for Exec(), since we ignore the returned data -func (cn *conn) readExecuteResponse( - protocolState string, -) (res driver.Result, commandTag string, err error) { +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) { for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } switch t { - case 'C': - if err != nil { + case proto.CommandComplete: + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected CommandComplete after error %s", err) + return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr) } - res, commandTag = cn.parseComplete(r.string()) - case 'Z': + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady + if res == nil && resErr == nil { + resErr = errUnexpectedReady } - return res, commandTag, err - case 'E': - err = parseError(r) - case 'T', 'D', 'I': - if err != nil { + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, "") + case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse: + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected %q after error %s", t, err) + return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr) } - if t == 'I' { + if t == proto.EmptyQueryResponse { res = emptyRows } // ignore any results default: cn.err.set(driver.ErrBadConn) - errorf("unknown %s response: %q", protocolState, t) + return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t) } } } @@ -1998,108 +1806,6 @@ func parsePortalRowDescribe(r *readBuf) rowsHeader { } } -// parseEnviron tries to mimic some of libpq's environment handling -// -// To ease testing, it does not directly reference os.Environ, but is -// designed to accept its output. -// -// Environment-set connection information is intended to have a higher -// precedence than a library default but lower than any explicitly -// passed information (such as in the URL or connection string). -func parseEnviron(env []string) (out map[string]string) { - out = make(map[string]string) - - for _, v := range env { - parts := strings.SplitN(v, "=", 2) - - accrue := func(keyname string) { - out[keyname] = parts[1] - } - unsupported := func() { - panic(fmt.Sprintf("setting %v not supported", parts[0])) - } - - // The order of these is the same as is seen in the - // PostgreSQL 9.1 manual. Unsupported but well-defined - // keys cause a panic; these should be unset prior to - // execution. Options which pq expects to be set to a - // certain value are allowed, but must be set to that - // value if present (they can, of course, be absent). - switch parts[0] { - case "PGHOST": - accrue("host") - case "PGHOSTADDR": - unsupported() - case "PGPORT": - accrue("port") - case "PGDATABASE": - accrue("dbname") - case "PGUSER": - accrue("user") - case "PGPASSWORD": - accrue("password") - case "PGSERVICE", "PGSERVICEFILE", "PGREALM": - unsupported() - case "PGOPTIONS": - accrue("options") - case "PGAPPNAME": - accrue("application_name") - case "PGSSLMODE": - accrue("sslmode") - case "PGSSLCERT": - accrue("sslcert") - case "PGSSLKEY": - accrue("sslkey") - case "PGSSLROOTCERT": - accrue("sslrootcert") - case "PGSSLSNI": - accrue("sslsni") - case "PGREQUIRESSL", "PGSSLCRL": - unsupported() - case "PGREQUIREPEER": - unsupported() - case "PGKRBSRVNAME", "PGGSSLIB": - unsupported() - case "PGCONNECT_TIMEOUT": - accrue("connect_timeout") - case "PGCLIENTENCODING": - accrue("client_encoding") - case "PGDATESTYLE": - accrue("datestyle") - case "PGTZ": - accrue("timezone") - case "PGGEQO": - accrue("geqo") - case "PGSYSCONFDIR", "PGLOCALEDIR": - unsupported() - } - } - - return out -} - -// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". -func isUTF8(name string) bool { - // Recognize all sorts of silly things as "UTF-8", like Postgres does - s := strings.Map(alnumLowerASCII, name) - return s == "utf8" || s == "unicode" -} - -func alnumLowerASCII(ch rune) rune { - if 'A' <= ch && ch <= 'Z' { - return ch + ('a' - 'A') - } - if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { - return ch - } - return -1 // discard -} - -// The database/sql/driver package says: -// All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator. -var _ driver.Pinger = &conn{} -var _ driver.SessionResetter = &conn{} - func (cn *conn) ResetSession(ctx context.Context) error { // Ensure bad connections are reported: From database/sql/driver: // If a connection is never returned to the connection pool but immediately reused, then diff --git a/vendor/github.com/lib/pq/conn_go115.go b/vendor/github.com/lib/pq/conn_go115.go deleted file mode 100644 index f4ef030f..00000000 --- a/vendor/github.com/lib/pq/conn_go115.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package pq - -import "database/sql/driver" - -var _ driver.Validator = &conn{} diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index 63d4ca6a..16de38eb 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -6,22 +6,17 @@ import ( "database/sql/driver" "fmt" "io" - "io/ioutil" "time" -) -const ( - watchCancelDialContextTimeout = time.Second * 10 + "github.com/lib/pq/internal/proto" ) +const watchCancelDialContextTimeout = 10 * time.Second + // Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := cn.watchCancel(ctx) - r, err := cn.query(query, list) + r, err := cn.query(query, args) if err != nil { if finish != nil { finish() @@ -57,7 +52,6 @@ func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, // Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { var mode string - switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: // Don't touch mode: use the server's default @@ -72,7 +66,6 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, default: return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) } - if opts.ReadOnly { mode += " READ ONLY" } else { @@ -93,9 +86,9 @@ func (cn *conn) Ping(ctx context.Context) error { } rows, err := cn.simpleQuery(";") if err != nil { - return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + return driver.ErrBadConn } - rows.Close() + _ = rows.Close() return nil } @@ -131,7 +124,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() { select { case <-finished: cn.err.set(ctx.Err()) - cn.Close() + _ = cn.Close() case finished <- struct{}{}: } } @@ -140,55 +133,39 @@ func (cn *conn) watchCancel(ctx context.Context) func() { } func (cn *conn) cancel(ctx context.Context) error { - // Create a new values map (copy). This makes sure the connection created - // in this method cannot write to the same underlying data, which could - // cause a concurrent map write panic. This is necessary because cancel - // is called from a goroutine in watchCancel. - o := make(values) - for k, v := range cn.opts { - o[k] = v - } + // Use a copy since a new connection is created here. This is necessary + // because cancel is called from a goroutine in watchCancel. + cfg := cn.cfg.Clone() - c, err := dial(ctx, cn.dialer, o) + c, err := dial(ctx, cn.dialer, cfg) if err != nil { return err } - defer c.Close() + defer func() { _ = c.Close() }() - { - can := conn{ - c: c, - } - err = can.ssl(o) - if err != nil { - return err - } - - w := can.writeBuf(0) - w.int32(80877102) // cancel request code - w.int32(cn.processID) - w.int32(cn.secretKey) - - if err := can.sendStartupPacket(w); err != nil { - return err - } + cn2 := conn{c: c} + err = cn2.ssl(cfg, cfg.SSLMode) + if err != nil { + return err } - // Read until EOF to ensure that the server received the cancel. - { - _, err := io.Copy(ioutil.Discard, c) + w := cn2.writeBuf(0) + w.int32(proto.CancelRequestCode) + w.int32(cn.pid) + w.bytes(cn.secretKey) + if err := cn2.sendStartupPacket(w); err != nil { return err } + + // Read until EOF to ensure that the server received the cancel. + _, err = io.Copy(io.Discard, c) + return err } // Implement the "StmtQueryContext" interface func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := st.watchCancel(ctx) - r, err := st.query(list) + r, err := st.query(args) if err != nil { if finish != nil { finish() @@ -201,16 +178,19 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri // Implement the "StmtExecContext" interface func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - if finish := st.watchCancel(ctx); finish != nil { defer finish() } + if err := st.cn.err.get(); err != nil { + return nil, err + } - return st.Exec(list) + err := st.exec(args) + if err != nil { + return nil, st.cn.handleError(err) + } + res, _, err := st.cn.readExecuteResponse("simple query") + return res, st.cn.handleError(err) } // watchCancel is implemented on stmt in order to not mark the parent conn as bad @@ -220,10 +200,9 @@ func (st *stmt) watchCancel(ctx context.Context) func() { go func() { select { case <-done: - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. + // At this point the function level context is canceled, so it + // must not be used for the additional network request to cancel + // the query. Create a new context to pass into the dial. ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) defer cancel() diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go index 1145e122..a2a8fb28 100644 --- a/vendor/github.com/lib/pq/connector.go +++ b/vendor/github.com/lib/pq/connector.go @@ -2,82 +2,597 @@ package pq import ( "context" + "crypto/tls" "database/sql/driver" - "errors" "fmt" + "math/rand" + "net" + "net/netip" + neturl "net/url" "os" + "path/filepath" + "reflect" + "slices" + "sort" + "strconv" "strings" + "time" + "unicode" + + "github.com/lib/pq/internal/pgservice" + "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" +) + +type ( + // SSLMode is a sslmode setting. + SSLMode string + + // SSLNegotiation is a sslnegotiation setting. + SSLNegotiation string + + // TargetSessionAttrs is a target_session_attrs setting. + TargetSessionAttrs string + + // LoadBalanceHosts is a load_balance_hosts setting. + LoadBalanceHosts string + + // ProtocolVersion is a min_protocol_version or max_protocol_version + // setting. + ProtocolVersion string + + // SSLProtocolVersion is a ssl_min_protocol_version or + // ssl_max_protocol_version setting. + SSLProtocolVersion string +) + +// Values for [SSLMode] that pq supports. +const ( + // No SSL + SSLModeDisable = SSLMode("disable") + + // First try a non-SSL connection and if that fails try an SSL connection. + SSLModeAllow = SSLMode("allow") + + // First try an SSL connection and if that fails try a non-SSL connection. + SSLModePrefer = SSLMode("prefer") + + // Require SSL, but skip verification. This is the default. + SSLModeRequire = SSLMode("require") + + // Require SSL and verify that the certificate was signed by a trusted CA. + SSLModeVerifyCA = SSLMode("verify-ca") + + // Require SSL and verify that the certificate was signed by a trusted CA + // and the server host name matches the one in the certificate. + SSLModeVerifyFull = SSLMode("verify-full") +) + +var sslModes = []SSLMode{SSLModeDisable, SSLModeAllow, SSLModePrefer, SSLModeRequire, + SSLModeVerifyFull, SSLModeVerifyCA} + +func (s SSLMode) useSSL() bool { + switch s { + case SSLModePrefer, SSLModeRequire, SSLModeVerifyCA, SSLModeVerifyFull: + return true + } + return false +} + +// Values for [SSLNegotiation] that pq supports. +const ( + // Negotiate whether SSL should be used. This is the default. + SSLNegotiationPostgres = SSLNegotiation("postgres") + + // Always use SSL, don't try to negotiate. + SSLNegotiationDirect = SSLNegotiation("direct") +) + +var sslNegotiations = []SSLNegotiation{SSLNegotiationPostgres, SSLNegotiationDirect} + +// Values for [TargetSessionAttrs] that pq supports. +const ( + // Any successful connection is acceptable. This is the default. + TargetSessionAttrsAny = TargetSessionAttrs("any") + + // Session must accept read-write transactions by default: the server must + // not be in hot standby mode and default_transaction_read_only must be + // off. + TargetSessionAttrsReadWrite = TargetSessionAttrs("read-write") + + // Session must not accept read-write transactions by default. + TargetSessionAttrsReadOnly = TargetSessionAttrs("read-only") + + // Server must not be in hot standby mode. + TargetSessionAttrsPrimary = TargetSessionAttrs("primary") + + // Server must be in hot standby mode. + TargetSessionAttrsStandby = TargetSessionAttrs("standby") + + // First try to find a standby server, but if none of the listed hosts is a + // standby server, try again in any mode. + TargetSessionAttrsPreferStandby = TargetSessionAttrs("prefer-standby") +) + +var targetSessionAttrs = []TargetSessionAttrs{TargetSessionAttrsAny, + TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly, TargetSessionAttrsPrimary, + TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby} + +// Values for [LoadBalanceHosts] that pq supports. +const ( + // Don't load balance; try hosts in the order in which they're provided. + // This is the default. + LoadBalanceHostsDisable = LoadBalanceHosts("disable") + + // Hosts are tried in random order to balance connections across multiple + // PostgreSQL servers. + // + // When using this value it's recommended to also configure a reasonable + // value for connect_timeout. Because then, if one of the nodes that are + // used for load balancing is not responding, a new node will be tried. + LoadBalanceHostsRandom = LoadBalanceHosts("random") ) +var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom} + +// Values for [ProtocolVersion] that pq supports. +const ( + // ProtocolVersion30 is the default protocol version, supported in + // PostgreSQL 3.0 and newer. + ProtocolVersion30 = ProtocolVersion("3.0") + + // ProtocolVersion32 uses a longer secret key length for query cancellation, + // supported in PostgreSQL 18 and newer. + ProtocolVersion32 = ProtocolVersion("3.2") + + // ProtocolVersionLatest is the latest protocol version that pq supports + // (which may not be supported by the server). + ProtocolVersionLatest = ProtocolVersion("latest") +) + +var protocolVersions = []ProtocolVersion{ProtocolVersion30, ProtocolVersion32, ProtocolVersionLatest} + +// Values for [SSLProtocolVersion] that pq supports. +const ( + SSLProtocolVersionTLS10 = SSLProtocolVersion("TLSv1.0") + SSLProtocolVersionTLS11 = SSLProtocolVersion("TLSv1.1") + SSLProtocolVersionTLS12 = SSLProtocolVersion("TLSv1.2") + SSLProtocolVersionTLS13 = SSLProtocolVersion("TLSv1.3") +) + +var sslProtocolVersions = []SSLProtocolVersion{SSLProtocolVersionTLS10, SSLProtocolVersionTLS11, + SSLProtocolVersionTLS12, SSLProtocolVersionTLS13} + +func (s SSLProtocolVersion) tlsconf() uint16 { + switch s { + case SSLProtocolVersionTLS10: + return tls.VersionTLS10 + case SSLProtocolVersionTLS11: + return tls.VersionTLS11 + case SSLProtocolVersionTLS12: + return tls.VersionTLS12 + case SSLProtocolVersionTLS13: + return tls.VersionTLS13 + default: + return 0 + } +} + // Connector represents a fixed configuration for the pq driver with a given -// name. Connector satisfies the database/sql/driver Connector interface and -// can be used to create any number of DB Conn's via the database/sql OpenDB -// function. -// -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. +// dsn. Connector satisfies the [database/sql/driver.Connector] interface and +// can be used to create any number of DB Conn's via [sql.OpenDB]. type Connector struct { - opts values + cfg Config dialer Dialer } -// Connect returns a connection to the database using the fixed configuration -// of this Connector. Context is not used. -func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { - return c.open(ctx) +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// [sql.OpenDB]. +func NewConnector(dsn string) (*Connector, error) { + cfg, err := NewConfig(dsn) + if err != nil { + return nil, err + } + return NewConnectorConfig(cfg) +} + +// NewConnectorConfig returns a connector for the pq driver in a fixed +// configuration with the given [Config]. The returned connector can be used to +// create any number of equivalent Conn's. The returned connector is intended to +// be used with [sql.OpenDB]. +func NewConnectorConfig(cfg Config) (*Connector, error) { + return &Connector{cfg: cfg, dialer: defaultDialer{}}, nil } +// Connect returns a connection to the database using the fixed configuration of +// this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } + // Dialer allows change the dialer used to open connections. -func (c *Connector) Dialer(dialer Dialer) { - c.dialer = dialer -} +func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } // Driver returns the underlying driver of this Connector. -func (c *Connector) Driver() driver.Driver { - return &Driver{} +func (c *Connector) Driver() driver.Driver { return &Driver{} } + +func (p ProtocolVersion) proto() int { + switch p { + default: + return proto.ProtocolVersion30 + case ProtocolVersion32, ProtocolVersionLatest: + return proto.ProtocolVersion32 + } } -// NewConnector returns a connector for the pq driver in a fixed configuration -// with the given dsn. The returned connector can be used to create any number -// of equivalent Conn's. The returned connector is intended to be used with -// database/sql.OpenDB. +// Config holds options pq supports when connecting to PostgreSQL. // -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(dsn string) (*Connector, error) { - var err error - o := make(values) - - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o["extra_float_digits"] = "2" - for k, v := range parseEnviron(os.Environ()) { - o[k] = v +// The postgres struct tag is used for the value from the DSN (e.g. +// "dbname=abc"), and the env struct tag is used for the environment variable +// (e.g. "PGDATABASE=abc") +type Config struct { + // The host to connect to. Absolute paths and values that start with @ are + // for unix domain sockets. Defaults to localhost. + // + // A comma-separated list of host names is also accepted, in which case each + // host name in the list is tried in order or randomly if load_balance_hosts + // is set. An empty item selects the default of localhost. The + // target_session_attrs option controls properties the host must have to be + // considered acceptable. + Host string `postgres:"host" env:"PGHOST"` + + // IPv4 or IPv6 address to connect to. Using hostaddr allows the application + // to avoid a host name lookup, which might be important in applications + // with time constraints. A hostname is required for sslmode=verify-full and + // the GSSAPI or SSPI authentication methods. + // + // The following rules are used: + // + // - If host is given without hostaddr, a host name lookup occurs. + // + // - If hostaddr is given without host, the value for hostaddr gives the + // server network address. The connection attempt will fail if the + // authentication method requires a host name. + // + // - If both host and hostaddr are given, the value for hostaddr gives the + // server network address. The value for host is ignored unless the + // authentication method requires it, in which case it will be used as the + // host name. + // + // A comma-separated list of hostaddr values is also accepted, in which case + // each host in the list is tried in order or randonly if load_balance_hosts + // is set. An empty item causes the corresponding host name to be used, or + // the default host name if that is empty as well. The target_session_attrs + // option controls properties the host must have to be considered + // acceptable. + Hostaddr netip.Addr `postgres:"hostaddr" env:"PGHOSTADDR"` + + // The port to connect to. Defaults to 5432. + // + // If multiple hosts were given in the host or hostaddr parameters, this + // parameter may specify a comma-separated list of ports of the same length + // as the host list, or it may specify a single port number to be used for + // all hosts. An empty string, or an empty item in a comma-separated list, + // specifies the default of 5432. + Port uint16 `postgres:"port" env:"PGPORT"` + + // The name of the database to connect to. + Database string `postgres:"dbname" env:"PGDATABASE"` + + // The user to sign in as. Defaults to the current user. + User string `postgres:"user" env:"PGUSER"` + + // The user's password. + Password string `postgres:"password" env:"PGPASSWORD"` + + // Path to [pgpass] file to store passwords; overrides Password. + // + // [pgpass]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + Passfile string `postgres:"passfile" env:"PGPASSFILE"` + + // Commandline options to send to the server at connection start. + Options string `postgres:"options" env:"PGOPTIONS"` + + // Application name, displayed in pg_stat_activity and log entries. + ApplicationName string `postgres:"application_name" env:"PGAPPNAME"` + + // Used if application_name is not given. Specifying a fallback name is + // useful in generic utility programs that wish to set a default application + // name but allow it to be overridden by the user. + FallbackApplicationName string `postgres:"fallback_application_name" env:"-"` + + // Whether to use SSL. Defaults to "require" (different from libpq's default + // of "prefer"). + // + // [RegisterTLSConfig] can be used to registers a custom [tls.Config], which + // can be used by setting sslmode=pqgo-«key» in the connection string. + SSLMode SSLMode `postgres:"sslmode" env:"PGSSLMODE"` + + // When set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only). + SSLNegotiation SSLNegotiation `postgres:"sslnegotiation" env:"PGSSLNEGOTIATION"` + + // Path to client SSL certificate. The file must contain PEM encoded data. + // + // Defaults to ~/.postgresql/postgresql.crt + SSLCert string `postgres:"sslcert" env:"PGSSLCERT"` + + // Path to secret key for sslcert. The file must contain PEM encoded data. + // + // Defaults to ~/.postgresql/postgresql.key + SSLKey string `postgres:"sslkey" env:"PGSSLKEY"` + + // Path to root certificate. The file must contain PEM encoded data. + // + // The special value "system" can be used to load the system's root + // certificates ([x509.SystemCertPool]). This will change the default + // sslmode to verify-full and issue an error if a lower setting is used – as + // anyone can register a valid certificate hostname verification becomes + // essential. + // + // Defaults to ~/.postgresql/root.crt. + SSLRootCert string `postgres:"sslrootcert" env:"PGSSLROOTCERT"` + + // By default SNI is on, any value which is not starting with "1" disables + // SNI. + SSLSNI bool `postgres:"sslsni" env:"PGSSLSNI"` + + // Minimum SSL/TLS protocol version to allow for the connection. + // + // The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at + // the time of writing. + SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"` + + // Maximum SSL/TLS protocol version to allow for the connection. If not set, + // this parameter is ignored and the connection will use the maximum bound + // defined by the backend, if set. Setting the maximum protocol version is + // mainly useful for testing or if some component has issues working with a + // newer protocol. + SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"` + + // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a + // PEM file. This is a pq extension, not supported in libpq. + SSLInline bool `postgres:"sslinline" env:"-"` + + // GSS (Kerberos) service name when constructing the SPN (default is + // postgres). This will be combined with the host to form the full SPN: + // krbsrvname/host. + KrbSrvname string `postgres:"krbsrvname" env:"PGKRBSRVNAME"` + + // GSS (Kerberos) SPN. This takes priority over krbsrvname if present. This + // is a pq extension, not supported in libpq. + KrbSpn string `postgres:"krbspn" env:"-"` + + // Maximum time to wait while connecting, in seconds. Zero, negative, or not + // specified means wait indefinitely + ConnectTimeout time.Duration `postgres:"connect_timeout" env:"PGCONNECT_TIMEOUT"` + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. This is a pq extension, not + // supported in libpq. + BinaryParameters bool `postgres:"binary_parameters" env:"-"` + + // This connection should never use the binary format when receiving query + // results from prepared statements. Only provided for debugging. This is a + // pq extension, not supported in libpq. + DisablePreparedBinaryResult bool `postgres:"disable_prepared_binary_result" env:"-"` + + // Client encoding; pq only supports UTF8 and this must be blank or "UTF8". + ClientEncoding string `postgres:"client_encoding" env:"PGCLIENTENCODING"` + + // Date/time representation to use; pq only supports "ISO, MDY" and this + // must be blank or "ISO, MDY". + Datestyle string `postgres:"datestyle" env:"PGDATESTYLE"` + + // Default time zone. + TZ string `postgres:"tz" env:"PGTZ"` + + // Default mode for the genetic query optimizer. + Geqo string `postgres:"geqo" env:"PGGEQO"` + + // Determine whether the session must have certain properties to be + // acceptable. It's typically used in combination with multiple host names + // to select the first acceptable alternative among several hosts. + TargetSessionAttrs TargetSessionAttrs `postgres:"target_session_attrs" env:"PGTARGETSESSIONATTRS"` + + // Controls the order in which the client tries to connect to the available + // hosts. Once a connection attempt is successful no other hosts will be + // tried. This parameter is typically used in combination with multiple host + // names. + // + // This parameter can be used in combination with target_session_attrs to, + // for example, load balance over standby servers only. Once successfully + // connected, subsequent queries on the returned connection will all be sent + // to the same server. + LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"` + + // Minimum acceptable PostgreSQL protocol version. If the server does not + // support at least this version, the connection will fail. Defaults to + // "3.0". + MinProtocolVersion ProtocolVersion `postgres:"min_protocol_version" env:"PGMINPROTOCOLVERSION"` + + // Maximum PostgreSQL protocol version to request from the server. Defaults to "3.0". + MaxProtocolVersion ProtocolVersion `postgres:"max_protocol_version" env:"PGMAXPROTOCOLVERSION"` + + // Load connection parameters from the service file at ~/.pg_service.conf + // (which can be configured with PGSERVICEFILE). + // + // The service file is a INI-like file to configure connection parameters: + // + // [servicename] + // # Comment + // dbname=foo + // + // Unlike libpq, this does not look at the system-wide service file, as the + // location of this is a compile-time value that is not easy for pq to + // retrieve. + Service string `postgres:"service" env:"PGSERVICE"` + + // Path to connection service file. Defaults to ~/.pg_service.conf. + ServiceFile string `postgres:"-" env:"PGSERVICEFILE"` + + // Runtime parameters: any unrecognized parameter in the DSN will be added + // to this and sent to PostgreSQL during startup. + Runtime map[string]string `postgres:"-" env:"-"` + + // Multi contains additional connection details. The first value is + // available in [Config.Host], [Config.Hostaddr], and [Config.Port], and + // additional ones (if any) are available here. + Multi []ConfigMultihost + + // Record which parameters were given, so we can distinguish between an + // empty string "not given at all". + // + // The alternative is to use pointers or sql.Null[..], but that's more + // awkward to use. + set []string `env:"set"` + + multiHost []string + multiHostaddr []netip.Addr + multiPort []uint16 +} + +// ConfigMultihost specifies an additional server to try to connect to. +type ConfigMultihost struct { + Host string + Hostaddr netip.Addr + Port uint16 +} + +// NewConfig creates a new [Config] from the defaults, environment, service +// file, and DSN, in that order. That is: a service overrides any value from the +// environment, which in turn gets overridden by the same parameter in the +// connection string. +// +// Most connection parameters supported by PostgreSQL are supported; see the +// [Config] struct for supported parameters. pq also lets you specify any +// [run-time parameter] such as search_path or work_mem in the connection +// string. This is different from libpq, which uses the "options" parameter for +// this (which also works in pq). +// +// # key=value connection strings +// +// For key=value strings, use single quotes for values that contain whitespace +// or empty values. A backslash will escape the next character: +// +// "user=pqgo password='with spaces'" +// "user=''" +// "user=space\ man password='it\'s valid'" +// +// # URL connection strings +// +// pq supports URL-style postgres:// or postgresql:// connection strings in the +// form: +// +// postgres[ql]://[user[:pwd]@][net-location][:port][/dbname][?param1=value1&...] +// +// Go's [net/url.Parse] is more strict than PostgreSQL's URL parser and will +// (correctly) reject %2F in the host part. This means that unix-socket URLs: +// +// postgres://[user[:pwd]@][unix-socket][:port[/dbname]][?param1=value1&...] +// postgres://%2Ftmp%2Fpostgres/db +// +// will not work. You will need to use "host=/tmp/postgres dbname=db". +// +// Similarly, multiple ports also won't work, but ?port= will: +// +// postgres://host1,host2:5432,6543/dbname Doesn't work +// postgres://host1,host2/dbname?port=5432,6543 Works +// +// # Environment +// +// Most [PostgreSQL environment variables] are supported by pq. Environment +// variables have a lower precedence than explicitly provided connection +// parameters. pq will return an error if environment variables it does not +// support are set. Environment variables have a lower precedence than +// explicitly provided connection parameters. +// +// [PostgreSQL environment variables]: http://www.postgresql.org/docs/current/static/libpq-envars.html +// [run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html +func NewConfig(dsn string) (Config, error) { + return newConfig(dsn, os.Environ()) +} + +// Clone returns a copy of the [Config]. +func (cfg Config) Clone() Config { + rt := make(map[string]string) + for k, v := range cfg.Runtime { + rt[k] = v } + c := cfg + c.Runtime = rt + c.set = append([]string{}, cfg.set...) + return c +} - if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { - dsn, err = ParseURL(dsn) - if err != nil { - return nil, err - } +// hosts returns a slice of copies of this config, one for each host. +func (cfg Config) hosts() []Config { + cfgs := make([]Config, 1, len(cfg.Multi)+1) + cfgs[0] = cfg.Clone() + for _, m := range cfg.Multi { + c := cfg.Clone() + c.Host, c.Hostaddr, c.Port = m.Host, m.Hostaddr, m.Port + cfgs = append(cfgs, c) } - if err := parseOpts(dsn, o); err != nil { - return nil, err + if cfg.LoadBalanceHosts == LoadBalanceHostsRandom { + rand.Shuffle(len(cfgs), func(i, j int) { cfgs[i], cfgs[j] = cfgs[j], cfgs[i] }) } - // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback + return cfgs +} + +func newConfig(dsn string, env []string) (Config, error) { + cfg := Config{ + Host: "localhost", + Port: 5432, + SSLSNI: true, + MinProtocolVersion: "3.0", + MaxProtocolVersion: "3.0", + } + if err := cfg.fromEnv(env); err != nil { + return Config{}, err + } + if err := cfg.fromDSN(dsn); err != nil { + return Config{}, err + } + if err := cfg.fromService(); err != nil { + return Config{}, err + } + + // Need to have exactly the same number of host and hostaddr, or only specify one. + if cfg.isset("host") && cfg.Host != "" && cfg.Hostaddr != (netip.Addr{}) && len(cfg.multiHost) != len(cfg.multiHostaddr) { + return Config{}, fmt.Errorf("pq: could not match %d host names to %d hostaddr values", + len(cfg.multiHost)+1, len(cfg.multiHostaddr)+1) + } + // Need one port that applies to all or exactly the same number of ports as hosts. + l, ll := max(len(cfg.multiHost), len(cfg.multiHostaddr)), len(cfg.multiPort) + if l > 0 && ll > 0 && l != ll { + return Config{}, fmt.Errorf("pq: could not match %d port numbers to %d hosts", ll+1, l+1) + } + + // Populate Multi + if len(cfg.multiHostaddr) > len(cfg.multiHost) { + cfg.multiHost = make([]string, len(cfg.multiHostaddr)) + } + for i, h := range cfg.multiHost { + p := cfg.Port + if len(cfg.multiPort) > 0 { + p = cfg.multiPort[i] } + var addr netip.Addr + if len(cfg.multiHostaddr) > 0 { + addr = cfg.multiHostaddr[i] + } + cfg.Multi = append(cfg.Multi, ConfigMultihost{ + Host: h, + Port: p, + Hostaddr: addr, + }) + } + + // Use the "fallback" application name if necessary + if cfg.isset("fallback_application_name") && !cfg.isset("application_name") { + cfg.ApplicationName = cfg.FallbackApplicationName } // We can't work with any client_encoding other than UTF-8 currently. @@ -87,34 +602,556 @@ func NewConnector(dsn string) (*Connector, error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") + if cfg.isset("client_encoding") && !isUTF8(cfg.ClientEncoding) { + return Config{}, fmt.Errorf(`pq: unsupported client_encoding %q: must be absent or "UTF8"`, cfg.ClientEncoding) } - o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) + if cfg.isset("datestyle") && cfg.Datestyle != "ISO, MDY" { + return Config{}, fmt.Errorf(`pq: unsupported datestyle %q: must be absent or "ISO, MDY"`, cfg.Datestyle) + } + cfg.ClientEncoding, cfg.Datestyle = "UTF8", "ISO, MDY" + + // Set default user if not explicitly provided. + if !cfg.isset("user") { + u, err := pqutil.User() + if err != nil { + return Config{}, err + } + cfg.User = u + } + + // SSL is not necessary or supported over UNIX domain sockets. + if nw, _ := cfg.network(); nw == "unix" { + cfg.SSLMode = SSLModeDisable + } + + if cfg.MinProtocolVersion > cfg.MaxProtocolVersion { + return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q", + cfg.MinProtocolVersion, cfg.MaxProtocolVersion) + } + if cfg.SSLNegotiation == SSLNegotiationDirect { + switch cfg.SSLMode { + case SSLModeDisable, SSLModeAllow, SSLModePrefer: + return Config{}, fmt.Errorf( + `pq: weak sslmode %q may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full")`, + cfg.SSLMode) + } + } + if cfg.SSLRootCert == "system" { + if !cfg.isset("sslmode") { + cfg.SSLMode = SSLModeVerifyFull + } + if cfg.SSLMode != SSLModeVerifyFull { + return Config{}, fmt.Errorf( + `pq: weak sslmode %q may not be used with sslrootcert=system (use "verify-full")`, + cfg.SSLMode) } - } else { - o["datestyle"] = "ISO, MDY" } - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { - u, err := userCurrent() + return cfg, nil +} + +func (cfg Config) network() (string, string) { + if cfg.Hostaddr != (netip.Addr{}) { + return "tcp", net.JoinHostPort(cfg.Hostaddr.String(), strconv.Itoa(int(cfg.Port))) + } + // UNIX domain sockets are either represented by an (absolute) file system + // path or they live in the abstract name space (starting with an @). + if filepath.IsAbs(cfg.Host) || strings.HasPrefix(cfg.Host, "@") { + sockPath := filepath.Join(cfg.Host, ".s.PGSQL."+strconv.Itoa(int(cfg.Port))) + return "unix", sockPath + } + return "tcp", net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port))) +} + +func (cfg *Config) fromEnv(env []string) error { + e := make(map[string]string) + for _, v := range env { + k, v, ok := strings.Cut(v, "=") + if !ok { + continue + } + switch k { + case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated. + "PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff + "PGREQUIREAUTH", "PGCHANNELBINDING", + "PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER": + return fmt.Errorf("pq: environment variable $%s is not supported", k) + case "PGKRBSRVNAME": + if newGss == nil { + return fmt.Errorf("pq: environment variable $%s is not supported as Kerberos is not enabled", k) + } + } + e[k] = v + } + return cfg.setFromTag(e, "env", false) +} + +// fromDSN parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func (cfg *Config) fromDSN(dsn string) error { + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + var err error + dsn, err = convertURL(dsn) if err != nil { - return nil, err + return err + } + } + + var ( + opt = make(map[string]string) + s = []rune(dsn) + i int + next = func() (rune, bool) { + if i >= len(s) { + return 0, false + } + r := s[i] + i++ + return r, true + } + skipSpaces = func() (rune, bool) { + r, ok := next() + for unicode.IsSpace(r) && ok { + r, ok = next() + } + return r, ok + } + ) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = skipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = skipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = skipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + opt[string(keyRunes)] = "" + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + opt[string(keyRunes)] = string(valRunes) + } + + return cfg.setFromTag(opt, "postgres", false) +} + +func (cfg *Config) fromService() error { + if cfg.Service == "" { + return nil + } + + if !cfg.isset("PGSERVICEFILE") { + if home := pqutil.Home(false); home != "" { + cfg.ServiceFile = filepath.Join(home, ".pg_service.conf") + } + } + + opts, err := pgservice.FindService(cfg.ServiceFile, cfg.Service) + if err != nil { + return fmt.Errorf("pq: %w", err) + } + return cfg.setFromTag(opts, "postgres", true) +} + +func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) error { + f := "pq: wrong value for %q: " + if tag == "env" { + f = "pq: wrong value for $%s: " + } + var ( + types = reflect.TypeOf(cfg).Elem() + values = reflect.ValueOf(cfg).Elem() + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get(tag) + connectTimeout = (tag == "postgres" && k == "connect_timeout") || (tag == "env" && k == "PGCONNECT_TIMEOUT") + host = (tag == "postgres" && k == "host") || (tag == "env" && k == "PGHOST") + hostaddr = (tag == "postgres" && k == "hostaddr") || (tag == "env" && k == "PGHOSTADDR") + port = (tag == "postgres" && k == "port") || (tag == "env" && k == "PGPORT") + sslmode = (tag == "postgres" && k == "sslmode") || (tag == "env" && k == "PGSSLMODE") + sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION") + targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS") + loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") + minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION") + maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION") + sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION") + sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION") + ) + if k == "" || k == "-" { + continue + } + + v, ok := o[k] + delete(o, k) + if ok { + t, ok := rt.Tag.Lookup("postgres") + if !ok || t == "" || t == "-" { // For PGSERVICEFILE, which can only be from env + t, ok = rt.Tag.Lookup("env") + } + if ok && t != "" && t != "-" { + cfg.set = append(cfg.set, t) + } + switch rt.Type.Kind() { + default: + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type.Kind()) + case reflect.Struct: + if rt.Type == reflect.TypeOf(netip.Addr{}) { + if hostaddr { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + cfg.multiHostaddr = append(cfg.multiHostaddr, netip.Addr{}) + } else { + ip, err := netip.ParseAddr(vvv) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiHostaddr = append(cfg.multiHostaddr, ip) + } + } + } + ip, err := netip.ParseAddr(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.Set(reflect.ValueOf(ip)) + } else { + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type) + } + case reflect.String: + if sslmode && !slices.Contains(sslModes, SSLMode(v)) && !(strings.HasPrefix(v, "pqgo-") && hasTLSConfig(v[5:])) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslModes)) + } + if sslnegotiation && !slices.Contains(sslNegotiations, SSLNegotiation(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslNegotiations)) + } + if targetsessionattrs && !slices.Contains(targetSessionAttrs, TargetSessionAttrs(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(targetSessionAttrs)) + } + if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts)) + } + if (minprotocolversion || maxprotocolversion) && !slices.Contains(protocolVersions, ProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(protocolVersions)) + } + if (sslminprotocolversion || sslmaxprotocolversion) && !slices.Contains(sslProtocolVersions, SSLProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslProtocolVersions)) + } + if host { + vv := strings.Split(v, ",") + v = vv[0] + for i, vvv := range vv[1:] { + if vvv == "" { + vv[i+1] = "localhost" + } + } + cfg.multiHost = append(cfg.multiHost, vv[1:]...) + } + rv.SetString(v) + case reflect.Int64: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + if connectTimeout { + n = int64(time.Duration(n) * time.Second) + } + rv.SetInt(n) + case reflect.Uint16: + if port { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + vvv = "5432" + } + n, err := strconv.ParseUint(vvv, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiPort = append(cfg.multiPort, uint16(n)) + } + } + n, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetUint(n) + case reflect.Bool: + b, err := pqutil.ParseBool(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetBool(b) + } + } + } + + if service && len(o) > 0 { + // TODO(go1.23): use maps.Keys once we require Go 1.23. + var key string + for k := range o { + key = k + break + } + return fmt.Errorf("pq: unknown setting %q in service file for service %q", key, cfg.Service) + } + + // Set run-time; we delete map keys as they're set in the struct. + if !service && tag == "postgres" { + // Make sure database= sets dbname=, as that previously worked (kind of + // by accident). + // TODO(v2): remove + if d, ok := o["database"]; ok { + cfg.Database = d + delete(o, "database") + } + cfg.Runtime = o + } + + return nil +} + +// Should generally only be used from newConfig(), as it will never be set if +// people go outside that. +func (cfg Config) isset(name string) bool { + return slices.Contains(cfg.set, name) +} + +// Convert to a map; used only in tests. +func (cfg Config) tomap() map[string]string { + var ( + o = make(map[string]string) + values = reflect.ValueOf(cfg) + types = reflect.TypeOf(cfg) + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get("postgres") + ) + if k == "" || k == "-" { + continue + } + if !rv.IsZero() || slices.Contains(cfg.set, k) { + switch rt.Type.Kind() { + default: + if s, ok := rv.Interface().(fmt.Stringer); ok { + o[k] = s.String() + } else { + o[k] = rv.String() + } + case reflect.Uint16: + n := rv.Uint() + o[k] = strconv.FormatUint(n, 10) + case reflect.Int64: + n := rv.Int() + if k == "connect_timeout" { + n = int64(time.Duration(n) / time.Second) + } + o[k] = strconv.FormatInt(n, 10) + case reflect.Bool: + if rv.Bool() { + o[k] = "yes" + } else { + o[k] = "no" + } + } + } + } + for k, v := range cfg.Runtime { + o[k] = v + } + return o +} + +// Create DSN for this config; used only in tests. +func (cfg Config) string() string { + var ( + m = cfg.tomap() + keys = make([]string, 0, len(m)) + ) + for k := range m { + switch k { + case "datestyle", "client_encoding": + continue + case "host", "port", "user", "sslsni", "min_protocol_version", "max_protocol_version": + if !cfg.isset(k) { + continue + } + } + if k == "application_name" && m[k] == "pqgo" { + continue + } + if k == "host" && len(cfg.multiHost) > 0 { + m[k] += "," + strings.Join(cfg.multiHost, ",") + } + if k == "hostaddr" && len(cfg.multiHostaddr) > 0 { + for _, ha := range cfg.multiHostaddr { + m[k] += "," + if ha != (netip.Addr{}) { + m[k] += ha.String() + } + } + } + if k == "port" && len(cfg.multiPort) > 0 { + for _, p := range cfg.multiPort { + m[k] += "," + strconv.Itoa(int(p)) + } + } + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for i, k := range keys { + if i > 0 { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + var ( + v = m[k] + nv = make([]rune, 0, len(v)+2) + quote = v == "" + ) + for _, c := range v { + if c == ' ' { + quote = true + } + if c == '\'' { + nv = append(nv, '\\') + } + nv = append(nv, c) + } + if quote { + b.WriteByte('\'') + } + b.WriteString(string(nv)) + if quote { + b.WriteByte('\'') + } + } + return b.String() +} + +// Recognize all sorts of silly things as "UTF-8", like Postgres does +func isUTF8(name string) bool { + s := strings.Map(func(c rune) rune { + if 'A' <= c && c <= 'Z' { + return c + ('a' - 'A') + } + if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { + return c } - o["user"] = u + return -1 // discard + }, name) + return s == "utf8" || s == "unicode" +} + +func convertURL(url string) (string, error) { + u, err := neturl.Parse(url) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") + } + } + + if u.User != nil { + pw, _ := u.User.Password() + accrue("user", u.User.Username()) + accrue("password", pw) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) } - // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { - o["sslmode"] = "disable" + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) } - return &Connector{opts: o, dialer: defaultDialer{}}, nil + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil } diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index a8f16b2b..a7c73e01 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -1,13 +1,15 @@ package pq import ( - "bytes" "context" "database/sql/driver" "encoding/binary" "errors" "fmt" + "os" "sync" + + "github.com/lib/pq/internal/proto" ) var ( @@ -15,64 +17,28 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") - errCopyInProgress = errors.New("pq: COPY in progress") ) -// CopyIn creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). The target table should be visible in search_path. -func CopyIn(table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() -} - -// MakeStmt makes the stmt string for CopyIn and CopyInSchema. -func makeStmt(buffer *bytes.Buffer, columns ...string) { - //s := bytes.NewBufferString() - for i, col := range columns { - if i != 0 { - buffer.WriteString(", ") - } - BufferQuoteIdentifier(col, buffer) - } - buffer.WriteString(") FROM STDIN") -} - -// CopyInSchema creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). -func CopyInSchema(schema, table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(schema, buffer) - buffer.WriteRune('.') - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() -} - type copyin struct { cn *conn buffer []byte rowData chan []byte done chan bool - - closed bool - - mu struct { + closed bool + mu struct { sync.Mutex err error driver.Result } } -const ciBufferSize = 64 * 1024 - -// flush buffer before the buffer is filled up and needs reallocation -const ciBufferFlushSize = 63 * 1024 +const ( + ciBufferSize = 64 * 1024 + // flush buffer before the buffer is filled up and needs reallocation + ciBufferFlushSize = 63 * 1024 +) -func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { +func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, resErr error) { if !cn.isInTransaction() { return nil, errCopyNotSupportedOutsideTxn } @@ -84,69 +50,84 @@ func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { done: make(chan bool, 1), } // add CopyData identifier + 4 bytes for message length - ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) + ci.buffer = append(ci.buffer, byte(proto.CopyDataRequest), 0, 0, 0, 0) - b := cn.writeBuf('Q') + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, err + } awaitCopyInResponse: for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, err + } switch t { - case 'G': + case proto.CopyInResponse: if r.byte() != 0 { - err = errBinaryCopyNotSupported + resErr = errBinaryCopyNotSupported break awaitCopyInResponse } go ci.resploop() return ci, nil - case 'H': - err = errCopyToNotSupported + case proto.CopyOutResponse: + resErr = errCopyToNotSupported break awaitCopyInResponse - case 'E': - err = parseError(r) - case 'Z': - if err == nil { + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.ReadyForQuery: + if resErr == nil { ci.setBad(driver.ErrBadConn) - errorf("unexpected ReadyForQuery in response to COPY") + return nil, fmt.Errorf("pq: unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) - return nil, err + return nil, resErr default: ci.setBad(driver.ErrBadConn) - errorf("unknown response for copy query: %q", t) + return nil, fmt.Errorf("pq: unknown response for copy query: %q", t) } } // something went wrong, abort COPY before we return - b = cn.writeBuf('f') - b.string(err.Error()) - cn.send(b) + b = cn.writeBuf(proto.CopyFail) + b.string(resErr.Error()) + err = cn.send(b) + if err != nil { + return nil, err + } for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, err + } + switch t { - case 'c', 'C', 'E': - case 'Z': + case proto.CopyDoneResponse, proto.CommandComplete, proto.ErrorResponse: + case proto.ReadyForQuery: // correctly aborted, we're done cn.processReadyForQuery(r) - return nil, err + return nil, resErr default: ci.setBad(driver.ErrBadConn) - errorf("unknown response for CopyFail: %q", t) + return nil, fmt.Errorf("pq: unknown response for CopyFail: %q", t) } } } -func (ci *copyin) flush(buf []byte) { - // set message length (without message identifier) - binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) - - _, err := ci.cn.c.Write(buf) - if err != nil { - panic(err) +func (ci *copyin) flush(buf []byte) error { + if len(buf)-1 > proto.MaxUint32 { + return errors.New("pq: too many columns") + } + if debugProto { + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", proto.RequestCode(buf[0]), len(buf)-5, buf[5:]) } + binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) // Set message length (without message identifier). + _, err := ci.cn.c.Write(buf) + return err } func (ci *copyin) resploop() { @@ -160,20 +141,23 @@ func (ci *copyin) resploop() { return } switch t { - case 'C': + case proto.CommandComplete: // complete - res, _ := ci.cn.parseComplete(r.string()) + res, _, err := ci.cn.parseComplete(r.string()) + if err != nil { + panic(err) + } ci.setResult(res) - case 'N': + case proto.NoticeResponse: if n := ci.cn.noticeHandler; n != nil { - n(parseError(&r)) + n(parseError(&r, "")) } - case 'Z': + case proto.ReadyForQuery: ci.cn.processReadyForQuery(&r) ci.done <- true return - case 'E': - err := parseError(&r) + case proto.ErrorResponse: + err := parseError(&r, "") ci.setError(err) default: ci.setBad(driver.ErrBadConn) @@ -240,16 +224,13 @@ func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { // You need to call Exec(nil) to sync the COPY stream and to get any // errors from pending data, since Stmt.Close() doesn't return errors // to the user. -func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { +func (ci *copyin) Exec(v []driver.Value) (driver.Result, error) { if ci.closed { return nil, errCopyInClosed } - if err := ci.getBad(); err != nil { return nil, err } - defer ci.cn.errRecover(&err) - if err := ci.err(); err != nil { return nil, err } @@ -258,13 +239,18 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { if err := ci.Close(); err != nil { return driver.RowsAffected(0), err } - return ci.getResult(), nil } - numValues := len(v) + var ( + numValues = len(v) + err error + ) for i, value := range v { - ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) + ci.buffer, err = appendEncodedText(ci.buffer, value) + if err != nil { + return nil, ci.cn.handleError(err) + } if i < numValues-1 { ci.buffer = append(ci.buffer, '\t') } @@ -273,7 +259,10 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { ci.buffer = append(ci.buffer, '\n') if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } // reset buffer, keep bytes for message identifier and length ci.buffer = ci.buffer[:5] } @@ -288,20 +277,16 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { // You need to call Exec(nil) to sync the COPY stream and to get any // errors from pending data, since Stmt.Close() doesn't return errors // to the user. -func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { +func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, error) { if ci.closed { return nil, errCopyInClosed } - if finish := ci.cn.watchCancel(ctx); finish != nil { defer finish() } - if err := ci.getBad(); err != nil { return nil, err } - defer ci.cn.errRecover(&err) - if err := ci.err(); err != nil { return nil, err } @@ -310,7 +295,11 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, e ci.buffer = append(ci.buffer, '\n') if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } + // reset buffer, keep bytes for message identifier and length ci.buffer = ci.buffer[:5] } @@ -318,7 +307,7 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, e return driver.RowsAffected(0), nil } -func (ci *copyin) Close() (err error) { +func (ci *copyin) Close() error { if ci.closed { // Don't do anything, we're already closed return nil } @@ -327,19 +316,21 @@ func (ci *copyin) Close() (err error) { if err := ci.getBad(); err != nil { return err } - defer ci.cn.errRecover(&err) if len(ci.buffer) > 0 { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return ci.cn.handleError(err) + } } // Avoid touching the scratch buffer as resploop could be using it. - err = ci.cn.sendSimpleMessage('c') + err := ci.cn.sendSimpleMessage(proto.CopyDoneRequest) if err != nil { - return err + return ci.cn.handleError(err) } <-ci.done - ci.cn.inCopy = false + ci.cn.inProgress.Store(false) if err := ci.err(); err != nil { return err diff --git a/vendor/github.com/lib/pq/deprecated.go b/vendor/github.com/lib/pq/deprecated.go new file mode 100644 index 00000000..d43934a0 --- /dev/null +++ b/vendor/github.com/lib/pq/deprecated.go @@ -0,0 +1,133 @@ +package pq + +import ( + "bytes" + "database/sql" + + "github.com/lib/pq/pqerror" +) + +// [pq.Error.Severity] values. +// +// Deprecated: use pqerror.Severity[..] values. +// +//go:fix inline +const ( + Efatal = pqerror.SeverityFatal + Epanic = pqerror.SeverityPanic + Ewarning = pqerror.SeverityWarning + Enotice = pqerror.SeverityNotice + Edebug = pqerror.SeverityDebug + Einfo = pqerror.SeverityInfo + Elog = pqerror.SeverityLog +) + +// PGError is an interface used by previous versions of pq. +// +// Deprecated: use the Error type. This is never used. +type PGError interface { + Error() string + Fatal() bool + Get(k byte) (v string) +} + +// Get implements the legacy PGError interface. +// +// Deprecated: new code should use the fields of the Error struct directly. +func (e *Error) Get(k byte) (v string) { + switch k { + case 'S': + return e.Severity + case 'C': + return string(e.Code) + case 'M': + return e.Message + case 'D': + return e.Detail + case 'H': + return e.Hint + case 'P': + return e.Position + case 'p': + return e.InternalPosition + case 'q': + return e.InternalQuery + case 'W': + return e.Where + case 's': + return e.Schema + case 't': + return e.Table + case 'c': + return e.Column + case 'd': + return e.DataTypeName + case 'n': + return e.Constraint + case 'F': + return e.File + case 'L': + return e.Line + case 'R': + return e.Routine + } + return "" +} + +// ParseURL converts a url to a connection string for driver.Open. +// +// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...") +// now works, and calling this manually is no longer required. +func ParseURL(url string) (string, error) { return convertURL(url) } + +// NullTime represents a [time.Time] that may be null. +// +// Deprecated: this is an alias for [sql.NullTime]. +// +//go:fix inline +type NullTime = sql.NullTime + +// CopyIn creates a COPY FROM statement which can be prepared with Tx.Prepare(). +// The target table should be visible in search_path. +// +// It copies all columns if the list of columns is empty. +// +// Deprecated: there is no need to use this query builder, you can use: +// +// tx.Prepare("copy tbl (col1, col2) from stdin") +func CopyIn(table string, columns ...string) string { + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() +} + +// CopyInSchema creates a COPY FROM statement which can be prepared with +// Tx.Prepare(). +// +// Deprecated: there is no need to use this query builder, you can use: +// +// tx.Prepare("copy schema.tbl (col1, col2) from stdin") +func CopyInSchema(schema, table string, columns ...string) string { + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(schema, b) + b.WriteRune('.') + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() +} + +func makeStmt(b *bytes.Buffer, columns ...string) { + if len(columns) == 0 { + b.WriteString(" FROM STDIN") + return + } + b.WriteString(" (") + for i, col := range columns { + if i != 0 { + b.WriteString(", ") + } + BufferQuoteIdentifier(col, b) + } + b.WriteString(") FROM STDIN") +} diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index b5718480..9d9d78e4 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -1,8 +1,8 @@ /* -Package pq is a pure Go Postgres driver for the database/sql package. +Package pq is a Go PostgreSQL driver for database/sql. -In most cases clients will use the database/sql package instead of -using this package directly. For example: +Most clients will use the database/sql package instead of using this package +directly. For example: import ( "database/sql" @@ -11,239 +11,113 @@ using this package directly. For example: ) func main() { - connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "user=pqgo dbname=pqgo sslmode=verify-full" + db, err := sql.Open("postgres", dsn) if err != nil { log.Fatal(err) } age := 21 - rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) - … + rows, err := db.Query("select name from users where age = $1", age) + // … } -You can also connect to a database using a URL. For example: +You can also connect with an URL: - connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "postgres://pqgo:password@localhost/pqgo?sslmode=verify-full" + db, err := sql.Open("postgres", dsn) +# Connection String Parameters -Connection String Parameters +See [NewConfig]. +# Queries -Similarly to libpq, when establishing a connection using pq you are expected to -supply a connection string containing zero or more parameters. -A subset of the connection parameters supported by libpq are also supported by pq. -Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) -directly in the connection string. This is different from libpq, which does not allow -run-time parameters in the connection string, instead requiring you to supply -them in the options parameter. +database/sql does not dictate any specific format for parameter placeholders, +and pq uses the PostgreSQL-native ordinal markers ($1, $2, etc.). The same +placeholder can be used more than once: -For compatibility with libpq, the following special connection parameters are -supported: + rows, err := db.Query( + `select * from users where name = $1 or age between $2 and $2 + 3`, + "Duck", 64) - * dbname - The name of the database to connect to - * user - The user to sign in as - * password - The user's password - * host - The host to connect to. Values that start with / are for unix - domain sockets. (default is localhost) - * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not - the default for libpq) - * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or - not specified means wait indefinitely. - * sslcert - Cert file location. The file must contain PEM encoded data. - * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file - must contain PEM encoded data. +pq does not support [sql.Result.LastInsertId]. Use the RETURNING clause with a +Query or QueryRow call instead to return the identifier: -Valid values for sslmode are: - - * disable - No SSL - * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the - server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by - the server was signed by a trusted CA and the server host name - matches the one in the certificate) - -See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING -for more information about connection string parameters. - -Use single quotes for values that contain whitespace: - - "user=pqgotest password='with spaces'" - -A backslash will escape the next character in values: - - "user=space\ man password='it\'s valid'" - -Note that the connection parameter client_encoding (which sets the -text encoding for the connection) may be set but must be "UTF8", -matching with the same rules as Postgres. It is an error to provide -any other value. - -In addition to the parameters listed above, any run-time parameter that can be -set at backend start time can be set in the connection string. For more -information, see -http://www.postgresql.org/docs/current/static/runtime-config.html. - -Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html -supported by libpq are also supported by pq. If any of the environment -variables not supported by pq are set, pq will panic during connection -establishment. Environment variables have a lower precedence than explicitly -provided connection parameters. - -The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html -is supported, but on Windows PGPASSFILE must be specified explicitly. - - -Queries - - -database/sql does not dictate any specific format for parameter -markers in query strings, and pq uses the Postgres-native ordinal markers, -as shown above. The same marker can be reused for the same parameter: - - rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 - OR age BETWEEN $2 AND $2 + 3`, "orange", 64) - -pq does not support the LastInsertId() method of the Result type in database/sql. -To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres -RETURNING clause with a standard Query or QueryRow call: + row := db.QueryRow(`insert into users(name, age) values('Scrooge McDuck', 93) returning id`) var userid int - err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) - VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) - -For more details on RETURNING, see the Postgres documentation: - - http://www.postgresql.org/docs/current/static/sql-insert.html - http://www.postgresql.org/docs/current/static/sql-update.html - http://www.postgresql.org/docs/current/static/sql-delete.html + err := row.Scan(&userid) -For additional instructions on querying see the documentation for the database/sql package. +# Data Types - -Data Types - - -Parameters pass through driver.DefaultParameterConverter before they are handled -by this package. When the binary_parameters connection option is enabled, -[]byte values are sent directly to the backend as data in binary format. +Parameters pass through [driver.DefaultParameterConverter] before they are handled +by this package. When the binary_parameters connection option is enabled, []byte +values are sent directly to the backend as data in binary format. This package returns the following types for values from the PostgreSQL backend: - - integer types smallint, integer, and bigint are returned as int64 - - floating-point types real and double precision are returned as float64 - - character types char, varchar, and text are returned as string - - temporal types date, time, timetz, timestamp, and timestamptz are - returned as time.Time - - the boolean type is returned as bool - - the bytea type is returned as []byte + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte All other types are returned directly from the backend as []byte values in text format. +# Errors -Errors - - -pq may return errors of type *pq.Error which can be interrogated for error details: +pq may return errors of type [*pq.Error] which contain error details: - if err, ok := err.(*pq.Error); ok { - fmt.Println("pq error:", err.Code.Name()) - } - -See the pq.Error type for details. - - -Bulk imports - -You can perform bulk imports by preparing a statement returned by pq.CopyIn (or -pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement -handle can then be repeatedly "executed" to copy data into the target table. -After all data has been processed you should call Exec() once with no arguments -to flush all buffered data. Any call to Exec() might return an error which -should be handled appropriately, but because of the internal buffering an error -returned by Exec() might not be related to the data passed in the call that -failed. - -CopyIn uses COPY FROM internally. It is not possible to COPY outside of an -explicit transaction in pq. - -Usage example: - - txn, err := db.Begin() - if err != nil { - log.Fatal(err) - } - - stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) - if err != nil { - log.Fatal(err) - } - - for _, user := range users { - _, err = stmt.Exec(user.Name, int64(user.Age)) - if err != nil { - log.Fatal(err) - } + pqErr := new(pq.Error) + if errors.As(err, &pqErr) { + fmt.Println("pq error:", pqErr.Code.Name()) } - _, err = stmt.Exec() - if err != nil { - log.Fatal(err) - } +# Bulk imports - err = stmt.Close() - if err != nil { - log.Fatal(err) - } +You can perform bulk imports by preparing a "COPY [..] FROM STDIN" statement in +a transaction ([sql.Tx]). The returned [sql.Stmt] handle can then be repeatedly +"executed" to copy data into the target table. After all data has been processed +you should call Exec() once with no arguments to flush all buffered data. Any +call to Exec() might return an error which should be handled appropriately, but +because of the internal buffering an error returned by Exec() might not be +related to the data passed in the call that failed. - err = txn.Commit() - if err != nil { - log.Fatal(err) - } +It is not possible to COPY outside of an explicit transaction in pq. +Use nil for NULL, or explicitly add WITH NULL 'SOME STRING' (the default of \N +doesn't work). -Notifications +# Notifications +PostgreSQL supports a simple publish/subscribe model using PostgreSQL's [NOTIFY] mechanism. -PostgreSQL supports a simple publish/subscribe model over database -connections. See http://www.postgresql.org/docs/current/static/sql-notify.html -for more information about the general mechanism. - -To start listening for notifications, you first have to open a new connection -to the database by calling NewListener. This connection can not be used for -anything other than LISTEN / NOTIFY. Calling Listen will open a "notification +To start listening for notifications, you first have to open a new connection to +the database by calling [NewListener]. This connection can not be used for +anything other than LISTEN / NOTIFY. Calling Listen will open a "notification channel"; once a notification channel is open, a notification generated on that -channel will effect a send on the Listener.Notify channel. A notification +channel will effect a send on the Listener.Notify channel. A notification channel will remain open until Unlisten is called, though connection loss might -result in some notifications being lost. To solve this problem, Listener sends -a nil pointer over the Notify channel any time the connection is re-established -following a connection loss. The application can get information about the -state of the underlying connection by setting an event callback in the call to +result in some notifications being lost. To solve this problem, Listener sends a +nil pointer over the Notify channel any time the connection is re-established +following a connection loss. The application can get information about the state +of the underlying connection by setting an event callback in the call to NewListener. -A single Listener can safely be used from concurrent goroutines, which means +A single [Listener] can safely be used from concurrent goroutines, which means that there is often no need to create more than one Listener in your -application. However, a Listener is always connected to a single database, so +application. However, a Listener is always connected to a single database, so you will need to create a new Listener instance for every database you want to receive notifications in. The channel name in both Listen and Unlisten is case sensitive, and can contain -any characters legal in an identifier (see -http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -for more information). Note that the channel name will be truncated to 63 -bytes by the PostgreSQL server. - -You can find a complete, working example of Listener usage at -https://godoc.org/github.com/lib/pq/example/listen. - - -Kerberos Support +any characters legal in an [identifier]. Note that the channel name will be +truncated to 63 bytes by the PostgreSQL server. +# Kerberos Support If you need support for Kerberos authentication, add the following to your main package: @@ -254,15 +128,10 @@ package: pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) } -This package is in a separate module so that users who don't need Kerberos -don't have to download unnecessary dependencies. - -When imported, additional connection string parameters are supported: +This package is in a separate module so that users who don't need Kerberos don't +have to add unnecessary dependencies. - * krbsrvname - GSS (Kerberos) service name when constructing the - SPN (default is `postgres`). This will be combined with the host - to form the full SPN: `krbsrvname/host`. - * krbspn - GSS (Kerberos) SPN. This takes priority over - `krbsrvname` if present. +[identifier]: http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +[NOTIFY]: http://www.postgresql.org/docs/current/static/sql-notify.html */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index bffe6096..f9b65051 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -2,171 +2,170 @@ package pq import ( "bytes" - "database/sql/driver" "encoding/binary" "encoding/hex" "errors" "fmt" - "math" - "regexp" "strconv" "strings" - "sync" "time" + "github.com/lib/pq/internal/pqtime" "github.com/lib/pq/oid" ) -var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`) - -func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { +func binaryEncode(x any) ([]byte, error) { switch v := x.(type) { case []byte: - return v + return v, nil default: - return encode(parameterStatus, x, oid.T_unknown) + return encode(x, oid.T_unknown) } } -func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { +func encode(x any, pgtypOid oid.Oid) ([]byte, error) { switch v := x.(type) { case int64: - return strconv.AppendInt(nil, v, 10) + return strconv.AppendInt(nil, v, 10), nil case float64: - return strconv.AppendFloat(nil, v, 'f', -1, 64) + return strconv.AppendFloat(nil, v, 'f', -1, 64), nil case []byte: + if v == nil { + return nil, nil + } if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, v) + return encodeBytea(v), nil } - - return v + return v, nil case string: if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, []byte(v)) + return encodeBytea([]byte(v)), nil } - - return []byte(v) + return []byte(v), nil case bool: - return strconv.AppendBool(nil, v) + return strconv.AppendBool(nil, v), nil case time.Time: - return formatTs(v) - + return formatTS(v), nil default: - errorf("encode: unknown type for %T", v) + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } - - panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { +func decode(ps *parameterStatus, s []byte, typ oid.Oid, f format) (any, error) { switch f { case formatBinary: - return binaryDecode(parameterStatus, s, typ) + return binaryDecode(s, typ) case formatText: - return textDecode(parameterStatus, s, typ) + return textDecode(ps, s, typ) default: - panic("not reached") + panic("unreachable") } } -func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func binaryDecode(s []byte, typ oid.Oid) (any, error) { switch typ { case oid.T_bytea: - return s + return s, nil case oid.T_int8: - return int64(binary.BigEndian.Uint64(s)) + return int64(binary.BigEndian.Uint64(s)), nil case oid.T_int4: - return int64(int32(binary.BigEndian.Uint32(s))) + return int64(int32(binary.BigEndian.Uint32(s))), nil case oid.T_int2: - return int64(int16(binary.BigEndian.Uint16(s))) + return int64(int16(binary.BigEndian.Uint16(s))), nil case oid.T_uuid: - b, err := decodeUUIDBinary(s) - if err != nil { - panic(err) - } - return b - + return decodeUUIDBinary(s) default: - errorf("don't know how to decode binary parameter of type %d", uint32(typ)) + return nil, fmt.Errorf("pq: don't know how to decode binary parameter of type %d", uint32(typ)) + } + +} + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) } - panic("not reached") + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + return dst, nil } -func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { switch typ { - case oid.T_char, oid.T_varchar, oid.T_text: - return string(s) + case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text: + return string(s), nil case oid.T_bytea: b, err := parseBytea(s) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return b + return b, err case oid.T_timestamptz: - return parseTs(parameterStatus.currentLocation, string(s)) + return parseTS(ps.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: - return parseTs(nil, string(s)) + return parseTS(nil, string(s)) case oid.T_time: - return mustParse("15:04:05", typ, s) + return parseTime(typ, s) case oid.T_timetz: - return mustParse("15:04:05-07", typ, s) + return parseTime(typ, s) case oid.T_bool: - return s[0] == 't' + return s[0] == 't', nil case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return i + return i, err case oid.T_float4, oid.T_float8: // We always use 64 bit parsing, regardless of whether the input text is for // a float4 or float8, because clients expect float64s for all float datatypes // and returning a 32-bit parsed float64 produces lossy results. f, err := strconv.ParseFloat(string(s), 64) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return f + return f, err } - - return s + return s, nil } // appendEncodedText encodes item in text format as required by COPY // and appends to buf -func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { +func appendEncodedText(buf []byte, x any) ([]byte, error) { switch v := x.(type) { case int64: - return strconv.AppendInt(buf, v, 10) + return strconv.AppendInt(buf, v, 10), nil case float64: - return strconv.AppendFloat(buf, v, 'f', -1, 64) + return strconv.AppendFloat(buf, v, 'f', -1, 64), nil case []byte: - encodedBytea := encodeBytea(parameterStatus.serverVersion, v) - return appendEscapedText(buf, string(encodedBytea)) + encodedBytea := encodeBytea(v) + return appendEscapedText(buf, string(encodedBytea)), nil case string: - return appendEscapedText(buf, v) + return appendEscapedText(buf, v), nil case bool: - return strconv.AppendBool(buf, v) + return strconv.AppendBool(buf, v), nil case time.Time: - return append(buf, formatTs(v)...) + return append(buf, formatTS(v)...), nil case nil: - return append(buf, "\\N"...) + return append(buf, `\N`...), nil default: - errorf("encode: unknown type for %T", v) + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } - - panic("not reached") } func appendEscapedText(buf []byte, text string) []byte { escapeNeeded := false startPos := 0 - var c byte // check if we need to escape for i := 0; i < len(text); i++ { - c = text[i] + c := text[i] if c == '\\' || c == '\n' || c == '\r' || c == '\t' { escapeNeeded = true startPos = i @@ -180,8 +179,7 @@ func appendEscapedText(buf []byte, text string) []byte { // copy till first char to escape, iterate the rest result := append(buf, text[:startPos]...) for i := startPos; i < len(text); i++ { - c = text[i] - switch c { + switch c := text[i]; c { case '\\': result = append(result, '\\', '\\') case '\n': @@ -197,119 +195,62 @@ func appendEscapedText(buf []byte, text string) []byte { return result } -func mustParse(f string, typ oid.Oid, s []byte) time.Time { +func parseTime(typ oid.Oid, s []byte) (time.Time, error) { str := string(s) - // Check for a minute and second offset in the timezone. - if typ == oid.T_timestamptz || typ == oid.T_timetz { - for i := 3; i <= 6; i += 3 { - if str[len(str)-i] == ':' { - f += ":00" - continue - } - break + f := "15:04:05" + if typ == oid.T_timetz { + f = "15:04:05-07" + // PostgreSQL just sends the hour if the minute and second is 0: + // 22:04:59+00 + // 22:04:59+08 + // 22:04:59+08:30 + // 22:04:59+08:30:40 + // 23:00:00.112321+02:12:13 + // So add those to the format string. + c := strings.Count(str, ":") + if c > 3 { + f = "15:04:05-07:00:00" + } else if c > 2 { + f = "15:04:05-07:00" } } - // Special case for 24:00 time. - // Unfortunately, golang does not parse 24:00 as a proper time. - // In this case, we want to try "round to the next day", to differentiate. - // As such, we find if the 24:00 time matches at the beginning; if so, - // we default it back to 00:00 but add a day later. + // Go doesn't parse 24:00, so manually set that to midnight on Jan 2. 24:00 + // is never with subseconds but may have a timezone: + // 24:00:00 + // 24:00:00+08 + // 24:00:00-08:01:01 var is2400Time bool - switch typ { - case oid.T_timetz, oid.T_time: - if matches := time2400Regex.FindStringSubmatch(str); matches != nil { - // Concatenate timezone information at the back. - str = "00:00:00" + str[len(matches[1]):] - is2400Time = true + if strings.HasPrefix(str, "24:00:00") { + is2400Time = true + if len(str) > 8 { + str = "00:00:00" + str[8:] + } else { + str = "00:00:00" } } + t, err := time.Parse(f, str) if err != nil { - errorf("decode: %s", err) + return time.Time{}, errors.New("pq: " + err.Error()) } if is2400Time { t = t.Add(24 * time.Hour) } - return t -} - -var errInvalidTimestamp = errors.New("invalid timestamp") - -type timestampParser struct { - err error + // TODO(v2): it uses UTC, which it shouldn't. But I'm afraid changing it now + // will break people's code. + //if typ == oid.T_time { + // // Don't use UTC but time.FixedZone("", 0) + // t = t.In(globalLocationCache.getLocation(0)) + //} + return t, nil } -func (p *timestampParser) expect(str string, char byte, pos int) { - if p.err != nil { - return - } - if pos+1 > len(str) { - p.err = errInvalidTimestamp - return - } - if c := str[pos]; c != char && p.err == nil { - p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) - } -} - -func (p *timestampParser) mustAtoi(str string, begin int, end int) int { - if p.err != nil { - return 0 - } - if begin < 0 || end < 0 || begin > end || end > len(str) { - p.err = errInvalidTimestamp - return 0 - } - result, err := strconv.Atoi(str[begin:end]) - if err != nil { - if p.err == nil { - p.err = fmt.Errorf("expected number; got '%v'", str) - } - return 0 - } - return result -} - -// The location cache caches the time zones typically used by the client. -type locationCache struct { - cache map[int]*time.Location - lock sync.Mutex -} - -// All connections share the same list of timezones. Benchmarking shows that -// about 5% speed could be gained by putting the cache in the connection and -// losing the mutex, at the cost of a small amount of memory and a somewhat -// significant increase in code complexity. -var globalLocationCache = newLocationCache() - -func newLocationCache() *locationCache { - return &locationCache{cache: make(map[int]*time.Location)} -} - -// Returns the cached timezone for the specified offset, creating and caching -// it if necessary. -func (c *locationCache) getLocation(offset int) *time.Location { - c.lock.Lock() - defer c.lock.Unlock() - - location, ok := c.cache[offset] - if !ok { - location = time.FixedZone("", offset) - c.cache[offset] = location - } - - return location -} - -var infinityTsEnabled = false -var infinityTsNegative time.Time -var infinityTsPositive time.Time - -const ( - infinityTsEnabledAlready = "pq: infinity timestamp enabled already" - infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" +var ( + infinityTSEnabled = false + infinityTSNegative time.Time + infinityTSPositive time.Time ) // EnableInfinityTs controls the handling of Postgres' "-infinity" and @@ -333,170 +274,63 @@ const ( // undefined behavior. If EnableInfinityTs is called more than once, it will // panic. func EnableInfinityTs(negative time.Time, positive time.Time) { - if infinityTsEnabled { - panic(infinityTsEnabledAlready) + if infinityTSEnabled { + panic("pq: infinity timestamp already enabled") } if !negative.Before(positive) { - panic(infinityTsNegativeMustBeSmaller) + panic("pq: infinity timestamp: negative value must be smaller (before) than positive") } - infinityTsEnabled = true - infinityTsNegative = negative - infinityTsPositive = positive + infinityTSEnabled = true + infinityTSNegative = negative + infinityTSPositive = positive } -/* - * Testing might want to toggle infinityTsEnabled - */ -func disableInfinityTs() { - infinityTsEnabled = false +// Testing might want to toggle infinityTSEnabled +func disableInfinityTS() { + infinityTSEnabled = false } -// This is a time function specific to the Postgres default DateStyle -// setting ("ISO, MDY"), the only one we currently support. This -// accounts for the discrepancies between the parsing available with -// time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) interface{} { +// This is a time function specific to the Postgres default DateStyle setting +// ("ISO, MDY"), the only one we currently support. This accounts for the +// discrepancies between the parsing available with time.Parse and the Postgres +// date formatting quirks. +func parseTS(currentLocation *time.Location, str string) (any, error) { switch str { case "-infinity": - if infinityTsEnabled { - return infinityTsNegative + if infinityTSEnabled { + return infinityTSNegative, nil } - return []byte(str) + return []byte(str), nil case "infinity": - if infinityTsEnabled { - return infinityTsPositive + if infinityTSEnabled { + return infinityTSPositive, nil } - return []byte(str) + return []byte(str), nil } t, err := ParseTimestamp(currentLocation, str) if err != nil { - panic(err) + err = errors.New("pq: " + err.Error()) } - return t + return t, err } // ParseTimestamp parses Postgres' text format. It returns a time.Time in // currentLocation iff that time's offset agrees with the offset sent from the -// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the -// fixed offset offset provided by the Postgres server. +// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the fixed +// offset offset provided by the Postgres server. func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { - p := timestampParser{} - - monSep := strings.IndexRune(str, '-') - // this is Gregorian year, not ISO Year - // In Gregorian system, the year 1 BC is followed by AD 1 - year := p.mustAtoi(str, 0, monSep) - daySep := monSep + 3 - month := p.mustAtoi(str, monSep+1, daySep) - p.expect(str, '-', daySep) - timeSep := daySep + 3 - day := p.mustAtoi(str, daySep+1, timeSep) - - minLen := monSep + len("01-01") + 1 - - isBC := strings.HasSuffix(str, " BC") - if isBC { - minLen += 3 - } - - var hour, minute, second int - if len(str) > minLen { - p.expect(str, ' ', timeSep) - minSep := timeSep + 3 - p.expect(str, ':', minSep) - hour = p.mustAtoi(str, timeSep+1, minSep) - secSep := minSep + 3 - p.expect(str, ':', secSep) - minute = p.mustAtoi(str, minSep+1, secSep) - secEnd := secSep + 3 - second = p.mustAtoi(str, secSep+1, secEnd) - } - remainderIdx := monSep + len("01-01 00:00:00") + 1 - // Three optional (but ordered) sections follow: the - // fractional seconds, the time zone offset, and the BC - // designation. We set them up here and adjust the other - // offsets if the preceding sections exist. - - nanoSec := 0 - tzOff := 0 - - if remainderIdx < len(str) && str[remainderIdx] == '.' { - fracStart := remainderIdx + 1 - fracOff := strings.IndexAny(str[fracStart:], "-+Z ") - if fracOff < 0 { - fracOff = len(str) - fracStart - } - fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) - nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) - - remainderIdx += fracOff + 1 - } - if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { - // time zone separator is always '-' or '+' or 'Z' (UTC is +00) - var tzSign int - switch c := str[tzStart]; c { - case '-': - tzSign = -1 - case '+': - tzSign = +1 - default: - return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) - } - tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) - remainderIdx += 3 - var tzMin, tzSec int - if remainderIdx < len(str) && str[remainderIdx] == ':' { - tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) - remainderIdx += 3 - } - if remainderIdx < len(str) && str[remainderIdx] == ':' { - tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) - remainderIdx += 3 - } - tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) - } else if tzStart < len(str) && str[tzStart] == 'Z' { - // time zone Z separator indicates UTC is +00 - remainderIdx += 1 - } - - var isoYear int - - if isBC { - isoYear = 1 - year - remainderIdx += 3 - } else { - isoYear = year - } - if remainderIdx < len(str) { - return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) - } - t := time.Date(isoYear, time.Month(month), day, - hour, minute, second, nanoSec, - globalLocationCache.getLocation(tzOff)) - - if currentLocation != nil { - // Set the location of the returned Time based on the session's - // TimeZone value, but only if the local time zone database agrees with - // the remote database on the offset. - lt := t.In(currentLocation) - _, newOff := lt.Zone() - if newOff == tzOff { - t = lt - } - } - - return t, p.err + return pqtime.Parse(currentLocation, str) } -// formatTs formats t into a format postgres understands. -func formatTs(t time.Time) []byte { - if infinityTsEnabled { +// formatTS formats t into a format postgres understands. +func formatTS(t time.Time) []byte { + if infinityTSEnabled { // t <= -infinity : ! (t > -infinity) - if !t.After(infinityTsNegative) { + if !t.After(infinityTSNegative) { return []byte("-infinity") } // t >= infinity : ! (!t < infinity) - if !t.Before(infinityTsPositive) { + if !t.Before(infinityTSPositive) { return []byte("infinity") } } @@ -505,128 +339,62 @@ func formatTs(t time.Time) []byte { // FormatTimestamp formats t into Postgres' text format for timestamps. func FormatTimestamp(t time.Time) []byte { - // Need to send dates before 0001 A.D. with " BC" suffix, instead of the - // minus sign preferred by Go. - // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on - bc := false - if t.Year() <= 0 { - // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" - t = t.AddDate((-t.Year())*2+1, 0, 0) - bc = true - } - b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) - - _, offset := t.Zone() - offset %= 60 - if offset != 0 { - // RFC3339Nano already printed the minus sign - if offset < 0 { - offset = -offset - } - - b = append(b, ':') - if offset < 10 { - b = append(b, '0') - } - b = strconv.AppendInt(b, int64(offset), 10) - } - - if bc { - b = append(b, " BC"...) - } - return b + return pqtime.Format(t) } // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. func parseBytea(s []byte) (result []byte, err error) { + // Hex format. if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { - // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { return nil, err } - } else { - // bytea_output = escape - for len(s) > 0 { - if s[0] == '\\' { - // escaped '\\' - if len(s) >= 2 && s[1] == '\\' { - result = append(result, '\\') - s = s[2:] - continue - } - - // '\\' followed by an octal number - if len(s) < 4 { - return nil, fmt.Errorf("invalid bytea sequence %v", s) - } - r, err := strconv.ParseUint(string(s[1:4]), 8, 8) - if err != nil { - return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) - } - result = append(result, byte(r)) - s = s[4:] - } else { - // We hit an unescaped, raw byte. Try to read in as many as - // possible in one go. - i := bytes.IndexByte(s, '\\') - if i == -1 { - result = append(result, s...) - break - } - result = append(result, s[:i]...) - s = s[i:] - } - } + return result, nil } - return result, nil -} + // Escape format. + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } -func encodeBytea(serverVersion int, v []byte) (result []byte) { - if serverVersion >= 90000 { - // Use the hex format if we know that the server supports it - result = make([]byte, 2+hex.EncodedLen(len(v))) - result[0] = '\\' - result[1] = 'x' - hex.Encode(result[2:], v) - } else { - // .. or resort to "escape" - for _, b := range v { - if b == '\\' { - result = append(result, '\\', '\\') - } else if b < 0x20 || b > 0x7e { - result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) - } else { - result = append(result, b) + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseUint(string(s[1:4]), 8, 8) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %w", err) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break } + result = append(result, s[:i]...) + s = s[i:] } } - - return result -} - -// NullTime represents a time.Time that may be null. NullTime implements the -// sql.Scanner interface so it can be used as a scan destination, similar to -// sql.NullString. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) error { - nt.Time, nt.Valid = value.(time.Time) - return nil + return result, nil } -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil +func encodeBytea(v []byte) (result []byte) { + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) + return result } diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index f67c5a5f..7d061875 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -6,362 +6,133 @@ import ( "io" "net" "runtime" -) + "strconv" + "strings" + "unicode/utf8" -// Error severities -const ( - Efatal = "FATAL" - Epanic = "PANIC" - Ewarning = "WARNING" - Enotice = "NOTICE" - Edebug = "DEBUG" - Einfo = "INFO" - Elog = "LOG" + "github.com/lib/pq/pqerror" ) -// Error represents an error communicating with the server. +// Error returned by the PostgreSQL server. +// +// The [Error] method returns the error message and error code: +// +// pq: invalid input syntax for type json (22P02) +// +// The [ErrorWithDetail] method also includes the error Detail, Hint, and +// location context (if any): // -// See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields +// ERROR: invalid input syntax for type json (22P02) +// DETAIL: Token "asd" is invalid. +// CONTEXT: line 5, column 8: +// +// 3 | 'def', +// 4 | 123, +// 5 | 'foo', 'asd'::jsonb +// ^ type Error struct { - Severity string - Code ErrorCode - Message string - Detail string - Hint string - Position string + // [Efatal], [Epanic], [Ewarning], [Enotice], [Edebug], [Einfo], or [Elog]. + // Always present. + Severity string + + // SQLSTATE code. Always present. + Code pqerror.Code + + // Primary human-readable error message. This should be accurate but terse + // (typically one line). Always present. + Message string + + // Optional secondary error message carrying more detail about the problem. + // Might run to multiple lines. + Detail string + + // Optional suggestion what to do about the problem. This is intended to + // differ from Detail in that it offers advice (potentially inappropriate) + // rather than hard facts. Might run to multiple lines. + Hint string + + // error position as an index into the original query string, as decimal + // ASCII integer. The first character has index 1, and positions are + // measured in characters not bytes. + Position string + + // This is defined the same as the Position field, but it is used when the + // cursor position refers to an internally generated command rather than the + // one submitted by the client. The InternalQuery field will always appear + // when this field appears. InternalPosition string - InternalQuery string - Where string - Schema string - Table string - Column string - DataTypeName string - Constraint string - File string - Line string - Routine string -} -// ErrorCode is a five-character error code. -type ErrorCode string + // Text of a failed internally-generated command. This could be, for + // example, an SQL query issued by a PL/pgSQL function. + InternalQuery string -// Name returns a more human friendly rendering of the error code, namely the -// "condition name". -// -// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for -// details. -func (ec ErrorCode) Name() string { - return errorCodeNames[ec] -} + // An indication of the context in which the error occurred. Presently this + // includes a call stack traceback of active procedural language functions + // and internally-generated queries. The trace is one entry per line, most + // recent first. + Where string -// ErrorClass is only the class part of an error code. -type ErrorClass string + // If the error was associated with a specific database object, the name of + // the schema containing that object, if any. + Schema string -// Name returns the condition name of an error class. It is equivalent to the -// condition name of the "standard" error code (i.e. the one having the last -// three characters "000"). -func (ec ErrorClass) Name() string { - return errorCodeNames[ErrorCode(ec+"000")] -} + // If the error was associated with a specific table, the name of the table. + // (Refer to the schema name field for the name of the table's schema.) + Table string -// Class returns the error class, e.g. "28". -// -// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for -// details. -func (ec ErrorCode) Class() ErrorClass { - return ErrorClass(ec[0:2]) -} + // If the error was associated with a specific table column, the name of the + // column. (Refer to the schema and table name fields to identify the + // table.) + Column string + + // If the error was associated with a specific data type, the name of the + // data type. (Refer to the schema name field for the name of the data + // type's schema.) + DataTypeName string + + // If the error was associated with a specific constraint, the name of the + // constraint. Refer to fields listed above for the associated table or + // domain. (For this purpose, indexes are treated as constraints, even if + // they weren't created with constraint syntax.) + Constraint string + + // File name of the source-code location where the error was reported. + File string + + // Line number of the source-code location where the error was reported. + Line string -// errorCodeNames is a mapping between the five-character error codes and the -// human readable "condition names". It is derived from the list at -// http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html -var errorCodeNames = map[ErrorCode]string{ - // Class 00 - Successful Completion - "00000": "successful_completion", - // Class 01 - Warning - "01000": "warning", - "0100C": "dynamic_result_sets_returned", - "01008": "implicit_zero_bit_padding", - "01003": "null_value_eliminated_in_set_function", - "01007": "privilege_not_granted", - "01006": "privilege_not_revoked", - "01004": "string_data_right_truncation", - "01P01": "deprecated_feature", - // Class 02 - No Data (this is also a warning class per the SQL standard) - "02000": "no_data", - "02001": "no_additional_dynamic_result_sets_returned", - // Class 03 - SQL Statement Not Yet Complete - "03000": "sql_statement_not_yet_complete", - // Class 08 - Connection Exception - "08000": "connection_exception", - "08003": "connection_does_not_exist", - "08006": "connection_failure", - "08001": "sqlclient_unable_to_establish_sqlconnection", - "08004": "sqlserver_rejected_establishment_of_sqlconnection", - "08007": "transaction_resolution_unknown", - "08P01": "protocol_violation", - // Class 09 - Triggered Action Exception - "09000": "triggered_action_exception", - // Class 0A - Feature Not Supported - "0A000": "feature_not_supported", - // Class 0B - Invalid Transaction Initiation - "0B000": "invalid_transaction_initiation", - // Class 0F - Locator Exception - "0F000": "locator_exception", - "0F001": "invalid_locator_specification", - // Class 0L - Invalid Grantor - "0L000": "invalid_grantor", - "0LP01": "invalid_grant_operation", - // Class 0P - Invalid Role Specification - "0P000": "invalid_role_specification", - // Class 0Z - Diagnostics Exception - "0Z000": "diagnostics_exception", - "0Z002": "stacked_diagnostics_accessed_without_active_handler", - // Class 20 - Case Not Found - "20000": "case_not_found", - // Class 21 - Cardinality Violation - "21000": "cardinality_violation", - // Class 22 - Data Exception - "22000": "data_exception", - "2202E": "array_subscript_error", - "22021": "character_not_in_repertoire", - "22008": "datetime_field_overflow", - "22012": "division_by_zero", - "22005": "error_in_assignment", - "2200B": "escape_character_conflict", - "22022": "indicator_overflow", - "22015": "interval_field_overflow", - "2201E": "invalid_argument_for_logarithm", - "22014": "invalid_argument_for_ntile_function", - "22016": "invalid_argument_for_nth_value_function", - "2201F": "invalid_argument_for_power_function", - "2201G": "invalid_argument_for_width_bucket_function", - "22018": "invalid_character_value_for_cast", - "22007": "invalid_datetime_format", - "22019": "invalid_escape_character", - "2200D": "invalid_escape_octet", - "22025": "invalid_escape_sequence", - "22P06": "nonstandard_use_of_escape_character", - "22010": "invalid_indicator_parameter_value", - "22023": "invalid_parameter_value", - "2201B": "invalid_regular_expression", - "2201W": "invalid_row_count_in_limit_clause", - "2201X": "invalid_row_count_in_result_offset_clause", - "22009": "invalid_time_zone_displacement_value", - "2200C": "invalid_use_of_escape_character", - "2200G": "most_specific_type_mismatch", - "22004": "null_value_not_allowed", - "22002": "null_value_no_indicator_parameter", - "22003": "numeric_value_out_of_range", - "2200H": "sequence_generator_limit_exceeded", - "22026": "string_data_length_mismatch", - "22001": "string_data_right_truncation", - "22011": "substring_error", - "22027": "trim_error", - "22024": "unterminated_c_string", - "2200F": "zero_length_character_string", - "22P01": "floating_point_exception", - "22P02": "invalid_text_representation", - "22P03": "invalid_binary_representation", - "22P04": "bad_copy_file_format", - "22P05": "untranslatable_character", - "2200L": "not_an_xml_document", - "2200M": "invalid_xml_document", - "2200N": "invalid_xml_content", - "2200S": "invalid_xml_comment", - "2200T": "invalid_xml_processing_instruction", - // Class 23 - Integrity Constraint Violation - "23000": "integrity_constraint_violation", - "23001": "restrict_violation", - "23502": "not_null_violation", - "23503": "foreign_key_violation", - "23505": "unique_violation", - "23514": "check_violation", - "23P01": "exclusion_violation", - // Class 24 - Invalid Cursor State - "24000": "invalid_cursor_state", - // Class 25 - Invalid Transaction State - "25000": "invalid_transaction_state", - "25001": "active_sql_transaction", - "25002": "branch_transaction_already_active", - "25008": "held_cursor_requires_same_isolation_level", - "25003": "inappropriate_access_mode_for_branch_transaction", - "25004": "inappropriate_isolation_level_for_branch_transaction", - "25005": "no_active_sql_transaction_for_branch_transaction", - "25006": "read_only_sql_transaction", - "25007": "schema_and_data_statement_mixing_not_supported", - "25P01": "no_active_sql_transaction", - "25P02": "in_failed_sql_transaction", - // Class 26 - Invalid SQL Statement Name - "26000": "invalid_sql_statement_name", - // Class 27 - Triggered Data Change Violation - "27000": "triggered_data_change_violation", - // Class 28 - Invalid Authorization Specification - "28000": "invalid_authorization_specification", - "28P01": "invalid_password", - // Class 2B - Dependent Privilege Descriptors Still Exist - "2B000": "dependent_privilege_descriptors_still_exist", - "2BP01": "dependent_objects_still_exist", - // Class 2D - Invalid Transaction Termination - "2D000": "invalid_transaction_termination", - // Class 2F - SQL Routine Exception - "2F000": "sql_routine_exception", - "2F005": "function_executed_no_return_statement", - "2F002": "modifying_sql_data_not_permitted", - "2F003": "prohibited_sql_statement_attempted", - "2F004": "reading_sql_data_not_permitted", - // Class 34 - Invalid Cursor Name - "34000": "invalid_cursor_name", - // Class 38 - External Routine Exception - "38000": "external_routine_exception", - "38001": "containing_sql_not_permitted", - "38002": "modifying_sql_data_not_permitted", - "38003": "prohibited_sql_statement_attempted", - "38004": "reading_sql_data_not_permitted", - // Class 39 - External Routine Invocation Exception - "39000": "external_routine_invocation_exception", - "39001": "invalid_sqlstate_returned", - "39004": "null_value_not_allowed", - "39P01": "trigger_protocol_violated", - "39P02": "srf_protocol_violated", - // Class 3B - Savepoint Exception - "3B000": "savepoint_exception", - "3B001": "invalid_savepoint_specification", - // Class 3D - Invalid Catalog Name - "3D000": "invalid_catalog_name", - // Class 3F - Invalid Schema Name - "3F000": "invalid_schema_name", - // Class 40 - Transaction Rollback - "40000": "transaction_rollback", - "40002": "transaction_integrity_constraint_violation", - "40001": "serialization_failure", - "40003": "statement_completion_unknown", - "40P01": "deadlock_detected", - // Class 42 - Syntax Error or Access Rule Violation - "42000": "syntax_error_or_access_rule_violation", - "42601": "syntax_error", - "42501": "insufficient_privilege", - "42846": "cannot_coerce", - "42803": "grouping_error", - "42P20": "windowing_error", - "42P19": "invalid_recursion", - "42830": "invalid_foreign_key", - "42602": "invalid_name", - "42622": "name_too_long", - "42939": "reserved_name", - "42804": "datatype_mismatch", - "42P18": "indeterminate_datatype", - "42P21": "collation_mismatch", - "42P22": "indeterminate_collation", - "42809": "wrong_object_type", - "42703": "undefined_column", - "42883": "undefined_function", - "42P01": "undefined_table", - "42P02": "undefined_parameter", - "42704": "undefined_object", - "42701": "duplicate_column", - "42P03": "duplicate_cursor", - "42P04": "duplicate_database", - "42723": "duplicate_function", - "42P05": "duplicate_prepared_statement", - "42P06": "duplicate_schema", - "42P07": "duplicate_table", - "42712": "duplicate_alias", - "42710": "duplicate_object", - "42702": "ambiguous_column", - "42725": "ambiguous_function", - "42P08": "ambiguous_parameter", - "42P09": "ambiguous_alias", - "42P10": "invalid_column_reference", - "42611": "invalid_column_definition", - "42P11": "invalid_cursor_definition", - "42P12": "invalid_database_definition", - "42P13": "invalid_function_definition", - "42P14": "invalid_prepared_statement_definition", - "42P15": "invalid_schema_definition", - "42P16": "invalid_table_definition", - "42P17": "invalid_object_definition", - // Class 44 - WITH CHECK OPTION Violation - "44000": "with_check_option_violation", - // Class 53 - Insufficient Resources - "53000": "insufficient_resources", - "53100": "disk_full", - "53200": "out_of_memory", - "53300": "too_many_connections", - "53400": "configuration_limit_exceeded", - // Class 54 - Program Limit Exceeded - "54000": "program_limit_exceeded", - "54001": "statement_too_complex", - "54011": "too_many_columns", - "54023": "too_many_arguments", - // Class 55 - Object Not In Prerequisite State - "55000": "object_not_in_prerequisite_state", - "55006": "object_in_use", - "55P02": "cant_change_runtime_param", - "55P03": "lock_not_available", - // Class 57 - Operator Intervention - "57000": "operator_intervention", - "57014": "query_canceled", - "57P01": "admin_shutdown", - "57P02": "crash_shutdown", - "57P03": "cannot_connect_now", - "57P04": "database_dropped", - // Class 58 - System Error (errors external to PostgreSQL itself) - "58000": "system_error", - "58030": "io_error", - "58P01": "undefined_file", - "58P02": "duplicate_file", - // Class F0 - Configuration File Error - "F0000": "config_file_error", - "F0001": "lock_file_exists", - // Class HV - Foreign Data Wrapper Error (SQL/MED) - "HV000": "fdw_error", - "HV005": "fdw_column_name_not_found", - "HV002": "fdw_dynamic_parameter_value_needed", - "HV010": "fdw_function_sequence_error", - "HV021": "fdw_inconsistent_descriptor_information", - "HV024": "fdw_invalid_attribute_value", - "HV007": "fdw_invalid_column_name", - "HV008": "fdw_invalid_column_number", - "HV004": "fdw_invalid_data_type", - "HV006": "fdw_invalid_data_type_descriptors", - "HV091": "fdw_invalid_descriptor_field_identifier", - "HV00B": "fdw_invalid_handle", - "HV00C": "fdw_invalid_option_index", - "HV00D": "fdw_invalid_option_name", - "HV090": "fdw_invalid_string_length_or_buffer_length", - "HV00A": "fdw_invalid_string_format", - "HV009": "fdw_invalid_use_of_null_pointer", - "HV014": "fdw_too_many_handles", - "HV001": "fdw_out_of_memory", - "HV00P": "fdw_no_schemas", - "HV00J": "fdw_option_name_not_found", - "HV00K": "fdw_reply_handle", - "HV00Q": "fdw_schema_not_found", - "HV00R": "fdw_table_not_found", - "HV00L": "fdw_unable_to_create_execution", - "HV00M": "fdw_unable_to_create_reply", - "HV00N": "fdw_unable_to_establish_connection", - // Class P0 - PL/pgSQL Error - "P0000": "plpgsql_error", - "P0001": "raise_exception", - "P0002": "no_data_found", - "P0003": "too_many_rows", - // Class XX - Internal Error - "XX000": "internal_error", - "XX001": "data_corrupted", - "XX002": "index_corrupted", + // Name of the source-code routine reporting the error. + Routine string + + query string } -func parseError(r *readBuf) *Error { - err := new(Error) +type ( + // ErrorCode is a five-character error code. + // + // Deprecated: use pqerror.Code + // + //go:fix inline + ErrorCode = pqerror.Code + + // ErrorClass is only the class part of an error code. + // + // Deprecated: use pqerror.Class + // + //go:fix inline + ErrorClass = pqerror.Class +) + +func parseError(r *readBuf, q string) *Error { + err := &Error{query: q} for t := r.byte(); t != 0; t = r.byte() { msg := r.string() switch t { case 'S': err.Severity = msg case 'C': - err.Code = ErrorCode(msg) + err.Code = pqerror.Code(msg) case 'M': err.Message = msg case 'D': @@ -398,126 +169,156 @@ func parseError(r *readBuf) *Error { } // Fatal returns true if the Error Severity is fatal. -func (err *Error) Fatal() bool { - return err.Severity == Efatal -} +func (e *Error) Fatal() bool { return e.Severity == pqerror.SeverityFatal } // SQLState returns the SQLState of the error. -func (err *Error) SQLState() string { - return string(err.Code) -} +func (e *Error) SQLState() string { return string(e.Code) } -// Get implements the legacy PGError interface. New code should use the fields -// of the Error struct directly. -func (err *Error) Get(k byte) (v string) { - switch k { - case 'S': - return err.Severity - case 'C': - return string(err.Code) - case 'M': - return err.Message - case 'D': - return err.Detail - case 'H': - return err.Hint - case 'P': - return err.Position - case 'p': - return err.InternalPosition - case 'q': - return err.InternalQuery - case 'W': - return err.Where - case 's': - return err.Schema - case 't': - return err.Table - case 'c': - return err.Column - case 'd': - return err.DataTypeName - case 'n': - return err.Constraint - case 'F': - return err.File - case 'L': - return err.Line - case 'R': - return err.Routine +func (e *Error) Error() string { + msg := e.Message + if e.query != "" && e.Position != "" { + pos, err := strconv.Atoi(e.Position) + if err == nil { + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + if len(lines) == 1 { + msg += " at column " + strconv.Itoa(col) + } else { + msg += " at position " + strconv.Itoa(line) + ":" + strconv.Itoa(col) + } + } } - return "" -} -func (err *Error) Error() string { - return "pq: " + err.Message + if e.Code != "" { + return "pq: " + msg + " (" + string(e.Code) + ")" + } + return "pq: " + msg } -// PGError is an interface used by previous versions of pq. It is provided -// only to support legacy code. New code should use the Error type. -type PGError interface { - Error() string - Fatal() bool - Get(k byte) (v string) -} +// ErrorWithDetail returns the error message with detailed information and +// location context (if any). +// +// See the documentation on [Error]. +func (e *Error) ErrorWithDetail() string { + b := new(strings.Builder) + b.Grow(len(e.Message) + len(e.Detail) + len(e.Hint) + 30) + b.WriteString("ERROR: ") + b.WriteString(e.Message) + if e.Code != "" { + b.WriteString(" (") + b.WriteString(string(e.Code)) + b.WriteByte(')') + } + if e.Detail != "" { + b.WriteString("\nDETAIL: ") + b.WriteString(e.Detail) + } + if e.Hint != "" { + b.WriteString("\nHINT: ") + b.WriteString(e.Hint) + } -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} + if e.query != "" && e.Position != "" { + b.Grow(512) + pos, err := strconv.Atoi(e.Position) + if err != nil { + return b.String() + } + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + + fmt.Fprintf(b, "\nCONTEXT: line %d, column %d:\n\n", line, col) + if line > 2 { + fmt.Fprintf(b, "% 7d | %s\n", line-2, expandTab(lines[line-3])) + } + if line > 1 { + fmt.Fprintf(b, "% 7d | %s\n", line-1, expandTab(lines[line-2])) + } + /// Expand tabs, so that the ^ is at at the correct position, but leave + /// "column 10-13" intact. Adjusting this to the visual column would be + /// better, but we don't know the tabsize of the user in their editor, + /// which can be 8, 4, 2, or something else. We can't know. So leaving + /// it as the character index is probably the "most correct". + expanded := expandTab(lines[line-1]) + diff := len(expanded) - len(lines[line-1]) + fmt.Fprintf(b, "% 7d | %s\n", line, expanded) + fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col-1+diff), "^") + } -// TODO(ainar-g) Rename to errorf after removing panics. -func fmterrorf(s string, args ...interface{}) error { - return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) + return b.String() } -func errRecoverNoErrBadConn(err *error) { - e := recover() - if e == nil { - // Do nothing - return +func posToLine(pos int, lines []string) (line, col int) { + read := 0 + for i := range lines { + line++ + ll := utf8.RuneCountInString(lines[i]) + 1 // +1 for the removed newline + if read+ll >= pos { + col = max(pos-read, 1) // Should be lower than 1, but just in case. + break + } + read += ll } - var ok bool - *err, ok = e.(error) - if !ok { - *err = fmt.Errorf("pq: unexpected error: %#v", e) + return line, col +} + +func expandTab(s string) string { + var ( + b strings.Builder + l int + fill = func(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = ' ' + } + return string(b) + } + ) + b.Grow(len(s)) + for _, r := range s { + switch r { + case '\t': + tw := 8 - l%8 + b.WriteString(fill(tw)) + l += tw + default: + b.WriteRune(r) + l += 1 + } } + return b.String() } -func (cn *conn) errRecover(err *error) { - e := recover() - switch v := e.(type) { +func (cn *conn) handleError(reported error, query ...string) error { + switch err := reported.(type) { case nil: - // Do nothing - case runtime.Error: + return nil + case runtime.Error, *net.OpError: cn.err.set(driver.ErrBadConn) - panic(v) - case *Error: - if v.Fatal() { - *err = driver.ErrBadConn - } else { - *err = v - } - case *net.OpError: - cn.err.set(driver.ErrBadConn) - *err = v case *safeRetryError: cn.err.set(driver.ErrBadConn) - *err = driver.ErrBadConn + reported = driver.ErrBadConn + case *Error: + if len(query) > 0 && query[0] != "" { + err.query = query[0] + reported = err + } + if err.Fatal() { + reported = driver.ErrBadConn + } case error: - if v == io.EOF || v.Error() == "remote error: handshake failure" { - *err = driver.ErrBadConn - } else { - *err = v + if err == io.EOF || err == io.ErrUnexpectedEOF || err.Error() == "remote error: handshake failure" { + reported = driver.ErrBadConn } - default: cn.err.set(driver.ErrBadConn) - panic(fmt.Sprintf("unknown error: %#v", e)) + reported = fmt.Errorf("pq: unknown error %T: %[1]s", err) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. - if *err == driver.ErrBadConn { + if reported == driver.ErrBadConn { cn.err.set(driver.ErrBadConn) } + return reported } diff --git a/vendor/github.com/lib/pq/internal/pgpass/pgpass.go b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go new file mode 100644 index 00000000..4da35385 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go @@ -0,0 +1,70 @@ +package pgpass + +import ( + "bufio" + "os" + "path/filepath" + "strings" + + "github.com/lib/pq/internal/pqutil" +) + +func PasswordFromPgpass(passfile, user, password, host, port, dbname string) string { + if password != "" { // Do not process .pgpass if a password was supplied. + return password + } + + filename := pqutil.Pgpass(passfile) + if filename == "" { + return "" + } + + fp, err := os.Open(filename) + if err != nil { + return "" + } + defer fp.Close() + + scan := bufio.NewScanner(fp) + for scan.Scan() { + line := scan.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := splitFields(line) + if len(split) != 5 { + continue + } + + socket := host == "" || filepath.IsAbs(host) || strings.HasPrefix(host, "@") + if (split[0] == "*" || split[0] == host || (split[0] == "localhost" && socket)) && + (split[1] == "*" || split[1] == port) && + (split[2] == "*" || split[2] == dbname) && + (split[3] == "*" || split[3] == user) { + return split[4] + } + } + + return "" +} + +func splitFields(s string) []string { + var ( + fs = make([]string, 0, 5) + f = make([]rune, 0, len(s)) + esc bool + ) + for _, c := range s { + switch { + case esc: + f, esc = append(f, c), false + case c == '\\': + esc = true + case c == ':': + fs, f = append(fs, string(f)), f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) +} diff --git a/vendor/github.com/lib/pq/internal/pgservice/pgservice.go b/vendor/github.com/lib/pq/internal/pgservice/pgservice.go new file mode 100644 index 00000000..9842648c --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pgservice/pgservice.go @@ -0,0 +1,70 @@ +package pgservice + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/lib/pq/internal/pqutil" +) + +func FindService(path string, service string) (map[string]string, error) { + fp, err := os.Open(path) + if err != nil { + if pqutil.ErrNotExists(err) { + // libpq just returns "definition of service not found" if the + // default file doesn't exist, but IMO that's confusing. + return nil, fmt.Errorf("service file %q not found", path) + } + return nil, err + } + defer fp.Close() + + var ( + scan = bufio.NewScanner(fp) + i int + ) + for scan.Scan() { + i++ + line := strings.TrimSpace(scan.Text()) + if line == "" || line[0] == '#' { + continue + } + + // [service] header that we want. + if line[0] == '[' && line[len(line)-1] == ']' && strings.TrimSpace(line[1:len(line)-1]) == service { + opts := make(map[string]string) + for scan.Scan() { + i++ + line := strings.TrimSpace(scan.Text()) + if line == "" || line[0] == '#' { + continue + } + // Next header: our work here is done. + if line[0] == '[' && line[len(line)-1] == ']' { + return opts, nil + } + + k, v, ok := strings.Cut(line, "=") + if !ok { + return nil, fmt.Errorf("line %d: missing '=' in %q", i, line) + } + k, v = strings.TrimSpace(k), strings.TrimSpace(v) + if k == "" { + return nil, fmt.Errorf("line %d: no value before '=' in %q", i, line) + } + opts[k] = v + } + if scan.Err() != nil { + return nil, scan.Err() + } + return opts, nil + } + } + if scan.Err() != nil { + return nil, scan.Err() + } + + return nil, fmt.Errorf("definition of service %q not found", service) +} diff --git a/vendor/github.com/lib/pq/internal/pqsql/copy.go b/vendor/github.com/lib/pq/internal/pqsql/copy.go new file mode 100644 index 00000000..ccb688f6 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqsql/copy.go @@ -0,0 +1,37 @@ +package pqsql + +// StartsWithCopy reports if the SQL strings start with "copy", ignoring +// whitespace, comments, and casing. +func StartsWithCopy(query string) bool { + if len(query) < 4 { + return false + } + var linecmt, blockcmt bool + for i := 0; i < len(query); i++ { + c := query[i] + if linecmt { + linecmt = c != '\n' + continue + } + if blockcmt { + blockcmt = !(c == '/' && query[i-1] == '*') + continue + } + if c == '-' && len(query) > i+1 && query[i+1] == '-' { + linecmt = true + continue + } + if c == '/' && len(query) > i+1 && query[i+1] == '*' { + blockcmt = true + continue + } + if c == ' ' || c == '\t' || c == '\r' || c == '\n' { + continue + } + + // First non-comment and non-whitespace. + return len(query) > i+3 && c|0x20 == 'c' && query[i+1]|0x20 == 'o' && + query[i+2]|0x20 == 'p' && query[i+3]|0x20 == 'y' + } + return false +} diff --git a/vendor/github.com/lib/pq/internal/pqtime/loc.go b/vendor/github.com/lib/pq/internal/pqtime/loc.go new file mode 100644 index 00000000..d23dd5b0 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqtime/loc.go @@ -0,0 +1,37 @@ +package pqtime + +import ( + "sync" + "time" +) + +// The location cache caches the time zones typically used by the client. +type locationCache struct { + cache map[int]*time.Location + lock sync.Mutex +} + +// All connections share the same list of timezones. Benchmarking shows that +// about 5% speed could be gained by putting the cache in the connection and +// losing the mutex, at the cost of a small amount of memory and a somewhat +// significant increase in code complexity. +var globalLocationCache = &locationCache{cache: make(map[int]*time.Location)} + +func Reset() { + globalLocationCache = &locationCache{cache: make(map[int]*time.Location)} +} + +// Returns the cached timezone for the specified offset, creating and caching +// it if necessary. +func (c *locationCache) getLocation(offset int) *time.Location { + c.lock.Lock() + defer c.lock.Unlock() + l, ok := c.cache[offset] + if !ok { + // TODO(v2): for offset=0 it should use some descriptive text like + // "without time zone". + l = time.FixedZone("", offset) + c.cache[offset] = l + } + return l +} diff --git a/vendor/github.com/lib/pq/internal/pqtime/pqtime.go b/vendor/github.com/lib/pq/internal/pqtime/pqtime.go new file mode 100644 index 00000000..28008e86 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqtime/pqtime.go @@ -0,0 +1,190 @@ +package pqtime + +import ( + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" +) + +var errInvalidTimestamp = errors.New("invalid timestamp") + +type timestampParser struct { + err error +} + +func (p *timestampParser) expect(str string, char byte, pos int) { + if p.err != nil { + return + } + if pos+1 > len(str) { + p.err = errInvalidTimestamp + return + } + if c := str[pos]; c != char && p.err == nil { + p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) + } +} + +func (p *timestampParser) mustAtoi(str string, begin int, end int) int { + if p.err != nil { + return 0 + } + if begin < 0 || end < 0 || begin > end || end > len(str) { + p.err = errInvalidTimestamp + return 0 + } + result, err := strconv.Atoi(str[begin:end]) + if err != nil { + if p.err == nil { + p.err = fmt.Errorf("expected number; got '%v'", str) + } + return 0 + } + return result +} + +func Parse(currentLocation *time.Location, str string) (time.Time, error) { + p := timestampParser{} + + monSep := strings.IndexRune(str, '-') + // this is Gregorian year, not ISO Year + // In Gregorian system, the year 1 BC is followed by AD 1 + year := p.mustAtoi(str, 0, monSep) + daySep := monSep + 3 + month := p.mustAtoi(str, monSep+1, daySep) + p.expect(str, '-', daySep) + timeSep := daySep + 3 + day := p.mustAtoi(str, daySep+1, timeSep) + + minLen := monSep + len("01-01") + 1 + + isBC := strings.HasSuffix(str, " BC") + if isBC { + minLen += 3 + } + + var hour, minute, second int + if len(str) > minLen { + p.expect(str, ' ', timeSep) + minSep := timeSep + 3 + p.expect(str, ':', minSep) + hour = p.mustAtoi(str, timeSep+1, minSep) + secSep := minSep + 3 + p.expect(str, ':', secSep) + minute = p.mustAtoi(str, minSep+1, secSep) + secEnd := secSep + 3 + second = p.mustAtoi(str, secSep+1, secEnd) + } + remainderIdx := monSep + len("01-01 00:00:00") + 1 + // Three optional (but ordered) sections follow: the + // fractional seconds, the time zone offset, and the BC + // designation. We set them up here and adjust the other + // offsets if the preceding sections exist. + + nanoSec := 0 + tzOff := 0 + + if remainderIdx < len(str) && str[remainderIdx] == '.' { + fracStart := remainderIdx + 1 + fracOff := strings.IndexAny(str[fracStart:], "-+Z ") + if fracOff < 0 { + fracOff = len(str) - fracStart + } + fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) + nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) + + remainderIdx += fracOff + 1 + } + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { + // time zone separator is always '-' or '+' or 'Z' (UTC is +00) + var tzSign int + switch c := str[tzStart]; c { + case '-': + tzSign = -1 + case '+': + tzSign = +1 + default: + return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + } + tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) + remainderIdx += 3 + var tzMin, tzSec int + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) + remainderIdx += 3 + } + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) + remainderIdx += 3 + } + tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) + } else if tzStart < len(str) && str[tzStart] == 'Z' { + // time zone Z separator indicates UTC is +00 + remainderIdx += 1 + } + + var isoYear int + + if isBC { + isoYear = 1 - year + remainderIdx += 3 + } else { + isoYear = year + } + if remainderIdx < len(str) { + return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) + } + t := time.Date(isoYear, time.Month(month), day, + hour, minute, second, nanoSec, + globalLocationCache.getLocation(tzOff)) + + if currentLocation != nil { + // Set the location of the returned Time based on the session's + // TimeZone value, but only if the local time zone database agrees with + // the remote database on the offset. + lt := t.In(currentLocation) + _, newOff := lt.Zone() + if newOff == tzOff { + t = lt + } + } + + return t, p.err +} + +// Format into Postgres' text format for timestamps. +func Format(t time.Time) []byte { + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) + + _, offset := t.Zone() + offset %= 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/path.go b/vendor/github.com/lib/pq/internal/pqutil/path.go new file mode 100644 index 00000000..dd0d5af0 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/path.go @@ -0,0 +1,91 @@ +package pqutil + +import ( + "errors" + "fmt" + "io" + "os" + "os/user" + "path/filepath" + "runtime" + "syscall" +) + +// Home gets the PostgreSQL configuration dir in the user's home directory: +// %APPDATA%/postgresql on Windows, and $HOME/.postgresql/postgresql.crt +// everywhere else. +// +// Returns an empy string if no home directory was found. +// +// Matches pqGetHomeDirectory() from PostgreSQL. +// https://github.com/postgres/postgres/blob/2b117bb/src/interfaces/libpq/fe-connect.c#L8214 +func Home(subdir bool) string { + if runtime.GOOS == "windows" { + // pq uses SHGetFolderPath(), which is deprecated but x/sys/windows has + // KnownFolderPath(). We don't really want to pull that in though, so + // use APPDATA env. This is also what PostgreSQL uses in some other + // codepaths (get_home_path() for example). + ad := os.Getenv("APPDATA") + if ad == "" { + return "" + } + return filepath.Join(ad, "postgresql") + } + + home, _ := os.UserHomeDir() + if home == "" { + u, err := user.Current() + if err != nil { + return "" + } + home = u.HomeDir + } + // libpq reads some files from ~/ and some from ~/.postgresql – on Windows + // it always uses %APPDATA%/postgresql. + if subdir { + home = filepath.Join(home, ".postgresql") + } + return home +} + +// ErrNotExists reports if err is a "path doesn't exist" type error. +// +// fs.ErrNotExist is not enough, as "/dev/null/somefile" will return ENOTDIR +// instead of ENOENT. +func ErrNotExists(err error) bool { + perr := new(os.PathError) + if errors.As(err, &perr) && (perr.Err == syscall.ENOENT || perr.Err == syscall.ENOTDIR) { + return true + } + return false +} + +var WarnFD io.Writer = os.Stderr + +// Pgpass gets the filepath to the pgpass file to use, returning "" if a pgpass +// file shouldn't be used. +func Pgpass(passfile string) string { + // Get passfile from the options. + if passfile == "" { + home := Home(false) + if home == "" { + return "" + } + passfile = filepath.Join(home, ".pgpass") + } + + // On Win32, the directory is protected, so we don't have to check the file. + if runtime.GOOS != "windows" { + fi, err := os.Stat(passfile) + if err != nil { + return "" + } + if fi.Mode().Perm()&(0x77) != 0 { + fmt.Fprintf(WarnFD, + "WARNING: password file %q has group or world access; permissions should be u=rw (0600) or less\n", + passfile) + return "" + } + } + return passfile +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm.go b/vendor/github.com/lib/pq/internal/pqutil/perm.go new file mode 100644 index 00000000..05fb9a6a --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm.go @@ -0,0 +1,64 @@ +//go:build !windows && !plan9 + +package pqutil + +import ( + "errors" + "os" + "syscall" +) + +var ( + ErrSSLKeyUnknownOwnership = errors.New("pq: could not get owner information for private key, may not be properly protected") + ErrSSLKeyHasWorldPermissions = errors.New("pq: private key has world access; permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") +) + +// SSLKeyPermissions checks the permissions on user-supplied SSL key files, +// which should have very little access. libpq does not check key file +// permissions on Windows. +// +// If the file is owned by the same user the process is running as, the file +// should only have 0600. If the file is owned by root, and the group matches +// the group that the process is running in, the permissions cannot be more than +// 0640. The file should never have world permissions. +// +// Returns an error when the permission check fails. +func SSLKeyPermissions(sslkey string) error { + fi, err := os.Stat(sslkey) + if err != nil { + return err + } + + return CheckPermissions(fi) +} + +func CheckPermissions(fi os.FileInfo) error { + // The maximum permissions that a private key file owned by a regular user + // is allowed to have. This translates to u=rw. Regardless of if we're + // running as root or not, 0600 is acceptable, so we return if no bits + // beyond the regular user permission mask are set. + if fi.Mode().Perm()&^os.FileMode(0o600) == 0 { + return nil + } + + // We need to pull the Unix file information to get the file's owner. + // If we can't access it, there's some sort of operating system level error + // and we should fail rather than attempting to use faulty information. + sys, ok := fi.Sys().(*syscall.Stat_t) + if !ok { + return ErrSSLKeyUnknownOwnership + } + + // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what + // Postgres does. + if sys.Uid == 0 { + // The maximum permissions that a private key file owned by root is + // allowed to have. This translates to u=rw,g=r. + if fi.Mode().Perm()&^os.FileMode(0o640) != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil + } + + return ErrSSLKeyHasWorldPermissions +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go new file mode 100644 index 00000000..3ce75957 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go @@ -0,0 +1,12 @@ +//go:build windows || plan9 + +package pqutil + +import "errors" + +var ( + ErrSSLKeyUnknownOwnership = errors.New("unused") + ErrSSLKeyHasWorldPermissions = errors.New("unused") +) + +func SSLKeyPermissions(sslkey string) error { return nil } diff --git a/vendor/github.com/lib/pq/internal/pqutil/pqutil.go b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go new file mode 100644 index 00000000..ca869e9c --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go @@ -0,0 +1,32 @@ +package pqutil + +import ( + "strconv" + "strings" +) + +// ParseBool is like strconv.ParseBool, but also accepts "yes"/"no" and +// "on"/"off". +func ParseBool(str string) (bool, error) { + switch str { + case "1", "t", "T", "true", "TRUE", "True", "yes", "on": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "no", "off": + return false, nil + } + return false, &strconv.NumError{Func: "ParseBool", Num: str, Err: strconv.ErrSyntax} +} + +func Join[S ~[]E, E ~string](s S) string { + var b strings.Builder + for i := range s { + if i > 0 { + b.WriteString(", ") + } + if i == len(s)-1 { + b.WriteString("or ") + } + b.WriteString(string(s[i])) + } + return b.String() +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_other.go b/vendor/github.com/lib/pq/internal/pqutil/user_other.go new file mode 100644 index 00000000..09e4f8df --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_other.go @@ -0,0 +1,9 @@ +//go:build js || android || hurd || zos || wasip1 || appengine + +package pqutil + +import "errors" + +func User() (string, error) { + return "", errors.New("pqutil.User: not supported on current platform") +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_posix.go b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go new file mode 100644 index 00000000..bd0ece6d --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go @@ -0,0 +1,25 @@ +//go:build !windows && !js && !android && !hurd && !zos && !wasip1 && !appengine + +package pqutil + +import ( + "os" + "os/user" + "runtime" +) + +func User() (string, error) { + env := "USER" + if runtime.GOOS == "plan9" { + env = "user" + } + if n := os.Getenv(env); n != "" { + return n, nil + } + + u, err := user.Current() + if err != nil { + return "", err + } + return u.Username, nil +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_windows.go b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go new file mode 100644 index 00000000..960cb805 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go @@ -0,0 +1,28 @@ +//go:build windows && !appengine + +package pqutil + +import ( + "path/filepath" + "syscall" +) + +func User() (string, error) { + // Perform Windows user name lookup identically to libpq. + // + // The PostgreSQL code makes use of the legacy Win32 function GetUserName, + // and that function has not been imported into stock Go. GetUserNameEx is + // available though, the difference being that a wider range of names are + // available. To get the output to be the same as GetUserName, only the + // base (or last) component of the result is returned. + var ( + name = make([]uint16, 128) + pwnameSz = uint32(len(name)) - 1 + ) + err := syscall.GetUserNameEx(syscall.NameSamCompatible, &name[0], &pwnameSz) + if err != nil { + return "", err + } + s := syscall.UTF16ToString(name) + return filepath.Base(s), nil +} diff --git a/vendor/github.com/lib/pq/internal/proto/proto.go b/vendor/github.com/lib/pq/internal/proto/proto.go new file mode 100644 index 00000000..e8b4bc59 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/proto.go @@ -0,0 +1,186 @@ +// From src/include/libpq/protocol.h and src/include/libpq/pqcomm.h – PostgreSQL 18.1 + +package proto + +import ( + "fmt" + "strconv" +) + +// Constants from pqcomm.h +const ( + ProtocolVersion30 = (3 << 16) | 0 //lint:ignore SA4016 x + ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18. + CancelRequestCode = (1234 << 16) | 5678 + NegotiateSSLCode = (1234 << 16) | 5679 + NegotiateGSSCode = (1234 << 16) | 5680 +) + +// Constants from fe-connect.c +const ( + MaxErrlen = 30_000 // https://github.com/postgres/postgres/blob/c6a10a89f/src/interfaces/libpq/fe-connect.c#L4067 +) + +// RequestCode is a request codes sent by the frontend. +type RequestCode byte + +// These are the request codes sent by the frontend. +const ( + Bind = RequestCode('B') + Close = RequestCode('C') + Describe = RequestCode('D') + Execute = RequestCode('E') + FunctionCall = RequestCode('F') + Flush = RequestCode('H') + Parse = RequestCode('P') + Query = RequestCode('Q') + Sync = RequestCode('S') + Terminate = RequestCode('X') + CopyFail = RequestCode('f') + GSSResponse = RequestCode('p') + PasswordMessage = RequestCode('p') + SASLInitialResponse = RequestCode('p') + SASLResponse = RequestCode('p') + CopyDoneRequest = RequestCode('c') + CopyDataRequest = RequestCode('d') +) + +func (r RequestCode) String() string { + s, ok := map[RequestCode]string{ + Bind: "Bind", + Close: "Close", + Describe: "Describe", + Execute: "Execute", + FunctionCall: "FunctionCall", + Flush: "Flush", + Parse: "Parse", + Query: "Query", + Sync: "Sync", + Terminate: "Terminate", + CopyFail: "CopyFail", + // These are all the same :-/ + //GSSResponse: "GSSResponse", + PasswordMessage: "PasswordMessage", + //SASLInitialResponse: "SASLInitialResponse", + //SASLResponse: "SASLResponse", + CopyDoneRequest: "CopyDone", + CopyDataRequest: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// ResponseCode is a response codes sent by the backend. +type ResponseCode byte + +// These are the response codes sent by the backend. +const ( + ParseComplete = ResponseCode('1') + BindComplete = ResponseCode('2') + CloseComplete = ResponseCode('3') + NotificationResponse = ResponseCode('A') + CommandComplete = ResponseCode('C') + DataRow = ResponseCode('D') + ErrorResponse = ResponseCode('E') + CopyInResponse = ResponseCode('G') + CopyOutResponse = ResponseCode('H') + EmptyQueryResponse = ResponseCode('I') + BackendKeyData = ResponseCode('K') + NoticeResponse = ResponseCode('N') + AuthenticationRequest = ResponseCode('R') + ParameterStatus = ResponseCode('S') + RowDescription = ResponseCode('T') + FunctionCallResponse = ResponseCode('V') + CopyBothResponse = ResponseCode('W') + ReadyForQuery = ResponseCode('Z') + NoData = ResponseCode('n') + PortalSuspended = ResponseCode('s') + ParameterDescription = ResponseCode('t') + NegotiateProtocolVersion = ResponseCode('v') + CopyDoneResponse = ResponseCode('c') + CopyDataResponse = ResponseCode('d') +) + +func (r ResponseCode) String() string { + s, ok := map[ResponseCode]string{ + ParseComplete: "ParseComplete", + BindComplete: "BindComplete", + CloseComplete: "CloseComplete", + NotificationResponse: "NotificationResponse", + CommandComplete: "CommandComplete", + DataRow: "DataRow", + ErrorResponse: "ErrorResponse", + CopyInResponse: "CopyInResponse", + CopyOutResponse: "CopyOutResponse", + EmptyQueryResponse: "EmptyQueryResponse", + BackendKeyData: "BackendKeyData", + NoticeResponse: "NoticeResponse", + AuthenticationRequest: "AuthRequest", + ParameterStatus: "ParamStatus", + RowDescription: "RowDescription", + FunctionCallResponse: "FunctionCallResponse", + CopyBothResponse: "CopyBothResponse", + ReadyForQuery: "ReadyForQuery", + NoData: "NoData", + PortalSuspended: "PortalSuspended", + ParameterDescription: "ParamDescription", + NegotiateProtocolVersion: "NegotiateProtocolVersion", + CopyDoneResponse: "CopyDone", + CopyDataResponse: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// AuthCode are authentication request codes sent by the backend. +type AuthCode int32 + +// These are the authentication request codes sent by the backend. +const ( + AuthReqOk = AuthCode(0) // User is authenticated + AuthReqKrb4 = AuthCode(1) // Kerberos V4. Not supported any more. + AuthReqKrb5 = AuthCode(2) // Kerberos V5. Not supported any more. + AuthReqPassword = AuthCode(3) // Password + AuthReqCrypt = AuthCode(4) // crypt password. Not supported any more. + AuthReqMD5 = AuthCode(5) // md5 password + _ = AuthCode(6) // 6 is available. It was used for SCM creds, not supported any more. + AuthReqGSS = AuthCode(7) // GSSAPI without wrap() + AuthReqGSSCont = AuthCode(8) // Continue GSS exchanges + AuthReqSSPI = AuthCode(9) // SSPI negotiate without wrap() + AuthReqSASL = AuthCode(10) // Begin SASL authentication + AuthReqSASLCont = AuthCode(11) // Continue SASL authentication + AuthReqSASLFin = AuthCode(12) // Final SASL message +) + +func (a AuthCode) String() string { + s, ok := map[AuthCode]string{ + AuthReqOk: "ok", + AuthReqKrb4: "krb4", + AuthReqKrb5: "krb5", + AuthReqPassword: "password", + AuthReqCrypt: "crypt", + AuthReqMD5: "md5", + AuthReqGSS: "GDD", + AuthReqGSSCont: "GSSCont", + AuthReqSSPI: "SSPI", + AuthReqSASL: "SASL", + AuthReqSASLCont: "SASLCont", + AuthReqSASLFin: "SASLFin", + }[a] + if !ok { + s = "" + } + return s + " (" + strconv.Itoa(int(a)) + ")" +} diff --git a/vendor/github.com/lib/pq/internal/proto/sz_32.go b/vendor/github.com/lib/pq/internal/proto/sz_32.go new file mode 100644 index 00000000..68065591 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_32.go @@ -0,0 +1,7 @@ +//go:build 386 || arm || mips || mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxInt diff --git a/vendor/github.com/lib/pq/internal/proto/sz_64.go b/vendor/github.com/lib/pq/internal/proto/sz_64.go new file mode 100644 index 00000000..2b8ad897 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_64.go @@ -0,0 +1,7 @@ +//go:build !386 && !arm && !mips && !mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxUint32 diff --git a/vendor/github.com/lib/pq/notice.go b/vendor/github.com/lib/pq/notice.go index 70ad122a..7b9ff392 100644 --- a/vendor/github.com/lib/pq/notice.go +++ b/vendor/github.com/lib/pq/notice.go @@ -1,6 +1,3 @@ -//go:build go1.10 -// +build go1.10 - package pq import ( @@ -10,7 +7,7 @@ import ( // NoticeHandler returns the notice handler on the given connection, if any. A // runtime panic occurs if c is not a pq connection. This is rarely used -// directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead. +// directly, use [ConnectorNoticeHandler] and [ConnectorWithNoticeHandler] instead. func NoticeHandler(c driver.Conn) func(*Error) { return c.(*conn).noticeHandler } @@ -18,7 +15,7 @@ func NoticeHandler(c driver.Conn) func(*Error) { // SetNoticeHandler sets the given notice handler on the given connection. A // runtime panic occurs if c is not a pq connection. A nil handler may be used // to unset it. This is rarely used directly, use ConnectorNoticeHandler and -// ConnectorWithNoticeHandler instead. +// [ConnectorWithNoticeHandler] instead. // // Note: Notice handlers are executed synchronously by pq meaning commands // won't continue to be processed until the handler returns. @@ -44,7 +41,7 @@ func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, erro } // ConnectorNoticeHandler returns the currently set notice handler, if any. If -// the given connector is not a result of ConnectorWithNoticeHandler, nil is +// the given connector is not a result of [ConnectorWithNoticeHandler], nil is // returned. func ConnectorNoticeHandler(c driver.Connector) func(*Error) { if c, ok := c.(*NoticeHandlerConnector); ok { diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index 5c421fdb..4f4c4227 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -1,33 +1,29 @@ package pq -// Package pq is a pure Go Postgres driver for the database/sql package. -// This module contains support for Postgres LISTEN/NOTIFY. - import ( "context" "database/sql/driver" "errors" "fmt" + "net" "sync" "sync/atomic" "time" + + "github.com/lib/pq/internal/proto" ) // Notification represents a single notification from the database. type Notification struct { - // Process ID (PID) of the notifying postgres backend. - BePid int - // Name of the channel the notification was sent on. - Channel string - // Payload, or the empty string if unspecified. - Extra string + BePid int // Process ID (PID) of the notifying postgres backend. + Channel string // Name of the channel the notification was sent on. + Extra string // Payload, or the empty string if unspecified. } func recvNotification(r *readBuf) *Notification { bePid := r.int32() channel := r.string() extra := r.string() - return &Notification{bePid, channel, extra} } @@ -41,8 +37,8 @@ func SetNotificationHandler(c driver.Conn, handler func(*Notification)) { c.(*conn).notificationHandler = handler } -// NotificationHandlerConnector wraps a regular connector and sets a notification handler -// on it. +// NotificationHandlerConnector wraps a regular connector and sets a +// notification handler on it. type NotificationHandlerConnector struct { driver.Connector notificationHandler func(*Notification) @@ -58,9 +54,9 @@ func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn return c, err } -// ConnectorNotificationHandler returns the currently set notification handler, if any. If -// the given connector is not a result of ConnectorWithNotificationHandler, nil is -// returned. +// ConnectorNotificationHandler returns the currently set notification handler, +// if any. If the given connector is not a result of +// [ConnectorWithNotificationHandler], nil is returned. func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { if c, ok := c.(*NotificationHandlerConnector); ok { return c.notificationHandler @@ -68,11 +64,11 @@ func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { return nil } -// ConnectorWithNotificationHandler creates or sets the given handler for the given -// connector. If the given connector is a result of calling this function +// ConnectorWithNotificationHandler creates or sets the given handler for the +// given connector. If the given connector is a result of calling this function // previously, it is simply set on the given connector and returned. Otherwise, -// this returns a new connector wrapping the given one and setting the notification -// handler. A nil notification handler may be used to unset it. +// this returns a new connector wrapping the given one and setting the +// notification handler. A nil notification handler may be used to unset it. // // The returned connector is intended to be used with database/sql.OpenDB. // @@ -93,28 +89,22 @@ const ( ) type message struct { - typ byte + typ proto.ResponseCode err error } var errListenerConnClosed = errors.New("pq: ListenerConn has been closed") -// ListenerConn is a low-level interface for waiting for notifications. You -// should use Listener instead. +// ListenerConn is a low-level interface for waiting for notifications. You +// should use [Listener] instead. type ListenerConn struct { - // guards cn and err - connectionLock sync.Mutex - cn *conn - err error - - connState int32 - - // the sending goroutine will be holding this lock - senderLock sync.Mutex - + connectionLock sync.Mutex // guards cn and err + senderLock sync.Mutex // the sending goroutine will be holding this lock + cn *conn + err error + connState int32 notificationChan chan<- *Notification - - replyChan chan message + replyChan chan message } // NewListenerConn creates a new ListenerConn. Use NewListener instead. @@ -136,7 +126,6 @@ func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*Listen } go l.listenerConnMain() - return l, nil } @@ -164,7 +153,7 @@ func (l *ListenerConn) releaseSenderLock() { l.senderLock.Unlock() } -// setState advances the protocol state to newState. Returns false if moving +// setState advances the protocol state to newState. Returns false if moving // to that state from the current state is not allowed. func (l *ListenerConn) setState(newState int32) bool { var expectedState int32 @@ -185,12 +174,10 @@ func (l *ListenerConn) setState(newState int32) bool { // Main logic is here: receive messages from the postgres backend, forward // notifications and query replies and keep the internal state in sync with the -// protocol state. Returns when the connection has been lost, is about to go +// protocol state. Returns when the connection has been lost, is about to go // away or should be discarded because we couldn't agree on the state with the // server backend. func (l *ListenerConn) listenerConnLoop() (err error) { - defer errRecoverNoErrBadConn(&err) - r := &readBuf{} for { t, err := l.cn.recvMessage(r) @@ -199,43 +186,43 @@ func (l *ListenerConn) listenerConnLoop() (err error) { } switch t { - case 'A': + case proto.NotificationResponse: // recvNotification copies all the data so we don't need to worry // about the scratch buffer being overwritten. l.notificationChan <- recvNotification(r) - case 'T', 'D': + case proto.RowDescription, proto.DataRow: // only used by tests; ignore - case 'E': + case proto.ErrorResponse: // We might receive an ErrorResponse even when not in a query; it // is expected that the server will close the connection after // that, but we should make sure that the error we display is the // one from the stray ErrorResponse, not io.ErrUnexpectedEOF. if !l.setState(connStateExpectReadyForQuery) { - return parseError(r) + return parseError(r, "") } - l.replyChan <- message{t, parseError(r)} + l.replyChan <- message{t, parseError(r, "")} - case 'C', 'I': + case proto.CommandComplete, proto.EmptyQueryResponse: if !l.setState(connStateExpectReadyForQuery) { // protocol out of sync return fmt.Errorf("unexpected CommandComplete") } // ExecSimpleQuery doesn't need to know about this message - case 'Z': + case proto.ReadyForQuery: if !l.setState(connStateIdle) { // protocol out of sync return fmt.Errorf("unexpected ReadyForQuery") } l.replyChan <- message{t, nil} - case 'S': + case proto.ParameterStatus: // ignore - case 'N': + case proto.NoticeResponse: if n := l.cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } default: return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) @@ -244,25 +231,25 @@ func (l *ListenerConn) listenerConnLoop() (err error) { } // This is the main routine for the goroutine receiving on the database -// connection. Most of the main logic is in listenerConnLoop. +// connection. Most of the main logic is in listenerConnLoop. func (l *ListenerConn) listenerConnMain() { err := l.listenerConnLoop() // listenerConnLoop terminated; we're done, but we still have to clean up. // Make sure nobody tries to start any new queries by making sure the err - // pointer is set. It is important that we do not overwrite its value; a - // connection could be closed by either this goroutine or one sending on - // the connection -- whoever closes the connection is assumed to have the - // more meaningful error message (as the other one will probably get + // pointer is set. It is important that we do not overwrite its value; a + // connection could be closed by either this goroutine or one sending on the + // connection – whoever closes the connection is assumed to have the more + // meaningful error message (as the other one will probably get // net.errClosed), so that goroutine sets the error we expose while the - // other error is discarded. If the connection is lost while two - // goroutines are operating on the socket, it probably doesn't matter which - // error we expose so we don't try to do anything more complex. + // other error is discarded. If the connection is lost while two goroutines + // are operating on the socket, it probably doesn't matter which error we + // expose so we don't try to do anything more complex. l.connectionLock.Lock() if l.err == nil { l.err = err } - l.cn.Close() + _ = l.cn.Close() l.connectionLock.Unlock() // There might be a query in-flight; make sure nobody's waiting for a @@ -290,30 +277,27 @@ func (l *ListenerConn) UnlistenAll() (bool, error) { return l.ExecSimpleQuery("UNLISTEN *") } -// Ping the remote server to make sure it's alive. Non-nil error means the +// Ping the remote server to make sure it's alive. Non-nil error means the // connection has failed and should be abandoned. func (l *ListenerConn) Ping() error { sent, err := l.ExecSimpleQuery("") if !sent { return err } - if err != nil { - // shouldn't happen + if err != nil { // shouldn't happen panic(err) } return nil } -// Attempt to send a query on the connection. Returns an error if sending the -// query failed, and the caller should initiate closure of this connection. -// The caller must be holding senderLock (see acquireSenderLock and +// Attempt to send a query on the connection. Returns an error if sending the +// query failed, and the caller should initiate closure of this connection. The +// caller must be holding senderLock (see acquireSenderLock and // releaseSenderLock). func (l *ListenerConn) sendSimpleQuery(q string) (err error) { - defer errRecoverNoErrBadConn(&err) - - // must set connection state before sending the query + // Must set connection state before sending the query if !l.setState(connStateExpectResponse) { - panic("two queries running at the same time") + return errors.New("pq: two queries running at the same time") } // Can't use l.cn.writeBuf here because it uses the scratch buffer which @@ -323,18 +307,16 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { pos: 1, } b.string(q) - l.cn.send(b) - - return nil + return l.cn.send(b) } // ExecSimpleQuery executes a "simple query" (i.e. one with no bindable // parameters) on the connection. The possible return values are: -// 1) "executed" is true; the query was executed to completion on the -// database server. If the query failed, err will be set to the error -// returned by the database, otherwise err will be nil. -// 2) If "executed" is false, the query could not be executed on the remote -// server. err will be non-nil. +// 1. "executed" is true; the query was executed to completion on the database +// server. If the query failed, err will be set to the error returned by the +// database, otherwise err will be nil. +// 2. If "executed" is false, the query could not be executed on the remote +// server. err will be non-nil. // // After a call to ExecSimpleQuery has returned an executed=false value, the // connection has either been closed or will be closed shortly thereafter, and @@ -356,7 +338,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { l.err = err } l.connectionLock.Unlock() - l.cn.c.Close() + _ = l.cn.c.Close() return false, err } @@ -365,14 +347,14 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { m, ok := <-l.replyChan if !ok { // We lost the connection to server, don't bother waiting for a - // a response. err should have been set already. + // a response. err should have been set already. l.connectionLock.Lock() err := l.err l.connectionLock.Unlock() return false, err } switch m.typ { - case 'Z': + case proto.ReadyForQuery: // sanity check if m.err != nil { panic("m.err != nil") @@ -380,7 +362,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { // done; err might or might not be set return true, err - case 'E': + case proto.ErrorResponse: // sanity check if m.err == nil { panic("m.err == nil") @@ -414,8 +396,6 @@ func (l *ListenerConn) Err() error { return l.err } -var errListenerClosed = errors.New("pq: Listener has been closed") - // ErrChannelAlreadyOpen is returned from Listen when a channel is already // open. var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") @@ -427,27 +407,25 @@ var ErrChannelNotOpen = errors.New("pq: channel is not open") type ListenerEventType int const ( - // ListenerEventConnected is emitted only when the database connection - // has been initially initialized. The err argument of the callback - // will always be nil. + // ListenerEventConnected is emitted only when the database connection has + // been initially initialized. The err argument of the callback will always + // be nil. ListenerEventConnected ListenerEventType = iota - // ListenerEventDisconnected is emitted after a database connection has - // been lost, either because of an error or because Close has been - // called. The err argument will be set to the reason the database - // connection was lost. + // ListenerEventDisconnected is emitted after a database connection has been + // lost, either because of an error or because Close has been called. The + // err argument will be set to the reason the database connection was lost. ListenerEventDisconnected - // ListenerEventReconnected is emitted after a database connection has - // been re-established after connection loss. The err argument of the - // callback will always be nil. After this event has been emitted, a - // nil pq.Notification is sent on the Listener.Notify channel. + // ListenerEventReconnected is emitted after a database connection has been + // re-established after connection loss. The err argument of the callback + // will always be nil. After this event has been emitted, a nil + // pq.Notification is sent on the Listener.Notify channel. ListenerEventReconnected - // ListenerEventConnectionAttemptFailed is emitted after a connection - // to the database was attempted, but failed. The err argument will be - // set to an error describing why the connection attempt did not - // succeed. + // ListenerEventConnectionAttemptFailed is emitted after a connection to the + // database was attempted, but failed. The err argument will be set to an + // error describing why the connection attempt did not succeed. ListenerEventConnectionAttemptFailed ) @@ -455,17 +433,26 @@ const ( // constants' documentation. type EventCallbackType func(event ListenerEventType, err error) +func (l ListenerEventType) String() string { + return map[ListenerEventType]string{ + ListenerEventConnected: "connected", + ListenerEventDisconnected: "disconnected", + ListenerEventReconnected: "reconnected", + ListenerEventConnectionAttemptFailed: "connectionAttemptFailed", + }[l] +} + // Listener provides an interface for listening to notifications from a -// PostgreSQL database. For general usage information, see section +// PostgreSQL database. For general usage information, see section // "Notifications". // // Listener can safely be used from concurrently running goroutines. type Listener struct { - // Channel for receiving notifications from the database. In some cases a - // nil value will be sent. See section "Notifications" above. + // Channel for receiving notifications from the database. In some cases a + // nil value will be sent. See section "Notifications" above. Notify chan *Notification - name string + dsn string minReconnectInterval time.Duration maxReconnectInterval time.Duration dialer Dialer @@ -484,98 +471,85 @@ type Listener struct { // name should be set to a connection string to be used to establish the // database connection (see section "Connection String Parameters" above). // -// minReconnectInterval controls the duration to wait before trying to -// re-establish the database connection after connection loss. After each -// consecutive failure this interval is doubled, until maxReconnectInterval is -// reached. Successfully completing the connection establishment procedure -// resets the interval back to minReconnectInterval. +// minReconnect controls the duration to wait before trying to re-establish the +// database connection after connection loss. After each consecutive failure +// this interval is doubled, until maxReconnect is reached. Successfully +// completing the connection establishment procedure resets the interval back to +// minReconnect. // -// The last parameter eventCallback can be set to a function which will be -// called by the Listener when the state of the underlying database connection -// changes. This callback will be called by the goroutine which dispatches the -// notifications over the Notify channel, so you should try to avoid doing -// potentially time-consuming operations from the callback. -func NewListener(name string, - minReconnectInterval time.Duration, - maxReconnectInterval time.Duration, - eventCallback EventCallbackType) *Listener { - return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) +// The last parameter cb can be set to a function which will be called by the +// Listener when the state of the underlying database connection changes. This +// callback will be called by the goroutine which dispatches the notifications +// over the Notify channel, so you should try to avoid doing potentially +// time-consuming operations from the callback. +func NewListener(dsn string, minReconnect, maxReconnect time.Duration, cb EventCallbackType) *Listener { + return NewDialListener(defaultDialer{}, dsn, minReconnect, maxReconnect, cb) } // NewDialListener is like NewListener but it takes a Dialer. -func NewDialListener(d Dialer, - name string, - minReconnectInterval time.Duration, - maxReconnectInterval time.Duration, - eventCallback EventCallbackType) *Listener { - +func NewDialListener(d Dialer, dsn string, minReconnect, maxReconnect time.Duration, cb EventCallbackType) *Listener { l := &Listener{ - name: name, - minReconnectInterval: minReconnectInterval, - maxReconnectInterval: maxReconnectInterval, + dsn: dsn, + minReconnectInterval: minReconnect, + maxReconnectInterval: maxReconnect, dialer: d, - eventCallback: eventCallback, - - channels: make(map[string]struct{}), - - Notify: make(chan *Notification, 32), + eventCallback: cb, + channels: make(map[string]struct{}), + Notify: make(chan *Notification, 32), } l.reconnectCond = sync.NewCond(&l.lock) - go l.listenerMain() - return l } -// NotificationChannel returns the notification channel for this listener. -// This is the same channel as Notify, and will not be recreated during the -// life time of the Listener. +// NotificationChannel returns the notification channel for this listener. This +// is the same channel as Notify, and will not be recreated during the life time +// of the Listener. func (l *Listener) NotificationChannel() <-chan *Notification { return l.Notify } -// Listen starts listening for notifications on a channel. Calls to this +// Listen starts listening for notifications on a channel. Calls to this // function will block until an acknowledgement has been received from the -// server. Note that Listener automatically re-establishes the connection -// after connection loss, so this function may block indefinitely if the -// connection can not be re-established. +// server. Note that Listener automatically re-establishes the connection after +// connection loss, so this function may block indefinitely if the connection +// can not be re-established. // // Listen will only fail in three conditions: -// 1) The channel is already open. The returned error will be -// ErrChannelAlreadyOpen. -// 2) The query was executed on the remote server, but PostgreSQL returned an -// error message in response to the query. The returned error will be a -// pq.Error containing the information the server supplied. -// 3) Close is called on the Listener before the request could be completed. +// 1. The channel is already open. The returned error will be +// [ErrChannelAlreadyOpen]. +// 2. The query was executed on the remote server, but PostgreSQL returned an +// error message in response to the query. The returned error will be a +// [pq.Error] containing the information the server supplied. +// 3. Close is called on the Listener before the request could be completed. // // The channel name is case-sensitive. func (l *Listener) Listen(channel string) error { l.lock.Lock() defer l.lock.Unlock() - if l.isClosed { - return errListenerClosed + return net.ErrClosed } // The server allows you to issue a LISTEN on a channel which is already // open, but it seems useful to be able to detect this case to spot for - // mistakes in application logic. If the application genuinely does't - // care, it can check the exported error and ignore it. + // mistakes in application logic. If the application genuinely does't care, + // it can check the exported error and ignore it. _, exists := l.channels[channel] if exists { return ErrChannelAlreadyOpen } if l.cn != nil { - // If gotResponse is true but error is set, the query was executed on - // the remote server, but resulted in an error. This should be - // relatively rare, so it's fine if we just pass the error to our - // caller. However, if gotResponse is false, we could not complete the - // query on the remote server and our underlying connection is about - // to go away, so we only add relname to l.channels, and wait for - // resync() to take care of the rest. - gotResponse, err := l.cn.Listen(channel) - if gotResponse && err != nil { + // If resp is true but error is set then the query was executed on the + // remote server but resulted in an error. This should be relatively + // rare, so it's fine if we just pass the error to our caller. + // If resp is false then we could not complete the query on the remote + // server and our underlying connection is about to go away, so we only + // add relname to l.channels, and wait for resync() to take care of the + // rest. + resp, err := l.cn.Listen(channel) + if resp && err != nil { return err } } @@ -585,16 +559,16 @@ func (l *Listener) Listen(channel string) error { l.reconnectCond.Wait() // we let go of the mutex for a while if l.isClosed { - return errListenerClosed + return net.ErrClosed } } return nil } -// Unlisten removes a channel from the Listener's channel list. Returns +// Unlisten removes a channel from the Listener's channel list. Returns // ErrChannelNotOpen if the Listener is not listening on the specified channel. -// Returns immediately with no error if there is no connection. Note that you +// Returns immediately with no error if there is no connection. Note that you // might still get notifications for this channel even after Unlisten has // returned. // @@ -604,7 +578,7 @@ func (l *Listener) Unlisten(channel string) error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } // Similarly to LISTEN, this is not an error in Postgres, but it seems @@ -615,11 +589,11 @@ func (l *Listener) Unlisten(channel string) error { } if l.cn != nil { - // Similarly to Listen (see comment in that function), the caller - // should only be bothered with an error if it came from the backend as - // a response to our query. - gotResponse, err := l.cn.Unlisten(channel) - if gotResponse && err != nil { + // Similarly to Listen (see comment there), the caller should only be + // bothered with an error if it came from the backend as a response to + // our query. + resp, err := l.cn.Unlisten(channel) + if resp && err != nil { return err } } @@ -629,8 +603,8 @@ func (l *Listener) Unlisten(channel string) error { return nil } -// UnlistenAll removes all channels from the Listener's channel list. Returns -// immediately with no error if there is no connection. Note that you might +// UnlistenAll removes all channels from the Listener's channel list. Returns +// immediately with no error if there is no connection. Note that you might // still get notifications for any of the deleted channels even after // UnlistenAll has returned. func (l *Listener) UnlistenAll() error { @@ -638,7 +612,7 @@ func (l *Listener) UnlistenAll() error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn != nil { @@ -656,14 +630,14 @@ func (l *Listener) UnlistenAll() error { return nil } -// Ping the remote server to make sure it's alive. Non-nil return value means +// Ping the remote server to make sure it's alive. Non-nil return value means // that there is no active connection. func (l *Listener) Ping() error { l.lock.Lock() defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn == nil { return errors.New("no connection") @@ -672,8 +646,8 @@ func (l *Listener) Ping() error { return l.cn.Ping() } -// Clean up after losing the server connection. Returns l.cn.Err(), which -// should have the reason the connection was lost. +// Clean up after losing the server connection. Returns l.cn.Err(), which should +// have the reason the connection was lost. func (l *Listener) disconnectCleanup() error { l.lock.Lock() defer l.lock.Unlock() @@ -689,7 +663,7 @@ func (l *Listener) disconnectCleanup() error { } err := l.cn.Err() - l.cn.Close() + _ = l.cn.Close() l.cn = nil return err } @@ -722,10 +696,10 @@ func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notificatio }(notificationChan) // Ignore notifications while synchronization is going on to avoid - // deadlocks. We have to send a nil notification over Notify anyway as - // we can't possibly know which notifications (if any) were lost while - // the connection was down, so there's no reason to try and process - // these messages at all. + // deadlocks. We have to send a nil notification over Notify anyway as we + // can't possibly know which notifications (if any) were lost while the + // connection was down, so there's no reason to try and process these + // messages at all. for { select { case _, ok := <-notificationChan: @@ -748,41 +722,44 @@ func (l *Listener) closed() bool { } func (l *Listener) connect() error { + l.lock.Lock() + defer l.lock.Unlock() + if l.isClosed { + return net.ErrClosed + } + notificationChan := make(chan *Notification, 32) - cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) + + var err error + l.cn, err = newDialListenerConn(l.dialer, l.dsn, notificationChan) if err != nil { return err } - l.lock.Lock() - defer l.lock.Unlock() - - err = l.resync(cn, notificationChan) + err = l.resync(l.cn, notificationChan) if err != nil { - cn.Close() + _ = l.cn.Close() return err } - l.cn = cn l.connNotificationChan = notificationChan l.reconnectCond.Broadcast() - return nil } // Close disconnects the Listener from the database and shuts it down. -// Subsequent calls to its methods will return an error. Close returns an -// error if the connection has already been closed. +// Subsequent calls to its methods will return an error. Close returns an error +// if the connection has already been closed. func (l *Listener) Close() error { l.lock.Lock() defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn != nil { - l.cn.Close() + _ = l.cn.Close() } l.isClosed = true @@ -801,21 +778,21 @@ func (l *Listener) emitEvent(event ListenerEventType, err error) { // Main logic here: maintain a connection to the server when possible, wait // for notifications and emit events. func (l *Listener) listenerConnLoop() { - var nextReconnect time.Time - - reconnectInterval := l.minReconnectInterval + var ( + nextReconnect time.Time + reconnectInterval = l.minReconnectInterval + ) for { for { err := l.connect() if err == nil { break } - if l.closed() { return } - l.emitEvent(ListenerEventConnectionAttemptFailed, err) + l.emitEvent(ListenerEventConnectionAttemptFailed, err) time.Sleep(reconnectInterval) reconnectInterval *= 2 if reconnectInterval > l.maxReconnectInterval { @@ -835,8 +812,7 @@ func (l *Listener) listenerConnLoop() { for { notification, ok := <-l.connNotificationChan - if !ok { - // lost connection, loop again + if !ok { // lost connection, loop again break } l.Notify <- notification diff --git a/vendor/github.com/lib/pq/oid/doc.go b/vendor/github.com/lib/pq/oid/doc.go index caaede24..a4865066 100644 --- a/vendor/github.com/lib/pq/oid/doc.go +++ b/vendor/github.com/lib/pq/oid/doc.go @@ -1,5 +1,6 @@ -// Package oid contains OID constants -// as defined by the Postgres server. +//go:generate go run ./gen.go + +// Package oid contains OID constants as defined by the Postgres server. package oid // Oid is a Postgres Object ID. diff --git a/vendor/github.com/lib/pq/pqerror/codes.go b/vendor/github.com/lib/pq/pqerror/codes.go new file mode 100644 index 00000000..f5576644 --- /dev/null +++ b/vendor/github.com/lib/pq/pqerror/codes.go @@ -0,0 +1,581 @@ +// Code generated by gen.go. DO NOT EDIT. + +// Last updated for PostgreSQL 18.3 + +package pqerror + +var ( + ClassSuccessfulCompletion = Class("00") // Successful Completion + ClassWarning = Class("01") // Warning + ClassNoData = Class("02") // No Data (this is also a warning class per the SQL standard) + ClassSQLStatementNotYetComplete = Class("03") // SQL Statement Not Yet Complete + ClassConnectionException = Class("08") // Connection Exception + ClassTriggeredActionException = Class("09") // Triggered Action Exception + ClassFeatureNotSupported = Class("0A") // Feature Not Supported + ClassInvalidTransactionInitiation = Class("0B") // Invalid Transaction Initiation + ClassLocatorException = Class("0F") // Locator Exception + ClassInvalidGrantor = Class("0L") // Invalid Grantor + ClassInvalidRoleSpecification = Class("0P") // Invalid Role Specification + ClassDiagnosticsException = Class("0Z") // Diagnostics Exception + ClassCaseNotFound = Class("20") // Case Not Found + ClassCardinalityViolation = Class("21") // Cardinality Violation + ClassDataException = Class("22") // Data Exception + ClassIntegrityConstraintViolation = Class("23") // Integrity Constraint Violation + ClassInvalidCursorState = Class("24") // Invalid Cursor State + ClassInvalidTransactionState = Class("25") // Invalid Transaction State + ClassInvalidSQLStatementName = Class("26") // Invalid SQL Statement Name + ClassTriggeredDataChangeViolation = Class("27") // Triggered Data Change Violation + ClassInvalidAuthorizationSpecification = Class("28") // Invalid Authorization Specification + ClassDependentPrivilegeDescriptorsStillExist = Class("2B") // Dependent Privilege Descriptors Still Exist + ClassInvalidTransactionTermination = Class("2D") // Invalid Transaction Termination + ClassSQLRoutineException = Class("2F") // SQL Routine Exception + ClassInvalidCursorName = Class("34") // Invalid Cursor Name + ClassExternalRoutineException = Class("38") // External Routine Exception + ClassExternalRoutineInvocationException = Class("39") // External Routine Invocation Exception + ClassSavepointException = Class("3B") // Savepoint Exception + ClassInvalidCatalogName = Class("3D") // Invalid Catalog Name + ClassInvalidSchemaName = Class("3F") // Invalid Schema Name + ClassTransactionRollback = Class("40") // Transaction Rollback + ClassSyntaxErrorOrAccessRuleViolation = Class("42") // Syntax Error or Access Rule Violation + ClassWithCheckOptionViolation = Class("44") // WITH CHECK OPTION Violation + ClassInsufficientResources = Class("53") // Insufficient Resources + ClassProgramLimitExceeded = Class("54") // Program Limit Exceeded + ClassObjectNotInPrerequisiteState = Class("55") // Object Not In Prerequisite State + ClassOperatorIntervention = Class("57") // Operator Intervention + ClassSystemError = Class("58") // System Error (errors external to PostgreSQL itself) + ClassConfigFileError = Class("F0") // Configuration File Error + ClassFDWError = Class("HV") // Foreign Data Wrapper Error (SQL/MED) + ClassPLpgSQLError = Class("P0") // PL/pgSQL Error + ClassInternalError = Class("XX") // Internal Error +) + +// A list of all error codes used in PostgreSQL. +var ( + SuccessfulCompletion = Code("00000") // Class 00 - Successful Completion + Warning = Code("01000") // Class 01 - Warning + WarningDynamicResultSetsReturned = Code("0100C") + WarningImplicitZeroBitPadding = Code("01008") + WarningNullValueEliminatedInSetFunction = Code("01003") + WarningPrivilegeNotGranted = Code("01007") + WarningPrivilegeNotRevoked = Code("01006") + WarningStringDataRightTruncation = Code("01004") + WarningDeprecatedFeature = Code("01P01") + NoData = Code("02000") // Class 02 - No Data (this is also a warning class per the SQL standard) + NoAdditionalDynamicResultSetsReturned = Code("02001") + SQLStatementNotYetComplete = Code("03000") // Class 03 - SQL Statement Not Yet Complete + ConnectionException = Code("08000") // Class 08 - Connection Exception + ConnectionDoesNotExist = Code("08003") + ConnectionFailure = Code("08006") + SQLClientUnableToEstablishSQLConnection = Code("08001") + SQLServerRejectedEstablishmentOfSQLConnection = Code("08004") + TransactionResolutionUnknown = Code("08007") + ProtocolViolation = Code("08P01") + TriggeredActionException = Code("09000") // Class 09 - Triggered Action Exception + FeatureNotSupported = Code("0A000") // Class 0A - Feature Not Supported + InvalidTransactionInitiation = Code("0B000") // Class 0B - Invalid Transaction Initiation + LocatorException = Code("0F000") // Class 0F - Locator Exception + LEInvalidSpecification = Code("0F001") + InvalidGrantor = Code("0L000") // Class 0L - Invalid Grantor + InvalidGrantOperation = Code("0LP01") + InvalidRoleSpecification = Code("0P000") // Class 0P - Invalid Role Specification + DiagnosticsException = Code("0Z000") // Class 0Z - Diagnostics Exception + StackedDiagnosticsAccessedWithoutActiveHandler = Code("0Z002") + InvalidArgumentForXquery = Code("10608") + CaseNotFound = Code("20000") // Class 20 - Case Not Found + CardinalityViolation = Code("21000") // Class 21 - Cardinality Violation + DataException = Code("22000") // Class 22 - Data Exception + ArraySubscriptError = Code("2202E") + CharacterNotInRepertoire = Code("22021") + DatetimeFieldOverflow = Code("22008") + DivisionByZero = Code("22012") + ErrorInAssignment = Code("22005") + EscapeCharacterConflict = Code("2200B") + IndicatorOverflow = Code("22022") + IntervalFieldOverflow = Code("22015") + InvalidArgumentForLog = Code("2201E") + InvalidArgumentForNtile = Code("22014") + InvalidArgumentForNthValue = Code("22016") + InvalidArgumentForPowerFunction = Code("2201F") + InvalidArgumentForWidthBucketFunction = Code("2201G") + InvalidCharacterValueForCast = Code("22018") + InvalidDatetimeFormat = Code("22007") + InvalidEscapeCharacter = Code("22019") + InvalidEscapeOctet = Code("2200D") + InvalidEscapeSequence = Code("22025") + NonstandardUseOfEscapeCharacter = Code("22P06") + InvalidIndicatorParameterValue = Code("22010") + InvalidParameterValue = Code("22023") + InvalidPrecedingOrFollowingSize = Code("22013") + InvalidRegularExpression = Code("2201B") + InvalidRowCountInLimitClause = Code("2201W") + InvalidRowCountInResultOffsetClause = Code("2201X") + InvalidTablesampleArgument = Code("2202H") + InvalidTablesampleRepeat = Code("2202G") + InvalidTimeZoneDisplacementValue = Code("22009") + InvalidUseOfEscapeCharacter = Code("2200C") + MostSpecificTypeMismatch = Code("2200G") + NullValueNotAllowed = Code("22004") + NullValueNoIndicatorParameter = Code("22002") + NumericValueOutOfRange = Code("22003") + SequenceGeneratorLimitExceeded = Code("2200H") + StringDataLengthMismatch = Code("22026") + StringDataRightTruncation = Code("22001") + SubstringError = Code("22011") + TrimError = Code("22027") + UnterminatedCString = Code("22024") + ZeroLengthCharacterString = Code("2200F") + FloatingPointException = Code("22P01") + InvalidTextRepresentation = Code("22P02") + InvalidBinaryRepresentation = Code("22P03") + BadCopyFileFormat = Code("22P04") + UntranslatableCharacter = Code("22P05") + NotAnXMLDocument = Code("2200L") + InvalidXMLDocument = Code("2200M") + InvalidXMLContent = Code("2200N") + InvalidXMLComment = Code("2200S") + InvalidXMLProcessingInstruction = Code("2200T") + DuplicateJSONObjectKeyValue = Code("22030") + InvalidArgumentForSQLJSONDatetimeFunction = Code("22031") + InvalidJSONText = Code("22032") + InvalidSQLJSONSubscript = Code("22033") + MoreThanOneSQLJSONItem = Code("22034") + NoSQLJSONItem = Code("22035") + NonNumericSQLJSONItem = Code("22036") + NonUniqueKeysInAJSONObject = Code("22037") + SingletonSQLJSONItemRequired = Code("22038") + SQLJSONArrayNotFound = Code("22039") + SQLJSONMemberNotFound = Code("2203A") + SQLJSONNumberNotFound = Code("2203B") + SQLJSONObjectNotFound = Code("2203C") + TooManyJSONArrayElements = Code("2203D") + TooManyJSONObjectMembers = Code("2203E") + SQLJSONScalarRequired = Code("2203F") + SQLJSONItemCannotBeCastToTargetType = Code("2203G") + IntegrityConstraintViolation = Code("23000") // Class 23 - Integrity Constraint Violation + RestrictViolation = Code("23001") + NotNullViolation = Code("23502") + ForeignKeyViolation = Code("23503") + UniqueViolation = Code("23505") + CheckViolation = Code("23514") + ExclusionViolation = Code("23P01") + InvalidCursorState = Code("24000") // Class 24 - Invalid Cursor State + InvalidTransactionState = Code("25000") // Class 25 - Invalid Transaction State + ActiveSQLTransaction = Code("25001") + BranchTransactionAlreadyActive = Code("25002") + HeldCursorRequiresSameIsolationLevel = Code("25008") + InappropriateAccessModeForBranchTransaction = Code("25003") + InappropriateIsolationLevelForBranchTransaction = Code("25004") + NoActiveSQLTransactionForBranchTransaction = Code("25005") + ReadOnlySQLTransaction = Code("25006") + SchemaAndDataStatementMixingNotSupported = Code("25007") + NoActiveSQLTransaction = Code("25P01") + InFailedSQLTransaction = Code("25P02") + IdleInTransactionSessionTimeout = Code("25P03") + TransactionTimeout = Code("25P04") + InvalidSQLStatementName = Code("26000") // Class 26 - Invalid SQL Statement Name + TriggeredDataChangeViolation = Code("27000") // Class 27 - Triggered Data Change Violation + InvalidAuthorizationSpecification = Code("28000") // Class 28 - Invalid Authorization Specification + InvalidPassword = Code("28P01") + DependentPrivilegeDescriptorsStillExist = Code("2B000") // Class 2B - Dependent Privilege Descriptors Still Exist + DependentObjectsStillExist = Code("2BP01") + InvalidTransactionTermination = Code("2D000") // Class 2D - Invalid Transaction Termination + SQLRoutineException = Code("2F000") // Class 2F - SQL Routine Exception + SREFunctionExecutedNoReturnStatement = Code("2F005") + SREModifyingSQLDataNotPermitted = Code("2F002") + SREProhibitedSQLStatementAttempted = Code("2F003") + SREReadingSQLDataNotPermitted = Code("2F004") + InvalidCursorName = Code("34000") // Class 34 - Invalid Cursor Name + ExternalRoutineException = Code("38000") // Class 38 - External Routine Exception + EREContainingSQLNotPermitted = Code("38001") + EREModifyingSQLDataNotPermitted = Code("38002") + EREProhibitedSQLStatementAttempted = Code("38003") + EREReadingSQLDataNotPermitted = Code("38004") + ExternalRoutineInvocationException = Code("39000") // Class 39 - External Routine Invocation Exception + ERIEInvalidSQLSTATEReturned = Code("39001") + ERIENullValueNotAllowed = Code("39004") + ERIETriggerProtocolViolated = Code("39P01") + ERIESrfProtocolViolated = Code("39P02") + ERIEEventTriggerProtocolViolated = Code("39P03") + SavepointException = Code("3B000") // Class 3B - Savepoint Exception + SEInvalidSpecification = Code("3B001") + InvalidCatalogName = Code("3D000") // Class 3D - Invalid Catalog Name + InvalidSchemaName = Code("3F000") // Class 3F - Invalid Schema Name + TransactionRollback = Code("40000") // Class 40 - Transaction Rollback + TRIntegrityConstraintViolation = Code("40002") + TRSerializationFailure = Code("40001") + TRStatementCompletionUnknown = Code("40003") + TRDeadlockDetected = Code("40P01") + SyntaxErrorOrAccessRuleViolation = Code("42000") // Class 42 - Syntax Error or Access Rule Violation + SyntaxError = Code("42601") + InsufficientPrivilege = Code("42501") + CannotCoerce = Code("42846") + GroupingError = Code("42803") + WindowingError = Code("42P20") + InvalidRecursion = Code("42P19") + InvalidForeignKey = Code("42830") + InvalidName = Code("42602") + NameTooLong = Code("42622") + ReservedName = Code("42939") + DatatypeMismatch = Code("42804") + IndeterminateDatatype = Code("42P18") + CollationMismatch = Code("42P21") + IndeterminateCollation = Code("42P22") + WrongObjectType = Code("42809") + GeneratedAlways = Code("428C9") + UndefinedColumn = Code("42703") + UndefinedFunction = Code("42883") + UndefinedTable = Code("42P01") + UndefinedParameter = Code("42P02") + UndefinedObject = Code("42704") + DuplicateColumn = Code("42701") + DuplicateCursor = Code("42P03") + DuplicateDatabase = Code("42P04") + DuplicateFunction = Code("42723") + DuplicatePstatement = Code("42P05") + DuplicateSchema = Code("42P06") + DuplicateTable = Code("42P07") + DuplicateAlias = Code("42712") + DuplicateObject = Code("42710") + AmbiguousColumn = Code("42702") + AmbiguousFunction = Code("42725") + AmbiguousParameter = Code("42P08") + AmbiguousAlias = Code("42P09") + InvalidColumnReference = Code("42P10") + InvalidColumnDefinition = Code("42611") + InvalidCursorDefinition = Code("42P11") + InvalidDatabaseDefinition = Code("42P12") + InvalidFunctionDefinition = Code("42P13") + InvalidPstatementDefinition = Code("42P14") + InvalidSchemaDefinition = Code("42P15") + InvalidTableDefinition = Code("42P16") + InvalidObjectDefinition = Code("42P17") + WithCheckOptionViolation = Code("44000") // Class 44 - WITH CHECK OPTION Violation + InsufficientResources = Code("53000") // Class 53 - Insufficient Resources + DiskFull = Code("53100") + OutOfMemory = Code("53200") + TooManyConnections = Code("53300") + ConfigurationLimitExceeded = Code("53400") + ProgramLimitExceeded = Code("54000") // Class 54 - Program Limit Exceeded + StatementTooComplex = Code("54001") + TooManyColumns = Code("54011") + TooManyArguments = Code("54023") + ObjectNotInPrerequisiteState = Code("55000") // Class 55 - Object Not In Prerequisite State + ObjectInUse = Code("55006") + CantChangeRuntimeParam = Code("55P02") + LockNotAvailable = Code("55P03") + UnsafeNewEnumValueUsage = Code("55P04") + OperatorIntervention = Code("57000") // Class 57 - Operator Intervention + QueryCanceled = Code("57014") + AdminShutdown = Code("57P01") + CrashShutdown = Code("57P02") + CannotConnectNow = Code("57P03") + DatabaseDropped = Code("57P04") + IdleSessionTimeout = Code("57P05") + SystemError = Code("58000") // Class 58 - System Error (errors external to PostgreSQL itself) + IOError = Code("58030") + UndefinedFile = Code("58P01") + DuplicateFile = Code("58P02") + FileNameTooLong = Code("58P03") + ConfigFileError = Code("F0000") // Class F0 - Configuration File Error + LockFileExists = Code("F0001") + FDWError = Code("HV000") // Class HV - Foreign Data Wrapper Error (SQL/MED) + FDWColumnNameNotFound = Code("HV005") + FDWDynamicParameterValueNeeded = Code("HV002") + FDWFunctionSequenceError = Code("HV010") + FDWInconsistentDescriptorInformation = Code("HV021") + FDWInvalidAttributeValue = Code("HV024") + FDWInvalidColumnName = Code("HV007") + FDWInvalidColumnNumber = Code("HV008") + FDWInvalidDataType = Code("HV004") + FDWInvalidDataTypeDescriptors = Code("HV006") + FDWInvalidDescriptorFieldIdentifier = Code("HV091") + FDWInvalidHandle = Code("HV00B") + FDWInvalidOptionIndex = Code("HV00C") + FDWInvalidOptionName = Code("HV00D") + FDWInvalidStringLengthOrBufferLength = Code("HV090") + FDWInvalidStringFormat = Code("HV00A") + FDWInvalidUseOfNullPointer = Code("HV009") + FDWTooManyHandles = Code("HV014") + FDWOutOfMemory = Code("HV001") + FDWNoSchemas = Code("HV00P") + FDWOptionNameNotFound = Code("HV00J") + FDWReplyHandle = Code("HV00K") + FDWSchemaNotFound = Code("HV00Q") + FDWTableNotFound = Code("HV00R") + FDWUnableToCreateExecution = Code("HV00L") + FDWUnableToCreateReply = Code("HV00M") + FDWUnableToEstablishConnection = Code("HV00N") + PLpgSQLError = Code("P0000") // Class P0 - PL/pgSQL Error + RaiseException = Code("P0001") + NoDataFound = Code("P0002") + TooManyRows = Code("P0003") + AssertFailure = Code("P0004") + InternalError = Code("XX000") // Class XX - Internal Error + DataCorrupted = Code("XX001") + IndexCorrupted = Code("XX002") +) + +var errorCodeNames = map[Code]string{ + "00000": "successful_completion", + "01000": "warning", + "0100C": "dynamic_result_sets_returned", + "01008": "implicit_zero_bit_padding", + "01003": "null_value_eliminated_in_set_function", + "01007": "privilege_not_granted", + "01006": "privilege_not_revoked", + "01004": "string_data_right_truncation", + "01P01": "deprecated_feature", + "02000": "no_data", + "02001": "no_additional_dynamic_result_sets_returned", + "03000": "sql_statement_not_yet_complete", + "08000": "connection_exception", + "08003": "connection_does_not_exist", + "08006": "connection_failure", + "08001": "sqlclient_unable_to_establish_sqlconnection", + "08004": "sqlserver_rejected_establishment_of_sqlconnection", + "08007": "transaction_resolution_unknown", + "08P01": "protocol_violation", + "09000": "triggered_action_exception", + "0A000": "feature_not_supported", + "0B000": "invalid_transaction_initiation", + "0F000": "locator_exception", + "0F001": "invalid_locator_specification", + "0L000": "invalid_grantor", + "0LP01": "invalid_grant_operation", + "0P000": "invalid_role_specification", + "0Z000": "diagnostics_exception", + "0Z002": "stacked_diagnostics_accessed_without_active_handler", + "10608": "invalid_argument_for_xquery", + "20000": "case_not_found", + "21000": "cardinality_violation", + "22000": "data_exception", + "2202E": "array_subscript_error", + "22021": "character_not_in_repertoire", + "22008": "datetime_field_overflow", + "22012": "division_by_zero", + "22005": "error_in_assignment", + "2200B": "escape_character_conflict", + "22022": "indicator_overflow", + "22015": "interval_field_overflow", + "2201E": "invalid_argument_for_logarithm", + "22014": "invalid_argument_for_ntile_function", + "22016": "invalid_argument_for_nth_value_function", + "2201F": "invalid_argument_for_power_function", + "2201G": "invalid_argument_for_width_bucket_function", + "22018": "invalid_character_value_for_cast", + "22007": "invalid_datetime_format", + "22019": "invalid_escape_character", + "2200D": "invalid_escape_octet", + "22025": "invalid_escape_sequence", + "22P06": "nonstandard_use_of_escape_character", + "22010": "invalid_indicator_parameter_value", + "22023": "invalid_parameter_value", + "22013": "invalid_preceding_or_following_size", + "2201B": "invalid_regular_expression", + "2201W": "invalid_row_count_in_limit_clause", + "2201X": "invalid_row_count_in_result_offset_clause", + "2202H": "invalid_tablesample_argument", + "2202G": "invalid_tablesample_repeat", + "22009": "invalid_time_zone_displacement_value", + "2200C": "invalid_use_of_escape_character", + "2200G": "most_specific_type_mismatch", + "22004": "null_value_not_allowed", + "22002": "null_value_no_indicator_parameter", + "22003": "numeric_value_out_of_range", + "2200H": "sequence_generator_limit_exceeded", + "22026": "string_data_length_mismatch", + "22001": "string_data_right_truncation", + "22011": "substring_error", + "22027": "trim_error", + "22024": "unterminated_c_string", + "2200F": "zero_length_character_string", + "22P01": "floating_point_exception", + "22P02": "invalid_text_representation", + "22P03": "invalid_binary_representation", + "22P04": "bad_copy_file_format", + "22P05": "untranslatable_character", + "2200L": "not_an_xml_document", + "2200M": "invalid_xml_document", + "2200N": "invalid_xml_content", + "2200S": "invalid_xml_comment", + "2200T": "invalid_xml_processing_instruction", + "22030": "duplicate_json_object_key_value", + "22031": "invalid_argument_for_sql_json_datetime_function", + "22032": "invalid_json_text", + "22033": "invalid_sql_json_subscript", + "22034": "more_than_one_sql_json_item", + "22035": "no_sql_json_item", + "22036": "non_numeric_sql_json_item", + "22037": "non_unique_keys_in_a_json_object", + "22038": "singleton_sql_json_item_required", + "22039": "sql_json_array_not_found", + "2203A": "sql_json_member_not_found", + "2203B": "sql_json_number_not_found", + "2203C": "sql_json_object_not_found", + "2203D": "too_many_json_array_elements", + "2203E": "too_many_json_object_members", + "2203F": "sql_json_scalar_required", + "2203G": "sql_json_item_cannot_be_cast_to_target_type", + "23000": "integrity_constraint_violation", + "23001": "restrict_violation", + "23502": "not_null_violation", + "23503": "foreign_key_violation", + "23505": "unique_violation", + "23514": "check_violation", + "23P01": "exclusion_violation", + "24000": "invalid_cursor_state", + "25000": "invalid_transaction_state", + "25001": "active_sql_transaction", + "25002": "branch_transaction_already_active", + "25008": "held_cursor_requires_same_isolation_level", + "25003": "inappropriate_access_mode_for_branch_transaction", + "25004": "inappropriate_isolation_level_for_branch_transaction", + "25005": "no_active_sql_transaction_for_branch_transaction", + "25006": "read_only_sql_transaction", + "25007": "schema_and_data_statement_mixing_not_supported", + "25P01": "no_active_sql_transaction", + "25P02": "in_failed_sql_transaction", + "25P03": "idle_in_transaction_session_timeout", + "25P04": "transaction_timeout", + "26000": "invalid_sql_statement_name", + "27000": "triggered_data_change_violation", + "28000": "invalid_authorization_specification", + "28P01": "invalid_password", + "2B000": "dependent_privilege_descriptors_still_exist", + "2BP01": "dependent_objects_still_exist", + "2D000": "invalid_transaction_termination", + "2F000": "sql_routine_exception", + "2F005": "function_executed_no_return_statement", + "2F002": "modifying_sql_data_not_permitted", + "2F003": "prohibited_sql_statement_attempted", + "2F004": "reading_sql_data_not_permitted", + "34000": "invalid_cursor_name", + "38000": "external_routine_exception", + "38001": "containing_sql_not_permitted", + "38002": "modifying_sql_data_not_permitted", + "38003": "prohibited_sql_statement_attempted", + "38004": "reading_sql_data_not_permitted", + "39000": "external_routine_invocation_exception", + "39001": "invalid_sqlstate_returned", + "39004": "null_value_not_allowed", + "39P01": "trigger_protocol_violated", + "39P02": "srf_protocol_violated", + "39P03": "event_trigger_protocol_violated", + "3B000": "savepoint_exception", + "3B001": "invalid_savepoint_specification", + "3D000": "invalid_catalog_name", + "3F000": "invalid_schema_name", + "40000": "transaction_rollback", + "40002": "transaction_integrity_constraint_violation", + "40001": "serialization_failure", + "40003": "statement_completion_unknown", + "40P01": "deadlock_detected", + "42000": "syntax_error_or_access_rule_violation", + "42601": "syntax_error", + "42501": "insufficient_privilege", + "42846": "cannot_coerce", + "42803": "grouping_error", + "42P20": "windowing_error", + "42P19": "invalid_recursion", + "42830": "invalid_foreign_key", + "42602": "invalid_name", + "42622": "name_too_long", + "42939": "reserved_name", + "42804": "datatype_mismatch", + "42P18": "indeterminate_datatype", + "42P21": "collation_mismatch", + "42P22": "indeterminate_collation", + "42809": "wrong_object_type", + "428C9": "generated_always", + "42703": "undefined_column", + "42883": "undefined_function", + "42P01": "undefined_table", + "42P02": "undefined_parameter", + "42704": "undefined_object", + "42701": "duplicate_column", + "42P03": "duplicate_cursor", + "42P04": "duplicate_database", + "42723": "duplicate_function", + "42P05": "duplicate_prepared_statement", + "42P06": "duplicate_schema", + "42P07": "duplicate_table", + "42712": "duplicate_alias", + "42710": "duplicate_object", + "42702": "ambiguous_column", + "42725": "ambiguous_function", + "42P08": "ambiguous_parameter", + "42P09": "ambiguous_alias", + "42P10": "invalid_column_reference", + "42611": "invalid_column_definition", + "42P11": "invalid_cursor_definition", + "42P12": "invalid_database_definition", + "42P13": "invalid_function_definition", + "42P14": "invalid_prepared_statement_definition", + "42P15": "invalid_schema_definition", + "42P16": "invalid_table_definition", + "42P17": "invalid_object_definition", + "44000": "with_check_option_violation", + "53000": "insufficient_resources", + "53100": "disk_full", + "53200": "out_of_memory", + "53300": "too_many_connections", + "53400": "configuration_limit_exceeded", + "54000": "program_limit_exceeded", + "54001": "statement_too_complex", + "54011": "too_many_columns", + "54023": "too_many_arguments", + "55000": "object_not_in_prerequisite_state", + "55006": "object_in_use", + "55P02": "cant_change_runtime_param", + "55P03": "lock_not_available", + "55P04": "unsafe_new_enum_value_usage", + "57000": "operator_intervention", + "57014": "query_canceled", + "57P01": "admin_shutdown", + "57P02": "crash_shutdown", + "57P03": "cannot_connect_now", + "57P04": "database_dropped", + "57P05": "idle_session_timeout", + "58000": "system_error", + "58030": "io_error", + "58P01": "undefined_file", + "58P02": "duplicate_file", + "58P03": "file_name_too_long", + "F0000": "config_file_error", + "F0001": "lock_file_exists", + "HV000": "fdw_error", + "HV005": "fdw_column_name_not_found", + "HV002": "fdw_dynamic_parameter_value_needed", + "HV010": "fdw_function_sequence_error", + "HV021": "fdw_inconsistent_descriptor_information", + "HV024": "fdw_invalid_attribute_value", + "HV007": "fdw_invalid_column_name", + "HV008": "fdw_invalid_column_number", + "HV004": "fdw_invalid_data_type", + "HV006": "fdw_invalid_data_type_descriptors", + "HV091": "fdw_invalid_descriptor_field_identifier", + "HV00B": "fdw_invalid_handle", + "HV00C": "fdw_invalid_option_index", + "HV00D": "fdw_invalid_option_name", + "HV090": "fdw_invalid_string_length_or_buffer_length", + "HV00A": "fdw_invalid_string_format", + "HV009": "fdw_invalid_use_of_null_pointer", + "HV014": "fdw_too_many_handles", + "HV001": "fdw_out_of_memory", + "HV00P": "fdw_no_schemas", + "HV00J": "fdw_option_name_not_found", + "HV00K": "fdw_reply_handle", + "HV00Q": "fdw_schema_not_found", + "HV00R": "fdw_table_not_found", + "HV00L": "fdw_unable_to_create_execution", + "HV00M": "fdw_unable_to_create_reply", + "HV00N": "fdw_unable_to_establish_connection", + "P0000": "plpgsql_error", + "P0001": "raise_exception", + "P0002": "no_data_found", + "P0003": "too_many_rows", + "P0004": "assert_failure", + "XX000": "internal_error", + "XX001": "data_corrupted", + "XX002": "index_corrupted", +} diff --git a/vendor/github.com/lib/pq/pqerror/pqerror.go b/vendor/github.com/lib/pq/pqerror/pqerror.go new file mode 100644 index 00000000..29e49e99 --- /dev/null +++ b/vendor/github.com/lib/pq/pqerror/pqerror.go @@ -0,0 +1,35 @@ +//go:generate go run gen.go + +// Package pqerror contains PostgreSQL error codes for use with pq.Error. +package pqerror + +// Code is a five-character error code. +type Code string + +// Name returns a more human friendly rendering of the error code, namely the +// "condition name". +func (ec Code) Name() string { return errorCodeNames[ec] } + +// Class returns the error class, e.g. "28". +func (ec Code) Class() Class { return Class(ec[:2]) } + +// Class is only the class part of an error code. +type Class string + +// Name returns the condition name of an error class. It is equivalent to the +// condition name of the "standard" error code (i.e. the one having the last +// three characters "000"). +func (ec Class) Name() string { return errorCodeNames[Code(ec+"000")] } + +// TODO(v2): use "type Severity string" for the below. + +// Error severity values. +const ( + SeverityFatal = "FATAL" + SeverityPanic = "PANIC" + SeverityWarning = "WARNING" + SeverityNotice = "NOTICE" + SeverityDebug = "DEBUG" + SeverityInfo = "INFO" + SeverityLog = "LOG" +) diff --git a/vendor/github.com/lib/pq/quote.go b/vendor/github.com/lib/pq/quote.go new file mode 100644 index 00000000..909e41ec --- /dev/null +++ b/vendor/github.com/lib/pq/quote.go @@ -0,0 +1,71 @@ +package pq + +import ( + "bytes" + "strings" +) + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be case +// sensitive when used in a query. If the input string contains a zero byte, the +// result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} + +// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a +// byte buffer. +func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { + // TODO(v2): this should have accepted an io.Writer, not *bytes.Buffer. + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + buffer.WriteRune('"') + buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) + buffer.WriteRune('"') +} + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} diff --git a/vendor/github.com/lib/pq/rows.go b/vendor/github.com/lib/pq/rows.go index c6aa5b9a..2029bfed 100644 --- a/vendor/github.com/lib/pq/rows.go +++ b/vendor/github.com/lib/pq/rows.go @@ -1,13 +1,182 @@ package pq import ( + "database/sql/driver" + "fmt" + "io" "math" "reflect" "time" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" ) +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { return 0, errNoLastInsertID } +func (noRows) RowsAffected() (int64, error) { return 0, errNoRowsAffected } + +type ( + rowsHeader struct { + colNames []string + colTyps []fieldDesc + colFmts []format + } + rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader + } +) + +func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } + // no need to look at cn.bad as Next() will + for { + err := rs.Next(nil) + switch err { + case nil: + case io.EOF: + // rs.Next can return io.EOF on both ReadyForQuery and + // RowDescription (used with HasNextResultSet). We need to fetch + // messages until we hit a ReadyForQuery, which is done by waiting + // for done to be set. + if rs.done { + return nil + } + default: + return err + } + } +} + +func (rs *rows) Columns() []string { + return rs.colNames +} + +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + +func (rs *rows) Next(dest []driver.Value) (resErr error) { + if rs.done { + return io.EOF + } + if err := rs.cn.err.getForNext(); err != nil { + return err + } + + for { + t, err := rs.cn.recv1Buf(&rs.rb) + if err != nil { + return rs.cn.handleError(err) + } + switch t { + case proto.ErrorResponse: + resErr = parseError(&rs.rb, "") + case proto.CommandComplete, proto.EmptyQueryResponse: + if t == proto.CommandComplete { + rs.result, rs.tag, err = rs.cn.parseComplete(rs.rb.string()) + if err != nil { + return rs.cn.handleError(err) + } + } + continue + case proto.ReadyForQuery: + rs.cn.processReadyForQuery(&rs.rb) + rs.done = true + if resErr != nil { + return rs.cn.handleError(resErr) + } + return io.EOF + case proto.DataRow: + n := rs.rb.int16() + if resErr != nil { + rs.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected DataRow after error %s", resErr) + } + if n < len(dest) { + dest = dest[:n] + } + for i := range dest { + l := rs.rb.int32() + if l == -1 { + dest[i] = nil + continue + } + dest[i], err = decode(&rs.cn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) + if err != nil { + return rs.cn.handleError(err) + } + } + return rs.cn.handleError(resErr) + case proto.RowDescription: + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next + return io.EOF + default: + return fmt.Errorf("pq: unexpected message after execute: %q", t) + } + } +} + +func (rs *rows) HasNextResultSet() bool { + hasNext := rs.next != nil && !rs.done + return hasNext +} + +func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil + return nil +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} + const headerSize = 4 type fieldDesc struct { @@ -29,7 +198,11 @@ func (fd fieldDesc) Type() reflect.Type { return reflect.TypeOf(int32(0)) case oid.T_int2: return reflect.TypeOf(int16(0)) - case oid.T_varchar, oid.T_text: + case oid.T_float8: + return reflect.TypeOf(float64(0)) + case oid.T_float4: + return reflect.TypeOf(float32(0)) + case oid.T_varchar, oid.T_text, oid.T_varbit, oid.T_bit: return reflect.TypeOf("") case oid.T_bool: return reflect.TypeOf(false) @@ -38,7 +211,7 @@ func (fd fieldDesc) Type() reflect.Type { case oid.T_bytea: return reflect.TypeOf([]byte(nil)) default: - return reflect.TypeOf(new(interface{})).Elem() + return reflect.TypeOf(new(any)).Elem() } } @@ -52,6 +225,8 @@ func (fd fieldDesc) Length() (length int64, ok bool) { return math.MaxInt64, true case oid.T_varchar, oid.T_bpchar: return int64(fd.Mod - headerSize), true + case oid.T_varbit, oid.T_bit: + return int64(fd.Mod), true default: return 0, false } @@ -68,26 +243,3 @@ func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { return 0, 0, false } } - -// ColumnTypeScanType returns the value type that can be used to scan types into. -func (rs *rows) ColumnTypeScanType(index int) reflect.Type { - return rs.colTyps[index].Type() -} - -// ColumnTypeDatabaseTypeName return the database system type name. -func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { - return rs.colTyps[index].Name() -} - -// ColumnTypeLength returns the length of the column type if the column is a -// variable length type. If the column is not a variable length type ok -// should return false. -func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { - return rs.colTyps[index].Length() -} - -// ColumnTypePrecisionScale should return the precision and scale for decimal -// types. If not applicable, ok should be false. -func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - return rs.colTyps[index].PrecisionScale() -} diff --git a/vendor/github.com/lib/pq/scram/scram.go b/vendor/github.com/lib/pq/scram/scram.go index 477216b6..7ed7a993 100644 --- a/vendor/github.com/lib/pq/scram/scram.go +++ b/vendor/github.com/lib/pq/scram/scram.go @@ -25,7 +25,6 @@ // Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. // // http://tools.ietf.org/html/rfc5802 -// package scram import ( @@ -43,17 +42,16 @@ import ( // // A Client may be used within a SASL conversation with logic resembling: // -// var in []byte -// var client = scram.NewClient(sha1.New, user, pass) -// for client.Step(in) { -// out := client.Out() -// // send out to server -// in := serverOut -// } -// if client.Err() != nil { -// // auth failed -// } -// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } type Client struct { newHash func() hash.Hash @@ -73,8 +71,7 @@ type Client struct { // // For SCRAM-SHA-256, for example, use: // -// client := scram.NewClient(sha256.New, user, pass) -// +// client := scram.NewClient(sha256.New, user, pass) func NewClient(newHash func() hash.Hash, user, pass string) *Client { c := &Client{ newHash: newHash, @@ -133,7 +130,7 @@ func (c *Client) step1(in []byte) error { const nonceLen = 16 buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) if _, err := rand.Read(buf[:nonceLen]); err != nil { - return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %w", err) } c.clientNonce = buf[nonceLen:] b64.Encode(c.clientNonce, buf[:nonceLen]) diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go index 36b61ba4..b9357854 100644 --- a/vendor/github.com/lib/pq/ssl.go +++ b/vendor/github.com/lib/pq/ssl.go @@ -1,204 +1,312 @@ package pq import ( + "bytes" "crypto/tls" "crypto/x509" - "io/ioutil" + "encoding/pem" + "errors" + "fmt" "net" "os" - "os/user" "path/filepath" + "slices" "strings" + "sync" + + "github.com/lib/pq/internal/pqutil" +) + +// Registry for custom tls.Configs +var ( + tlsConfs = make(map[string]*tls.Config) + tlsConfsMu sync.RWMutex ) +// RegisterTLSConfig registers a custom [tls.Config]. They are used by using +// sslmode=pqgo-«key» in the connection string. +// +// Set the config to nil to remove a configuration. +func RegisterTLSConfig(key string, config *tls.Config) error { + key = strings.TrimPrefix(key, "pqgo-") + if config == nil { + tlsConfsMu.Lock() + delete(tlsConfs, key) + tlsConfsMu.Unlock() + return nil + } + + tlsConfsMu.Lock() + tlsConfs[key] = config + tlsConfsMu.Unlock() + return nil +} + +func hasTLSConfig(key string) bool { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + _, ok := tlsConfs[key] + return ok +} + +func getTLSConfigClone(key string) *tls.Config { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + if v, ok := tlsConfs[key]; ok { + return v.Clone() + } + return nil +} + // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. -func ssl(o values) (func(net.Conn) (net.Conn, error), error) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o["sslmode"]; mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. +// +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) { + var ( + home = pqutil.Home(true) + // Don't set defaults here, because tlsConf may be overwritten if a + // custom one was registered. Set it after the sslmode switch. + tlsConf = &tls.Config{} + // Only verify the CA signing but not the hostname. + verifyCaOnly = false + ) + if mode.useSSL() && !cfg.SSLInline && cfg.SSLRootCert == "" && home != "" { + f := filepath.Join(home, "root.crt") + if _, err := os.Stat(f); err == nil { + cfg.SSLRootCert = f + } + } + switch { + case mode == SSLModeDisable || mode == SSLModeAllow: + return nil, nil + + case mode == "" || mode == SSLModeRequire || mode == SSLModePrefer: + // Skip TLS's own verification since it requires full verification. tlsConf.InsecureSkipVerify = true // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: // - // Note: For backwards compatibility with earlier versions of - // PostgreSQL, if a root CA file exists, the behavior of - // sslmode=require will be the same as that of verify-ca, meaning the - // server certificate is validated against the CA. Relying on this - // behavior is discouraged, and applications that need certificate - // validation should always use verify-ca or verify-full. - if sslrootcert, ok := o["sslrootcert"]; ok { - if _, err := os.Stat(sslrootcert); err == nil { + // For backwards compatibility with earlier versions of PostgreSQL, if a + // root CA file exists, the behavior of sslmode=require will be the same + // as that of verify-ca, meaning the server certificate is validated + // against the CA. Relying on this behavior is discouraged, and + // applications that need certificate validation should always use + // verify-ca or verify-full. + if cfg.SSLRootCert != "" { + if cfg.SSLInline { + verifyCaOnly = true + } else if _, err := os.Stat(cfg.SSLRootCert); err == nil { verifyCaOnly = true - } else { - delete(o, "sslrootcert") + } else if cfg.SSLRootCert != "system" { + cfg.SSLRootCert = "" } } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. + case mode == SSLModeVerifyCA: + // Skip TLS's own verification since it requires full verification. tlsConf.InsecureSkipVerify = true verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o["host"] - case "disable": - return nil, nil + case mode == SSLModeVerifyFull: + tlsConf.ServerName = cfg.Host + case strings.HasPrefix(string(mode), "pqgo-"): + tlsConf = getTLSConfigClone(string(mode[5:])) + if tlsConf == nil { + return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode) + } default: - return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + panic("unreachable") } - // Set Server Name Indication (SNI), if enabled by connection parameters. - // By default SNI is on, any value which is not starting with "1" disables - // SNI -- that is the same check vanilla libpq uses. - if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { - // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 - // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so - // just always set ServerName here and let crypto/tls do the filtering. - tlsConf.ServerName = o["host"] + tlsConf.MinVersion = cfg.SSLMinProtocolVersion.tlsconf() + tlsConf.MaxVersion = cfg.SSLMaxProtocolVersion.tlsconf() + + // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or + // IPv6). This check is coded already crypto.tls.hostnameInSNI, so just + // always set ServerName here and let crypto/tls do the filtering. + if cfg.SSLSNI { + tlsConf.ServerName = cfg.Host } - err := sslClientCertificates(&tlsConf, o) + err := sslClientCertificates(tlsConf, cfg, home) if err != nil { return nil, err } - err = sslCertificateAuthority(&tlsConf, o) + rootPem, err := sslCertificateAuthority(tlsConf, cfg) if err != nil { return nil, err } + sslAppendIntermediates(tlsConf, cfg, rootPem) // Accept renegotiation requests initiated by the backend. // - // Renegotiation was deprecated then removed from PostgreSQL 9.5, but - // the default configuration of older versions has it enabled. Redshift - // also initiates renegotiations and cannot be reconfigured. + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but the + // default configuration of older versions has it enabled. Redshift also + // initiates renegotiations and cannot be reconfigured. + // + // TODO: I think this can be removed? tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { - client := tls.Client(conn, &tlsConf) + client := tls.Client(conn, tlsConf) if verifyCaOnly { - err := sslVerifyCertificateAuthority(client, &tlsConf) + err := client.Handshake() if err != nil { - return nil, err + return client, err } + var ( + certs = client.ConnectionState().PeerCertificates + opts = x509.VerifyOptions{Intermediates: x509.NewCertPool(), Roots: tlsConf.RootCAs} + ) + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + return client, err } return client, nil }, nil } // sslClientCertificates adds the certificate specified in the "sslcert" and +// // "sslkey" settings, or if they aren't set, from the .postgresql directory // in the user's home directory. The configured files must exist and have // the correct permissions. -func sslClientCertificates(tlsConf *tls.Config, o values) error { - sslinline := o["sslinline"] - if sslinline == "true" { - cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) +func sslClientCertificates(tlsConf *tls.Config, cfg Config, home string) error { + if cfg.SSLInline { + cert, err := tls.X509KeyPair([]byte(cfg.SSLCert), []byte(cfg.SSLKey)) if err != nil { return err } - tlsConf.Certificates = []tls.Certificate{cert} + // Use GetClientCertificate instead of the Certificates field. When + // Certificates is set, Go's TLS client only sends the cert if the + // server's CertificateRequest includes a CA that issued it. When the + // client cert was signed by an intermediate CA but the server only + // advertises the root CA, Go skips sending the cert entirely. + // GetClientCertificate bypasses this filtering. + tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &cert, nil + } return nil } - // user.Current() might fail when cross-compiling. We have to ignore the - // error and continue without home directory defaults, since we wouldn't - // know from where to load them. - user, _ := user.Current() - - // In libpq, the client certificate is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 - sslcert := o["sslcert"] - if len(sslcert) == 0 && user != nil { - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + // Only load client certificate and key if the setting is not blank, like libpq. + if cfg.SSLCert == "" && home != "" { + cfg.SSLCert = filepath.Join(home, "postgresql.crt") } - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 - if len(sslcert) == 0 { + if cfg.SSLCert == "" { return nil } - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 - if _, err := os.Stat(sslcert); os.IsNotExist(err) { - return nil - } else if err != nil { + _, err := os.Stat(cfg.SSLCert) + if err != nil { + if pqutil.ErrNotExists(err) { + return nil + } return err } // In libpq, the ssl key is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 - sslkey := o["sslkey"] - if len(sslkey) == 0 && user != nil { - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if cfg.SSLKey == "" && home != "" { + cfg.SSLKey = filepath.Join(home, "postgresql.key") } - - if len(sslkey) > 0 { - if err := sslKeyPermissions(sslkey); err != nil { + if cfg.SSLKey != "" { + err := pqutil.SSLKeyPermissions(cfg.SSLKey) + if err != nil { return err } } - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey) if err != nil { return err } - tlsConf.Certificates = []tls.Certificate{cert} + // Using GetClientCertificate instead of Certificates per comment above. + tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &cert, nil + } return nil } +var testSystemRoots *x509.CertPool + // sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. -func sslCertificateAuthority(tlsConf *tls.Config, o values) error { - // In libpq, the root certificate is only loaded if the setting is not blank. - // - // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 - if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { - tlsConf.RootCAs = x509.NewCertPool() - - sslinline := o["sslinline"] - - var cert []byte - if sslinline == "true" { - cert = []byte(sslrootcert) - } else { - var err error - cert, err = ioutil.ReadFile(sslrootcert) - if err != nil { - return err - } - } +func sslCertificateAuthority(tlsConf *tls.Config, cfg Config) ([]byte, error) { + // Only load root certificate if not blank, like libpq. + if cfg.SSLRootCert == "" { + return nil, nil + } + + if cfg.SSLRootCert == "system" { + // No work to do as system CAs are used by default if RootCAs is nil. + tlsConf.RootCAs = testSystemRoots + return nil, nil + } + + tlsConf.RootCAs = x509.NewCertPool() - if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { - return fmterrorf("couldn't parse pem in sslrootcert") + var cert []byte + if cfg.SSLInline { + cert = []byte(cfg.SSLRootCert) + } else { + var err error + cert, err = os.ReadFile(cfg.SSLRootCert) + if err != nil { + return nil, err } } - return nil + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + return nil, errors.New("pq: couldn't parse pem from sslrootcert") + } + return cert, nil } -// sslVerifyCertificateAuthority carries out a TLS handshake to the server and -// verifies the presented certificate against the CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { - err := client.Handshake() - if err != nil { - return err - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, +// sslAppendIntermediates appends intermediate CA certificates from sslrootcert +// to the client certificate chain. This is needed so the server can verify the +// client cert when it was signed by an intermediate CA — without this, the TLS +// handshake only sends the leaf client cert. +func sslAppendIntermediates(tlsConf *tls.Config, cfg Config, rootPem []byte) { + if cfg.SSLRootCert == "" || tlsConf.GetClientCertificate == nil || len(rootPem) == 0 { + return } - for i, cert := range certs { - if i == 0 { + + var ( + pemData = slices.Clone(rootPem) + intermediates [][]byte + ) + for { + var block *pem.Block + block, pemData = pem.Decode(pemData) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { continue } - opts.Intermediates.AddCert(cert) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + continue + } + // Skip self-signed root CAs; only append intermediates. + if cert.IsCA && !bytes.Equal(cert.RawIssuer, cert.RawSubject) { + intermediates = append(intermediates, block.Bytes) + } + } + if len(intermediates) == 0 { + return + } + + // Wrap the existing GetClientCertificate to append intermediate certs to + // the certificate chain returned during the TLS handshake. + origGetCert := tlsConf.GetClientCertificate + tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert, err := origGetCert(info) + if err != nil { + return cert, err + } + cert.Certificate = append(cert.Certificate, intermediates...) + return cert, nil } - _, err = certs[0].Verify(opts) - return err } diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go deleted file mode 100644 index d587f102..00000000 --- a/vendor/github.com/lib/pq/ssl_permissions.go +++ /dev/null @@ -1,93 +0,0 @@ -//go:build !windows -// +build !windows - -package pq - -import ( - "errors" - "os" - "syscall" -) - -const ( - rootUserID = uint32(0) - - // The maximum permissions that a private key file owned by a regular user - // is allowed to have. This translates to u=rw. - maxUserOwnedKeyPermissions os.FileMode = 0600 - - // The maximum permissions that a private key file owned by root is allowed - // to have. This translates to u=rw,g=r. - maxRootOwnedKeyPermissions os.FileMode = 0640 -) - -var ( - errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") - errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") -) - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(sslkey string) error { - info, err := os.Stat(sslkey) - if err != nil { - return err - } - - err = hasCorrectPermissions(info) - - // return ErrSSLKeyHasWorldPermissions for backwards compatability with - // existing code. - if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { - err = ErrSSLKeyHasWorldPermissions - } - return err -} - -// hasCorrectPermissions checks the file info (and the unix-specific stat_t -// output) to verify that the permissions on the file are correct. -// -// If the file is owned by the same user the process is running as, -// the file should only have 0600 (u=rw). If the file is owned by root, -// and the group matches the group that the process is running in, the -// permissions cannot be more than 0640 (u=rw,g=r). The file should -// never have world permissions. -// -// Returns an error when the permission check fails. -func hasCorrectPermissions(info os.FileInfo) error { - // if file's permission matches 0600, allow access. - userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) - - // regardless of if we're running as root or not, 0600 is acceptable, - // so we return if we match the regular user permission mask. - if info.Mode().Perm()&userPermissionMask == 0 { - return nil - } - - // We need to pull the Unix file information to get the file's owner. - // If we can't access it, there's some sort of operating system level error - // and we should fail rather than attempting to use faulty information. - sysInfo := info.Sys() - if sysInfo == nil { - return ErrSSLKeyUnknownOwnership - } - - unixStat, ok := sysInfo.(*syscall.Stat_t) - if !ok { - return ErrSSLKeyUnknownOwnership - } - - // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what - // Postgres does. - if unixStat.Uid == rootUserID { - rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) - if info.Mode().Perm()&rootPermissionMask != 0 { - return errSSLKeyHasUnacceptableRootPermissions - } - return nil - } - - return errSSLKeyHasUnacceptableUserPermissions -} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go deleted file mode 100644 index 73663c8f..00000000 --- a/vendor/github.com/lib/pq/ssl_windows.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build windows -// +build windows - -package pq - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/staticcheck.conf b/vendor/github.com/lib/pq/staticcheck.conf new file mode 100644 index 00000000..83abe48e --- /dev/null +++ b/vendor/github.com/lib/pq/staticcheck.conf @@ -0,0 +1,5 @@ +checks = [ + 'all', + '-ST1000', # "Must have at least one package comment" + '-ST1003', # "func EnableInfinityTs should be EnableInfinityTS" +] diff --git a/vendor/github.com/lib/pq/stmt.go b/vendor/github.com/lib/pq/stmt.go new file mode 100644 index 00000000..ca6ecc89 --- /dev/null +++ b/vendor/github.com/lib/pq/stmt.go @@ -0,0 +1,150 @@ +package pq + +import ( + "context" + "database/sql/driver" + "fmt" + "os" + + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" +) + +type stmt struct { + cn *conn + name string + rowsHeader + colFmtData []byte + paramTyps []oid.Oid + closed bool +} + +func (st *stmt) Close() error { + if st.closed { + return nil + } + if err := st.cn.err.get(); err != nil { + return err + } + + w := st.cn.writeBuf(proto.Close) + w.byte(proto.Sync) + w.string(st.name) + err := st.cn.send(w) + if err != nil { + return st.cn.handleError(err) + } + err = st.cn.send(st.cn.writeBuf(proto.Sync)) + if err != nil { + return st.cn.handleError(err) + } + + t, _, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.CloseComplete { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected close response: %q", t) + } + st.closed = true + + t, r, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.ReadyForQuery { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: expected ready for query, but got: %q", t) + } + st.cn.processReadyForQuery(r) + + return nil +} + +func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { + return st.query(toNamedValue(v)) +} + +func (st *stmt) query(v []driver.NamedValue) (*rows, error) { + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(v) + if err != nil { + return nil, st.cn.handleError(err) + } + return &rows{ + cn: st.cn, + rowsHeader: st.rowsHeader, + }, nil +} + +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { + return st.ExecContext(context.Background(), toNamedValue(v)) +} + +func (st *stmt) exec(v []driver.NamedValue) error { + if debugProto { + fmt.Fprintf(os.Stderr, " START stmt.exec\n") + defer fmt.Fprintf(os.Stderr, " END stmt.exec\n") + } + if len(v) >= 65536 { + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) + } + if len(v) != len(st.paramTyps) { + return fmt.Errorf("pq: got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) + } + + cn := st.cn + w := cn.writeBuf(proto.Bind) + w.byte(0) // unnamed portal + w.string(st.name) + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryParameters(w, v) + if err != nil { + return err + } + } else { + w.int16(0) + w.int16(len(v)) + for i, x := range v { + if x.Value == nil { + w.int32(-1) + } else { + b, err := encode(x.Value, st.paramTyps[i]) + if err != nil { + return err + } + if b == nil { + w.int32(-1) + } else { + w.int32(len(b)) + w.bytes(b) + } + } + } + } + w.bytes(st.colFmtData) + + w.next(proto.Execute) + w.byte(0) + w.int32(0) + + w.next(proto.Sync) + err := cn.send(w) + if err != nil { + return err + } + err = cn.readBindResponse() + if err != nil { + return err + } + return cn.postExecuteWorkaround() +} + +func (st *stmt) NumInput() int { + return len(st.paramTyps) +} diff --git a/vendor/github.com/lib/pq/url.go b/vendor/github.com/lib/pq/url.go deleted file mode 100644 index aec6e95b..00000000 --- a/vendor/github.com/lib/pq/url.go +++ /dev/null @@ -1,76 +0,0 @@ -package pq - -import ( - "fmt" - "net" - nurl "net/url" - "sort" - "strings" -) - -// ParseURL no longer needs to be used by clients of this library since supplying a URL as a -// connection string to sql.Open() is now supported: -// -// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") -// -// It remains exported here for backwards-compatibility. -// -// ParseURL converts a url to a connection string for driver.Open. -// Example: -// -// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" -// -// converts to: -// -// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" -// -// A minimal example: -// -// "postgres://" -// -// This will be blank, causing driver.Open to use all of the defaults -func ParseURL(url string) (string, error) { - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - var kvs []string - escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - if host, port, err := net.SplitHostPort(u.Host); err != nil { - accrue("host", u.Host) - } else { - accrue("host", host) - accrue("port", port) - } - - if u.Path != "" { - accrue("dbname", u.Path[1:]) - } - - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil -} diff --git a/vendor/github.com/lib/pq/user_other.go b/vendor/github.com/lib/pq/user_other.go deleted file mode 100644 index 3dae8f55..00000000 --- a/vendor/github.com/lib/pq/user_other.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build js || android || hurd || zos -// +build js android hurd zos - -package pq - -func userCurrent() (string, error) { - return "", ErrCouldNotDetectUsername -} diff --git a/vendor/github.com/lib/pq/user_posix.go b/vendor/github.com/lib/pq/user_posix.go deleted file mode 100644 index 5f2d439b..00000000 --- a/vendor/github.com/lib/pq/user_posix.go +++ /dev/null @@ -1,25 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos -// +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos - -package pq - -import ( - "os" - "os/user" -) - -func userCurrent() (string, error) { - u, err := user.Current() - if err == nil { - return u.Username, nil - } - - name := os.Getenv("USER") - if name != "" { - return name, nil - } - - return "", ErrCouldNotDetectUsername -} diff --git a/vendor/github.com/lib/pq/user_windows.go b/vendor/github.com/lib/pq/user_windows.go deleted file mode 100644 index 2b691267..00000000 --- a/vendor/github.com/lib/pq/user_windows.go +++ /dev/null @@ -1,27 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. -package pq - -import ( - "path/filepath" - "syscall" -) - -// Perform Windows user name lookup identically to libpq. -// -// The PostgreSQL code makes use of the legacy Win32 function -// GetUserName, and that function has not been imported into stock Go. -// GetUserNameEx is available though, the difference being that a -// wider range of names are available. To get the output to be the -// same as GetUserName, only the base (or last) component of the -// result is returned. -func userCurrent() (string, error) { - pw_name := make([]uint16, 128) - pwname_size := uint32(len(pw_name)) - 1 - err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) - if err != nil { - return "", ErrCouldNotDetectUsername - } - s := syscall.UTF16ToString(pw_name) - u := filepath.Base(s) - return u, nil -} diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go deleted file mode 100644 index 9a1b9e07..00000000 --- a/vendor/github.com/lib/pq/uuid.go +++ /dev/null @@ -1,23 +0,0 @@ -package pq - -import ( - "encoding/hex" - "fmt" -) - -// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. -func decodeUUIDBinary(src []byte) ([]byte, error) { - if len(src) != 16 { - return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) - } - - dst := make([]byte, 36) - dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' - hex.Encode(dst[0:], src[0:4]) - hex.Encode(dst[9:], src[4:6]) - hex.Encode(dst[14:], src[6:8]) - hex.Encode(dst[19:], src[8:10]) - hex.Encode(dst[24:], src[10:16]) - - return dst, nil -} diff --git a/vendor/github.com/playwright-community/playwright-go/CONTRIBUTING.md b/vendor/github.com/playwright-community/playwright-go/CONTRIBUTING.md index 3b11995e..f3de9018 100644 --- a/vendor/github.com/playwright-community/playwright-go/CONTRIBUTING.md +++ b/vendor/github.com/playwright-community/playwright-go/CONTRIBUTING.md @@ -20,20 +20,20 @@ BROWSER=chromium HEADLESS=1 go test -v --race ./... ### Roll 1. Find out to which upstream version you want to roll, and change the value of `playwrightCliVersion` in the **run.go** to the new version. -1. Download current version of Playwright driver `go run scripts/install-browsers/main.go` -1. Apply patch `bash scripts/apply-patch.sh` -1. Fix merge conflicts if any, otherwise ignore this step. Once you are happy you can commit the changes `cd playwright; git commit -am "apply patch" && cd ..` -1. Regenerate a new patch `bash scripts/update-patch.sh` -1. Generate go code `go generate ./...` +2. Download current version of Playwright driver `go run scripts/install-browsers/main.go` +3. Apply patch `bash scripts/apply-patch.sh` +4. Fix merge conflicts if any, otherwise ignore this step. Once you are happy you can commit the changes `cd playwright; git commit -am "apply patch" && cd ..` +5. Regenerate a new patch `bash scripts/update-patch.sh` +6. Generate go code `go generate ./...` To adapt to the new version of Playwright's protocol and feature updates, you may need to modify the patch. Refer to the following steps: 1. Apply patch `bash scripts/apply-patch.sh` -1. `cd playwright` -1. Revert the patch`git reset HEAD~1` -1. Modify the files under `docs/src/api`, etc. as needed. Available references: +2. `cd playwright` +3. Revert the patch`git reset HEAD~1` +4. Modify the files under `docs/src/api`, etc. as needed. Available references: - Protocol `packages/protocol/src/protocol.yml` - [Playwright python](https://github.com/microsoft/playwright-python) -1. Commit the changes `git commit -am "apply patch"` -1. Regenerate a new patch `bash scripts/update-patch.sh` -1. Generate go code `go generate ./...`. +5. Commit the changes `git commit -am "apply patch"` +6. Regenerate a new patch `bash scripts/update-patch.sh` +7. Generate go code `go generate ./...`. diff --git a/vendor/github.com/playwright-community/playwright-go/README.md b/vendor/github.com/playwright-community/playwright-go/README.md index 13175b48..f0573cfc 100644 --- a/vendor/github.com/playwright-community/playwright-go/README.md +++ b/vendor/github.com/playwright-community/playwright-go/README.md @@ -5,7 +5,7 @@ [![PkgGoDev](https://pkg.go.dev/badge/github.com/playwright-community/playwright-go)](https://pkg.go.dev/github.com/playwright-community/playwright-go) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](http://opensource.org/licenses/MIT) [![Go Report Card](https://goreportcard.com/badge/github.com/playwright-community/playwright-go)](https://goreportcard.com/report/github.com/playwright-community/playwright-go) ![Build Status](https://github.com/playwright-community/playwright-go/workflows/Go/badge.svg) -[![Join Slack](https://img.shields.io/badge/join-slack-infomational)](https://aka.ms/playwright-slack) [![Coverage Status](https://coveralls.io/repos/github/playwright-community/playwright-go/badge.svg?branch=main)](https://coveralls.io/github/playwright-community/playwright-go?branch=main) [![Chromium version](https://img.shields.io/badge/chromium-136.0.7103.25-blue.svg?logo=google-chrome)](https://www.chromium.org/Home) [![Firefox version](https://img.shields.io/badge/firefox-137.0-blue.svg?logo=mozilla-firefox)](https://www.mozilla.org/en-US/firefox/new/) [![WebKit version](https://img.shields.io/badge/webkit-18.4-blue.svg?logo=safari)](https://webkit.org/) +[![Join Slack](https://img.shields.io/badge/join-slack-infomational)](https://aka.ms/playwright-slack) [![Coverage Status](https://coveralls.io/repos/github/playwright-community/playwright-go/badge.svg?branch=main)](https://coveralls.io/github/playwright-community/playwright-go?branch=main) [![Chromium version](https://img.shields.io/badge/chromium-143.0.7499.4-blue.svg?logo=google-chrome)](https://www.chromium.org/Home) [![Firefox version](https://img.shields.io/badge/firefox-144.0.2-blue.svg?logo=mozilla-firefox)](https://www.mozilla.org/en-US/firefox/new/) [![WebKit version](https://img.shields.io/badge/webkit-26.0-blue.svg?logo=safari)](https://webkit.org/) [API reference](https://playwright.dev/docs/api/class-playwright) | [Example recipes](https://github.com/playwright-community/playwright-go/tree/main/examples) @@ -13,9 +13,9 @@ Playwright is a Go library to automate [Chromium](https://www.chromium.org/Home) | | Linux | macOS | Windows | | :--- | :---: | :---: | :---: | -| Chromium 136.0.7103.25 | ✅ | ✅ | ✅ | -| WebKit 18.4 | ✅ | ✅ | ✅ | -| Firefox 137.0 | ✅ | ✅ | ✅ | +| Chromium 143.0.7499.4 | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| WebKit 26.0 | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Firefox 144.0.2 | :white_check_mark: | :white_check_mark: | :white_check_mark: | Headless execution is supported for all the browsers on all platforms. diff --git a/vendor/github.com/playwright-community/playwright-go/apiresponse_assertions.go b/vendor/github.com/playwright-community/playwright-go/apiresponse_assertions.go index 187618e2..a86c349e 100644 --- a/vendor/github.com/playwright-community/playwright-go/apiresponse_assertions.go +++ b/vendor/github.com/playwright-community/playwright-go/apiresponse_assertions.go @@ -67,9 +67,6 @@ func subString(s string, start, length int) string { length = 0 } rs := []rune(s) - end := start + length - if end > len(rs) { - end = len(rs) - } + end := min(start+length, len(rs)) return string(rs[start:end]) } diff --git a/vendor/github.com/playwright-community/playwright-go/artifact.go b/vendor/github.com/playwright-community/playwright-go/artifact.go index c76b8927..bb66c106 100644 --- a/vendor/github.com/playwright-community/playwright-go/artifact.go +++ b/vendor/github.com/playwright-community/playwright-go/artifact.go @@ -23,7 +23,7 @@ func (a *artifactImpl) PathAfterFinished() (string, error) { func (a *artifactImpl) SaveAs(path string) error { if !a.connection.isRemote { - _, err := a.channel.Send("saveAs", map[string]interface{}{ + _, err := a.channel.Send("saveAs", map[string]any{ "path": path, }) return err @@ -63,7 +63,7 @@ func (a *artifactImpl) ReadIntoBuffer() ([]byte, error) { return stream.(*streamImpl).ReadAll() } -func newArtifact(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *artifactImpl { +func newArtifact(parent *channelOwner, objectType string, guid string, initializer map[string]any) *artifactImpl { artifact := &artifactImpl{} artifact.createChannelOwner(artifact, parent, objectType, guid, initializer) return artifact diff --git a/vendor/github.com/playwright-community/playwright-go/assertions.go b/vendor/github.com/playwright-community/playwright-go/assertions.go index 5e0e7105..c1756de4 100644 --- a/vendor/github.com/playwright-community/playwright-go/assertions.go +++ b/vendor/github.com/playwright-community/playwright-go/assertions.go @@ -45,20 +45,20 @@ type expectedTextValue struct { } type frameExpectOptions struct { - ExpressionArg interface{} `json:"expressionArg,omitempty"` + ExpressionArg any `json:"expressionArg,omitempty"` ExpectedText []expectedTextValue `json:"expectedText,omitempty"` ExpectedNumber *float64 `json:"expectedNumber,omitempty"` - ExpectedValue interface{} `json:"expectedValue,omitempty"` + ExpectedValue any `json:"expectedValue,omitempty"` UseInnerText *bool `json:"useInnerText,omitempty"` IsNot bool `json:"isNot"` Timeout *float64 `json:"timeout"` } type frameExpectResult struct { - Matches bool `json:"matches"` - Received interface{} `json:"received,omitempty"` - TimedOut *bool `json:"timedOut,omitempty"` - Log []string `json:"log,omitempty"` + Matches bool `json:"matches"` + Received any `json:"received,omitempty"` + TimedOut *bool `json:"timedOut,omitempty"` + Log []string `json:"log,omitempty"` } type assertionsBase struct { @@ -70,7 +70,7 @@ type assertionsBase struct { func (b *assertionsBase) expect( expression string, options frameExpectOptions, - expected interface{}, + expected any, message string, ) error { options.IsNot = b.isNot @@ -101,7 +101,7 @@ func (b *assertionsBase) expect( } func toExpectedTextValues( - items []interface{}, + items []any, matchSubstring bool, normalizeWhiteSpace bool, ignoreCase *bool, @@ -132,13 +132,13 @@ func toExpectedTextValues( return out, nil } -func convertToInterfaceList(v interface{}) []interface{} { +func convertToInterfaceList(v any) []any { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Slice { - return []interface{}{v} + return []any{v} } - list := make([]interface{}, rv.Len()) + list := make([]any, rv.Len()) for i := 0; i < rv.Len(); i++ { list[i] = rv.Index(i).Interface() } diff --git a/vendor/github.com/playwright-community/playwright-go/binding_call.go b/vendor/github.com/playwright-community/playwright-go/binding_call.go index 84689921..21b13348 100644 --- a/vendor/github.com/playwright-community/playwright-go/binding_call.go +++ b/vendor/github.com/playwright-community/playwright-go/binding_call.go @@ -23,15 +23,15 @@ type BindingSource struct { } // ExposedFunction represents the func signature of an exposed function -type ExposedFunction = func(args ...interface{}) interface{} +type ExposedFunction = func(args ...any) any // BindingCallFunction represents the func signature of an exposed binding call func -type BindingCallFunction func(source *BindingSource, args ...interface{}) interface{} +type BindingCallFunction func(source *BindingSource, args ...any) any func (b *bindingCallImpl) Call(f BindingCallFunction) { defer func() { if r := recover(); r != nil { - if _, err := b.channel.Send("reject", map[string]interface{}{ + if _, err := b.channel.Send("reject", map[string]any{ "error": serializeError(r.(error)), }); err != nil { logger.Error("could not reject BindingCall", "error", err) @@ -45,18 +45,18 @@ func (b *bindingCallImpl) Call(f BindingCallFunction) { Page: frame.Page(), Frame: frame, } - var result interface{} + var result any if handle, ok := b.initializer["handle"]; ok { result = f(source, fromChannel(handle)) } else { - initializerArgs := b.initializer["args"].([]interface{}) - funcArgs := []interface{}{} - for i := 0; i < len(initializerArgs); i++ { + initializerArgs := b.initializer["args"].([]any) + funcArgs := []any{} + for i := range initializerArgs { funcArgs = append(funcArgs, parseResult(initializerArgs[i])) } result = f(source, funcArgs...) } - _, err := b.channel.Send("resolve", map[string]interface{}{ + _, err := b.channel.Send("resolve", map[string]any{ "result": serializeArgument(result), }) if err != nil { @@ -64,12 +64,12 @@ func (b *bindingCallImpl) Call(f BindingCallFunction) { } } -func serializeError(err error) map[string]interface{} { +func serializeError(err error) map[string]any { st := stack.Trace().TrimRuntime() if len(st) == 0 { // https://github.com/go-stack/stack/issues/27 st = stack.Trace() } - return map[string]interface{}{ + return map[string]any{ "error": &Error{ Name: "Playwright for Go Error", Message: err.Error(), @@ -80,7 +80,7 @@ func serializeError(err error) map[string]interface{} { } } -func newBindingCall(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *bindingCallImpl { +func newBindingCall(parent *channelOwner, objectType string, guid string, initializer map[string]any) *bindingCallImpl { bt := &bindingCallImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) return bt diff --git a/vendor/github.com/playwright-community/playwright-go/browser.go b/vendor/github.com/playwright-community/playwright-go/browser.go index c87540ab..1f611c98 100644 --- a/vendor/github.com/playwright-community/playwright-go/browser.go +++ b/vendor/github.com/playwright-community/playwright-go/browser.go @@ -30,7 +30,7 @@ func (b *browserImpl) IsConnected() bool { } func (b *browserImpl) NewContext(options ...BrowserNewContextOptions) (BrowserContext, error) { - overrides := map[string]interface{}{} + overrides := map[string]any{} option := BrowserNewContextOptions{} if len(options) == 1 { option = options[0] @@ -93,6 +93,9 @@ func (b *browserImpl) NewContext(options ...BrowserNewContextOptions) (BrowserCo context := fromChannel(channel).(*browserContextImpl) context.browser = b b.browserType.(*browserTypeImpl).didCreateContext(context, &option, nil) + if err := context.initializeHarFromOptions(); err != nil { + return nil, err + } return context, nil } @@ -139,7 +142,7 @@ func (b *browserImpl) Close(options ...BrowserCloseOptions) (err error) { if b.shouldCloseConnectionOnClose { err = b.connection.Stop() } else if b.closeReason != nil { - _, err = b.channel.Send("close", map[string]interface{}{ + _, err = b.channel.Send("close", map[string]any{ "reason": b.closeReason, }) } else { @@ -156,7 +159,7 @@ func (b *browserImpl) Version() string { } func (b *browserImpl) StartTracing(options ...BrowserStartTracingOptions) error { - overrides := map[string]interface{}{} + overrides := map[string]any{} option := BrowserStartTracingOptions{} if len(options) == 1 { option = options[0] @@ -215,7 +218,7 @@ func (b *browserImpl) OnDisconnected(fn func(Browser)) { b.On("disconnected", fn) } -func newBrowser(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *browserImpl { +func newBrowser(parent *channelOwner, objectType string, guid string, initializer map[string]any) *browserImpl { b := &browserImpl{ isConnected: true, contexts: make([]BrowserContext, 0), @@ -227,11 +230,11 @@ func newBrowser(parent *channelOwner, objectType string, guid string, initialize return b } -func transformClientCertificate(clientCertificates []ClientCertificate) ([]map[string]interface{}, error) { - results := make([]map[string]interface{}, 0) +func transformClientCertificate(clientCertificates []ClientCertificate) ([]map[string]any, error) { + results := make([]map[string]any, 0) for _, cert := range clientCertificates { - data := map[string]interface{}{ + data := map[string]any{ "origin": cert.Origin, "passphrase": cert.Passphrase, } diff --git a/vendor/github.com/playwright-community/playwright-go/browser_context.go b/vendor/github.com/playwright-community/playwright-go/browser_context.go index 1d420d3b..8f636bec 100644 --- a/vendor/github.com/playwright-community/playwright-go/browser_context.go +++ b/vendor/github.com/playwright-community/playwright-go/browser_context.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "github.com/playwright-community/playwright-go/internal/safe" ) @@ -16,7 +17,7 @@ import ( type browserContextImpl struct { channelOwner timeoutSettings *timeoutSettings - closeWasCalled bool + closeWasCalled atomic.Bool options *BrowserNewContextOptions pages []Page routes []*routeHandlerEntry @@ -45,7 +46,7 @@ func (b *browserContextImpl) SetDefaultNavigationTimeout(timeout float64) { func (b *browserContextImpl) setDefaultNavigationTimeoutImpl(timeout *float64) { b.timeoutSettings.SetDefaultNavigationTimeout(timeout) - b.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]interface{}{ + b.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]any{ "timeout": timeout, }) } @@ -56,7 +57,7 @@ func (b *browserContextImpl) SetDefaultTimeout(timeout float64) { func (b *browserContextImpl) setDefaultTimeoutImpl(timeout *float64) { b.timeoutSettings.SetDefaultTimeout(timeout) - b.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]interface{}{ + b.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]any{ "timeout": timeout, }) } @@ -75,8 +76,8 @@ func (b *browserContextImpl) Tracing() Tracing { return b.tracing } -func (b *browserContextImpl) NewCDPSession(page interface{}) (CDPSession, error) { - params := map[string]interface{}{} +func (b *browserContextImpl) NewCDPSession(page any) (CDPSession, error) { + params := map[string]any{} if p, ok := page.(*pageImpl); ok { params["page"] = p.channel @@ -108,14 +109,14 @@ func (b *browserContextImpl) NewPage() (Page, error) { } func (b *browserContextImpl) Cookies(urls ...string) ([]Cookie, error) { - result, err := b.channel.Send("cookies", map[string]interface{}{ + result, err := b.channel.Send("cookies", map[string]any{ "urls": urls, }) if err != nil { return nil, err } - cookies := make([]Cookie, len(result.([]interface{}))) - for i, item := range result.([]interface{}) { + cookies := make([]Cookie, len(result.([]any))) + for i, item := range result.([]any) { cookie := &Cookie{} remapMapToStruct(item, cookie) cookies[i] = *cookie @@ -124,14 +125,14 @@ func (b *browserContextImpl) Cookies(urls ...string) ([]Cookie, error) { } func (b *browserContextImpl) AddCookies(cookies []OptionalCookie) error { - _, err := b.channel.Send("addCookies", map[string]interface{}{ + _, err := b.channel.Send("addCookies", map[string]any{ "cookies": cookies, }) return err } func (b *browserContextImpl) ClearCookies(options ...BrowserContextClearCookiesOptions) error { - params := map[string]interface{}{} + params := map[string]any{} if len(options) == 1 { if options[0].Domain != nil { switch t := options[0].Domain.(type) { @@ -181,7 +182,7 @@ func (b *browserContextImpl) ClearCookies(options ...BrowserContextClearCookiesO } func (b *browserContextImpl) GrantPermissions(permissions []string, options ...BrowserContextGrantPermissionsOptions) error { - _, err := b.channel.Send("grantPermissions", map[string]interface{}{ + _, err := b.channel.Send("grantPermissions", map[string]any{ "permissions": permissions, }, options) return err @@ -193,26 +194,26 @@ func (b *browserContextImpl) ClearPermissions() error { } func (b *browserContextImpl) SetGeolocation(geolocation *Geolocation) error { - _, err := b.channel.Send("setGeolocation", map[string]interface{}{ + _, err := b.channel.Send("setGeolocation", map[string]any{ "geolocation": geolocation, }) return err } func (b *browserContextImpl) ResetGeolocation() error { - _, err := b.channel.Send("setGeolocation", map[string]interface{}{}) + _, err := b.channel.Send("setGeolocation", map[string]any{}) return err } func (b *browserContextImpl) SetExtraHTTPHeaders(headers map[string]string) error { - _, err := b.channel.Send("setExtraHTTPHeaders", map[string]interface{}{ + _, err := b.channel.Send("setExtraHTTPHeaders", map[string]any{ "headers": serializeMapToNameAndValue(headers), }) return err } func (b *browserContextImpl) SetOffline(offline bool) error { - _, err := b.channel.Send("setOffline", map[string]interface{}{ + _, err := b.channel.Send("setOffline", map[string]any{ "offline": offline, }) return err @@ -230,7 +231,7 @@ func (b *browserContextImpl) AddInitScript(script Script) error { } source = string(content) } - _, err := b.channel.Send("addInitScript", map[string]interface{}{ + _, err := b.channel.Send("addInitScript", map[string]any{ "source": source, }) return err @@ -249,7 +250,7 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct if _, ok := b.bindings.Load(name); ok { return fmt.Errorf("Function '%s' has been already registered", name) } - _, err := b.channel.Send("exposeBinding", map[string]interface{}{ + _, err := b.channel.Send("exposeBinding", map[string]any{ "name": name, "needsHandle": needsHandle, }) @@ -261,19 +262,19 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct } func (b *browserContextImpl) ExposeFunction(name string, binding ExposedFunction) error { - return b.ExposeBinding(name, func(source *BindingSource, args ...interface{}) interface{} { + return b.ExposeBinding(name, func(source *BindingSource, args ...any) any { return binding(args...) }) } -func (b *browserContextImpl) Route(url interface{}, handler routeHandler, times ...int) error { +func (b *browserContextImpl) Route(url any, handler routeHandler, times ...int) error { b.Lock() defer b.Unlock() b.routes = slices.Insert(b.routes, 0, newRouteHandlerEntry(newURLMatcher(url, b.options.BaseURL), handler, times...)) return b.updateInterceptionPatterns() } -func (b *browserContextImpl) Unroute(url interface{}, handlers ...routeHandler) error { +func (b *browserContextImpl) Unroute(url any, handlers ...routeHandler) error { removed, remaining, err := unroute(b.routes, url, handlers...) if err != nil { return err @@ -351,13 +352,13 @@ func (b *browserContextImpl) RouteFromHAR(har string, options ...BrowserContextR return router.addContextRoute(b) } -func (b *browserContextImpl) WaitForEvent(event string, options ...BrowserContextWaitForEventOptions) (interface{}, error) { +func (b *browserContextImpl) WaitForEvent(event string, options ...BrowserContextWaitForEventOptions) (any, error) { return b.waiterForEvent(event, options...).Wait() } func (b *browserContextImpl) waiterForEvent(event string, options ...BrowserContextWaitForEventOptions) *waiter { timeout := b.timeoutSettings.Timeout() - var predicate interface{} = nil + var predicate any = nil if len(options) == 1 { if options[0].Timeout != nil { timeout = *options[0].Timeout @@ -386,7 +387,7 @@ func (b *browserContextImpl) ExpectConsoleMessage(cb func() error, options ...Br return ret.(ConsoleMessage), nil } -func (b *browserContextImpl) ExpectEvent(event string, cb func() error, options ...BrowserContextExpectEventOptions) (interface{}, error) { +func (b *browserContextImpl) ExpectEvent(event string, cb func() error, options ...BrowserContextExpectEventOptions) (any, error) { if len(options) == 1 { return b.waiterForEvent(event, BrowserContextWaitForEventOptions(options[0])).RunAndWait(cb) } @@ -411,15 +412,15 @@ func (b *browserContextImpl) ExpectPage(cb func() error, options ...BrowserConte } func (b *browserContextImpl) Close(options ...BrowserContextCloseOptions) error { - if b.closeWasCalled { + if b.closeWasCalled.Load() { return nil } if len(options) == 1 { b.closeReason = options[0].Reason } - b.closeWasCalled = true + b.closeWasCalled.Store(true) - _, err := b.channel.connection.WrapAPICall(func() (interface{}, error) { + _, err := b.channel.connection.WrapAPICall(func() (any, error) { return nil, b.request.Dispose(APIRequestContextDisposeOptions{ Reason: b.closeReason, }) @@ -428,9 +429,9 @@ func (b *browserContextImpl) Close(options ...BrowserContextCloseOptions) error return err } - innerClose := func() (interface{}, error) { + innerClose := func() (any, error) { for harId, harMetaData := range b.harRecorders { - overrides := map[string]interface{}{} + overrides := map[string]any{} if harId != "" { overrides["harId"] = harId } @@ -467,7 +468,7 @@ func (b *browserContextImpl) Close(options ...BrowserContextCloseOptions) error return err } - _, err = b.channel.Send("close", map[string]interface{}{ + _, err = b.channel.Send("close", map[string]any{ "reason": b.closeReason, }) if err != nil { @@ -479,13 +480,13 @@ func (b *browserContextImpl) Close(options ...BrowserContextCloseOptions) error type browserContextRecordIntoHarOptions struct { Page Page - URL interface{} + URL any UpdateContent *HarContentPolicy UpdateMode *HarMode } func (b *browserContextImpl) recordIntoHar(har string, options ...browserContextRecordIntoHarOptions) error { - overrides := map[string]interface{}{} + overrides := map[string]any{} harOptions := recordHarInputOptions{ Path: har, Content: HarContentPolicyAttach, @@ -584,7 +585,7 @@ func (b *browserContextImpl) onRoute(route *routeImpl) { b.Lock() defer b.Unlock() if len(b.routes) == 0 { - _, err := b.connection.WrapAPICall(func() (interface{}, error) { + _, err := b.connection.WrapAPICall(func() (any, error) { err := b.updateInterceptionPatterns() return nil, err }, true) @@ -597,7 +598,7 @@ func (b *browserContextImpl) onRoute(route *routeImpl) { url := route.Request().URL() for _, handlerEntry := range routes { // If the page or the context was closed we stall all requests right away. - if (page != nil && page.closeWasCalled) || b.closeWasCalled { + if (page != nil && page.closeWasCalled.Load()) || b.closeWasCalled.Load() { return } if !handlerEntry.Matches(url) { @@ -627,7 +628,7 @@ func (b *browserContextImpl) onRoute(route *routeImpl) { func (b *browserContextImpl) updateInterceptionPatterns() error { patterns := prepareInterceptionPatterns(b.routes) - _, err := b.channel.Send("setNetworkInterceptionPatterns", map[string]interface{}{ + _, err := b.channel.Send("setNetworkInterceptionPatterns", map[string]any{ "patterns": patterns, }) return err @@ -642,7 +643,7 @@ func (b *browserContextImpl) pause() <-chan error { return ret } -func (b *browserContextImpl) onBackgroundPage(ev map[string]interface{}) { +func (b *browserContextImpl) onBackgroundPage(ev map[string]any) { b.Lock() p := fromChannel(ev["page"]).(*pageImpl) p.browserContext = b @@ -662,17 +663,41 @@ func (b *browserContextImpl) setOptions(options *BrowserNewContextOptions, trace options = &BrowserNewContextOptions{} } b.options = options - if b.options != nil && b.options.RecordHarPath != nil { - b.harRecorders[""] = harRecordingMetadata{ - Path: *b.options.RecordHarPath, - Content: b.options.RecordHarContent, - } - } if tracesDir != nil { b.tracing.tracesDir = *tracesDir } } +// initializeHarFromOptions starts HAR recording if RecordHarPath is set in options. +// This must be called after context creation to properly register the HAR recorder on the server. +func (b *browserContextImpl) initializeHarFromOptions() error { + if b.options == nil || b.options.RecordHarPath == nil { + return nil + } + path := *b.options.RecordHarPath + // Determine default content policy based on file extension + var content *HarContentPolicy + if strings.HasSuffix(strings.ToLower(path), ".zip") { + content = HarContentPolicyAttach + } else { + content = HarContentPolicyEmbed + } + if b.options.RecordHarContent != nil { + content = b.options.RecordHarContent + } else if b.options.RecordHarOmitContent != nil && *b.options.RecordHarOmitContent { + content = HarContentPolicyOmit + } + mode := HarModeFull + if b.options.RecordHarMode != nil { + mode = b.options.RecordHarMode + } + return b.recordIntoHar(path, browserContextRecordIntoHarOptions{ + URL: b.options.RecordHarURLFilter, + UpdateContent: content, + UpdateMode: mode, + }) +} + func (b *browserContextImpl) BackgroundPages() []Page { b.Lock() defer b.Unlock() @@ -725,7 +750,7 @@ func (b *browserContextImpl) OnWebError(fn func(WebError)) { b.On("weberror", fn) } -func (b *browserContextImpl) RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error { +func (b *browserContextImpl) RouteWebSocket(url any, handler func(WebSocketRoute)) error { b.Lock() defer b.Unlock() b.webSocketRoutes = slices.Insert(b.webSocketRoutes, 0, newWebSocketRouteHandler(newURLMatcher(url, b.options.BaseURL), handler)) @@ -753,7 +778,7 @@ func (b *browserContextImpl) onWebSocketRoute(wr WebSocketRoute) { func (b *browserContextImpl) updateWebSocketInterceptionPatterns() error { patterns := prepareWebSocketRouteHandlerInterceptionPatterns(b.webSocketRoutes) - _, err := b.channel.Send("setWebSocketInterceptionPatterns", map[string]interface{}{ + _, err := b.channel.Send("setWebSocketInterceptionPatterns", map[string]any{ "patterns": patterns, }) return err @@ -771,7 +796,7 @@ func (b *browserContextImpl) effectiveCloseReason() *string { return nil } -func newBrowserContext(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *browserContextImpl { +func newBrowserContext(parent *channelOwner, objectType string, guid string, initializer map[string]any) *browserContextImpl { bt := &browserContextImpl{ timeoutSettings: newTimeoutSettings(nil), pages: make([]Page, 0), @@ -790,36 +815,52 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini bt.tracing = fromChannel(initializer["tracing"]).(*tracingImpl) bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl) bt.clock = newClock(bt) - bt.channel.On("bindingCall", func(params map[string]interface{}) { + + // Register this context with the selectors manager for custom selector engines + if bt.browser != nil && bt.browser.browserType != nil { + if browserType, ok := bt.browser.browserType.(*browserTypeImpl); ok && browserType.playwright != nil { + browserType.playwright.Selectors.(*selectorsImpl).addContext(bt) + } + } + + bt.channel.On("bindingCall", func(params map[string]any) { bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) }) - bt.channel.On("close", bt.onClose) - bt.channel.On("page", func(payload map[string]interface{}) { + bt.channel.On("close", func() { + // Unregister this context from the selectors manager + if bt.browser != nil && bt.browser.browserType != nil { + if browserType, ok := bt.browser.browserType.(*browserTypeImpl); ok && browserType.playwright != nil { + browserType.playwright.Selectors.(*selectorsImpl).removeContext(bt) + } + } + bt.onClose() + }) + bt.channel.On("page", func(payload map[string]any) { bt.onPage(fromChannel(payload["page"]).(*pageImpl)) }) - bt.channel.On("route", func(params map[string]interface{}) { + bt.channel.On("route", func(params map[string]any) { bt.channel.CreateTask(func() { bt.onRoute(fromChannel(params["route"]).(*routeImpl)) }) }) - bt.channel.On("webSocketRoute", func(params map[string]interface{}) { + bt.channel.On("webSocketRoute", func(params map[string]any) { bt.channel.CreateTask(func() { bt.onWebSocketRoute(fromChannel(params["webSocketRoute"]).(*webSocketRouteImpl)) }) }) bt.channel.On("backgroundPage", bt.onBackgroundPage) - bt.channel.On("serviceWorker", func(params map[string]interface{}) { + bt.channel.On("serviceWorker", func(params map[string]any) { bt.onServiceWorker(fromChannel(params["worker"]).(*workerImpl)) }) - bt.channel.On("console", func(ev map[string]interface{}) { + bt.channel.On("console", func(ev map[string]any) { message := newConsoleMessage(ev) bt.Emit("console", message) if message.page != nil { message.page.Emit("console", message) } }) - bt.channel.On("dialog", func(params map[string]interface{}) { + bt.channel.On("dialog", func(params map[string]any) { dialog := fromChannel(params["dialog"]).(*dialogImpl) go func() { hasListeners := bt.Emit("dialog", dialog) @@ -843,9 +884,9 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini }() }) bt.channel.On( - "pageError", func(ev map[string]interface{}) { + "pageError", func(ev map[string]any) { pwErr := &Error{} - remapMapToStruct(ev["error"].(map[string]interface{})["error"], pwErr) + remapMapToStruct(ev["error"].(map[string]any)["error"], pwErr) err := parseError(*pwErr) page := fromNullableChannel(ev["page"]) if page != nil { @@ -856,7 +897,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini } }, ) - bt.channel.On("request", func(ev map[string]interface{}) { + bt.channel.On("request", func(ev map[string]any) { request := fromChannel(ev["request"]).(*requestImpl) page := fromNullableChannel(ev["page"]) bt.Emit("request", request) @@ -864,7 +905,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini page.(*pageImpl).Emit("request", request) } }) - bt.channel.On("requestFailed", func(ev map[string]interface{}) { + bt.channel.On("requestFailed", func(ev map[string]any) { request := fromChannel(ev["request"]).(*requestImpl) failureText := ev["failureText"] if failureText != nil { @@ -878,7 +919,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini } }) - bt.channel.On("requestFinished", func(ev map[string]interface{}) { + bt.channel.On("requestFinished", func(ev map[string]any) { request := fromChannel(ev["request"]).(*requestImpl) response := fromNullableChannel(ev["response"]) page := fromNullableChannel(ev["page"]) @@ -891,7 +932,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini close(response.(*responseImpl).finished) } }) - bt.channel.On("response", func(ev map[string]interface{}) { + bt.channel.On("response", func(ev map[string]any) { response := fromChannel(ev["response"]).(*responseImpl) page := fromNullableChannel(ev["page"]) bt.Emit("response", response) diff --git a/vendor/github.com/playwright-community/playwright-go/browser_type.go b/vendor/github.com/playwright-community/playwright-go/browser_type.go index 41a8b184..95d611a1 100644 --- a/vendor/github.com/playwright-community/playwright-go/browser_type.go +++ b/vendor/github.com/playwright-community/playwright-go/browser_type.go @@ -18,7 +18,11 @@ func (b *browserTypeImpl) ExecutablePath() string { } func (b *browserTypeImpl) Launch(options ...BrowserTypeLaunchOptions) (Browser, error) { - overrides := map[string]interface{}{} + overrides := map[string]any{} + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = float64(30000) // default 30s + } if len(options) == 1 && options[0].Env != nil { overrides["env"] = serializeMapToNameAndValue(options[0].Env) options[0].Env = nil @@ -33,9 +37,13 @@ func (b *browserTypeImpl) Launch(options ...BrowserTypeLaunchOptions) (Browser, } func (b *browserTypeImpl) LaunchPersistentContext(userDataDir string, options ...BrowserTypeLaunchPersistentContextOptions) (BrowserContext, error) { - overrides := map[string]interface{}{ + overrides := map[string]any{ "userDataDir": userDataDir, } + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = float64(30000) // default 30s + } option := &BrowserNewContextOptions{} var tracesDir *string = nil if len(options) == 1 { @@ -87,22 +95,30 @@ func (b *browserTypeImpl) LaunchPersistentContext(userDataDir string, options .. options[0].RecordHarOmitContent = nil } } - channel, err := b.channel.Send("launchPersistentContext", options, overrides) + response, err := b.channel.SendReturnAsDict("launchPersistentContext", options, overrides) if err != nil { return nil, err } - context := fromChannel(channel).(*browserContextImpl) + context := fromChannel(response["context"]).(*browserContextImpl) b.didCreateContext(context, option, tracesDir) + if err := context.initializeHarFromOptions(); err != nil { + return nil, err + } return context, nil } func (b *browserTypeImpl) Connect(wsEndpoint string, options ...BrowserTypeConnectOptions) (Browser, error) { - overrides := map[string]interface{}{ + overrides := map[string]any{ "wsEndpoint": wsEndpoint, "headers": map[string]string{ - "x-playwright-browser": b.Name(), + "x-playwright-browser": b.Name(), + "x-playwright-launch-options": "{}", }, } + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = float64(0) // default no timeout + } if len(options) == 1 { if options[0].Headers != nil { for k, v := range options[0].Headers { @@ -144,9 +160,13 @@ func (b *browserTypeImpl) Connect(wsEndpoint string, options ...BrowserTypeConne } func (b *browserTypeImpl) ConnectOverCDP(endpointURL string, options ...BrowserTypeConnectOverCDPOptions) (Browser, error) { - overrides := map[string]interface{}{ + overrides := map[string]any{ "endpointURL": endpointURL, } + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = float64(30000) // default 30s + } if len(options) == 1 { if options[0].Headers != nil { overrides["headers"] = serializeMapToNameAndValue(options[0].Headers) @@ -174,7 +194,7 @@ func (b *browserTypeImpl) didLaunchBrowser(browser *browserImpl) { browser.browserType = b } -func newBrowserType(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *browserTypeImpl { +func newBrowserType(parent *channelOwner, objectType string, guid string, initializer map[string]any) *browserTypeImpl { bt := &browserTypeImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) return bt diff --git a/vendor/github.com/playwright-community/playwright-go/cdp_session.go b/vendor/github.com/playwright-community/playwright-go/cdp_session.go index e9bba82d..ca83e272 100644 --- a/vendor/github.com/playwright-community/playwright-go/cdp_session.go +++ b/vendor/github.com/playwright-community/playwright-go/cdp_session.go @@ -9,8 +9,8 @@ func (c *cdpSessionImpl) Detach() error { return err } -func (c *cdpSessionImpl) Send(method string, params map[string]interface{}) (interface{}, error) { - result, err := c.channel.Send("send", map[string]interface{}{ +func (c *cdpSessionImpl) Send(method string, params map[string]any) (any, error) { + result, err := c.channel.Send("send", map[string]any{ "method": method, "params": params, }) @@ -21,16 +21,16 @@ func (c *cdpSessionImpl) Send(method string, params map[string]interface{}) (int return result, err } -func (c *cdpSessionImpl) onEvent(params map[string]interface{}) { +func (c *cdpSessionImpl) onEvent(params map[string]any) { c.Emit(params["method"].(string), params["params"]) } -func newCDPSession(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *cdpSessionImpl { +func newCDPSession(parent *channelOwner, objectType string, guid string, initializer map[string]any) *cdpSessionImpl { bt := &cdpSessionImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) - bt.channel.On("event", func(params map[string]interface{}) { + bt.channel.On("event", func(params map[string]any) { bt.onEvent(params) }) diff --git a/vendor/github.com/playwright-community/playwright-go/channel.go b/vendor/github.com/playwright-community/playwright-go/channel.go index b0bded46..eb1d4983 100644 --- a/vendor/github.com/playwright-community/playwright-go/channel.go +++ b/vendor/github.com/playwright-community/playwright-go/channel.go @@ -10,7 +10,7 @@ type channel struct { guid string connection *connection owner *channelOwner // to avoid type conversion - object interface{} // retain type info (for fromChannel needed) + object any // retain type info (for fromChannel needed) } func (c *channel) MarshalJSON() ([]byte, error) { @@ -36,20 +36,36 @@ func (c *channel) CreateTask(fn func()) { }() } -func (c *channel) Send(method string, options ...interface{}) (interface{}, error) { - return c.connection.WrapAPICall(func() (interface{}, error) { - return c.innerSend(method, options...).GetResultValue() +func (c *channel) Send(method string, options ...any) (any, error) { + return c.connection.WrapAPICall(func() (any, error) { + result, err := c.innerSend(method, options...).GetResultValue() + if err != nil { + return nil, err + } + // GUIDs are now always eagerly resolved in connection.Dispatch + return result, nil }, c.owner.isInternalType) } -func (c *channel) SendReturnAsDict(method string, options ...interface{}) (map[string]interface{}, error) { - ret, err := c.connection.WrapAPICall(func() (interface{}, error) { - return c.innerSend(method, options...).GetResult() +func (c *channel) SendReturnAsDict(method string, options ...any) (map[string]any, error) { + ret, err := c.connection.WrapAPICall(func() (any, error) { + result, err := c.innerSend(method, options...).GetResult() + if err != nil { + return nil, err + } + // GUIDs are now always eagerly resolved in connection.Dispatch + return result, nil }, c.owner.isInternalType) - return ret.(map[string]interface{}), err + if err != nil { + return nil, err + } + if ret == nil { + return make(map[string]any), nil + } + return ret.(map[string]any), nil } -func (c *channel) innerSend(method string, options ...interface{}) *protocolCallback { +func (c *channel) innerSend(method string, options ...any) *protocolCallback { if err := c.connection.err.Get(); err != nil { c.connection.err.Set(nil) pc := newProtocolCallback(false, c.connection.abort) @@ -62,17 +78,17 @@ func (c *channel) innerSend(method string, options ...interface{}) *protocolCall // SendNoReply ignores return value and errors // almost equivalent to `send(...).catch(() => {})` -func (c *channel) SendNoReply(method string, options ...interface{}) { +func (c *channel) SendNoReply(method string, options ...any) { c.innerSendNoReply(method, c.owner.isInternalType, options...) } -func (c *channel) SendNoReplyInternal(method string, options ...interface{}) { +func (c *channel) SendNoReplyInternal(method string, options ...any) { c.innerSendNoReply(method, true, options...) } -func (c *channel) innerSendNoReply(method string, isInternal bool, options ...interface{}) { +func (c *channel) innerSendNoReply(method string, isInternal bool, options ...any) { params := transformOptions(options...) - _, err := c.connection.WrapAPICall(func() (interface{}, error) { + _, err := c.connection.WrapAPICall(func() (any, error) { return c.connection.sendMessageToServer(c.owner, method, params, true).GetResult() }, isInternal) if err != nil { @@ -81,7 +97,7 @@ func (c *channel) innerSendNoReply(method string, isInternal bool, options ...in } } -func newChannel(owner *channelOwner, object interface{}) *channel { +func newChannel(owner *channelOwner, object any) *channel { channel := &channel{ connection: owner.connection, guid: owner.guid, diff --git a/vendor/github.com/playwright-community/playwright-go/channel_owner.go b/vendor/github.com/playwright-community/playwright-go/channel_owner.go index 5159eb2c..6e8450bf 100644 --- a/vendor/github.com/playwright-community/playwright-go/channel_owner.go +++ b/vendor/github.com/playwright-community/playwright-go/channel_owner.go @@ -13,7 +13,7 @@ type channelOwner struct { objects map[string]*channelOwner eventToSubscriptionMapping map[string]string connection *connection - initializer map[string]interface{} + initializer map[string]any parent *channelOwner wasCollected bool isInternalType bool @@ -49,36 +49,36 @@ func (c *channelOwner) setEventSubscriptionMapping(mapping map[string]string) { func (c *channelOwner) updateSubscription(event string, enabled bool) { protocolEvent, ok := c.eventToSubscriptionMapping[event] if ok { - c.channel.SendNoReplyInternal("updateSubscription", map[string]interface{}{ + c.channel.SendNoReplyInternal("updateSubscription", map[string]any{ "event": protocolEvent, "enabled": enabled, }) } } -func (c *channelOwner) Once(name string, handler interface{}) { +func (c *channelOwner) Once(name string, handler any) { c.addEvent(name, handler, true) } -func (c *channelOwner) On(name string, handler interface{}) { +func (c *channelOwner) On(name string, handler any) { c.addEvent(name, handler, false) } -func (c *channelOwner) addEvent(name string, handler interface{}, once bool) { +func (c *channelOwner) addEvent(name string, handler any, once bool) { if c.ListenerCount(name) == 0 { c.updateSubscription(name, true) } c.eventEmitter.addEvent(name, handler, once) } -func (c *channelOwner) RemoveListener(name string, handler interface{}) { +func (c *channelOwner) RemoveListener(name string, handler any) { c.eventEmitter.RemoveListener(name, handler) if c.ListenerCount(name) == 0 { c.updateSubscription(name, false) } } -func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) { +func (c *channelOwner) createChannelOwner(self any, parent *channelOwner, objectType string, guid string, initializer map[string]any) { c.objectType = objectType c.guid = guid c.wasCollected = false @@ -105,18 +105,20 @@ type rootChannelOwner struct { } func (r *rootChannelOwner) initialize() (*Playwright, error) { - ret, err := r.channel.SendReturnAsDict("initialize", map[string]interface{}{ + ret, err := r.channel.SendReturnAsDict("initialize", map[string]any{ "sdkLanguage": "javascript", }) if err != nil { return nil, err } - return fromChannel(ret["playwright"]).(*Playwright), nil + // GUIDs are now always eagerly resolved in connection.Dispatch + playwrightValue := ret["playwright"] + return fromChannel(playwrightValue).(*Playwright), nil } func newRootChannelOwner(connection *connection) *rootChannelOwner { c := &rootChannelOwner{} c.connection = connection - c.createChannelOwner(c, nil, "Root", "", make(map[string]interface{})) + c.createChannelOwner(c, nil, "Root", "", make(map[string]any)) return c } diff --git a/vendor/github.com/playwright-community/playwright-go/clock.go b/vendor/github.com/playwright-community/playwright-go/clock.go index 8bab0374..ba4c7e02 100644 --- a/vendor/github.com/playwright-community/playwright-go/clock.go +++ b/vendor/github.com/playwright-community/playwright-go/clock.go @@ -15,7 +15,7 @@ func newClock(bCtx *browserContextImpl) Clock { } } -func (c *clockImpl) FastForward(ticks interface{}) error { +func (c *clockImpl) FastForward(ticks any) error { params, err := parseTicks(ticks) if err != nil { return err @@ -41,7 +41,7 @@ func (c *clockImpl) Install(options ...ClockInstallOptions) (err error) { return err } -func (c *clockImpl) PauseAt(time interface{}) error { +func (c *clockImpl) PauseAt(time any) error { params, err := parseTime(time) if err != nil { return err @@ -56,7 +56,7 @@ func (c *clockImpl) Resume() error { return err } -func (c *clockImpl) RunFor(ticks interface{}) error { +func (c *clockImpl) RunFor(ticks any) error { params, err := parseTicks(ticks) if err != nil { return err @@ -66,7 +66,7 @@ func (c *clockImpl) RunFor(ticks interface{}) error { return err } -func (c *clockImpl) SetFixedTime(time interface{}) error { +func (c *clockImpl) SetFixedTime(time any) error { params, err := parseTime(time) if err != nil { return err @@ -76,7 +76,7 @@ func (c *clockImpl) SetFixedTime(time interface{}) error { return err } -func (c *clockImpl) SetSystemTime(time interface{}) error { +func (c *clockImpl) SetSystemTime(time any) error { params, err := parseTime(time) if err != nil { return err @@ -86,7 +86,7 @@ func (c *clockImpl) SetSystemTime(time interface{}) error { return err } -func parseTime(t interface{}) (map[string]any, error) { +func parseTime(t any) (map[string]any, error) { switch v := t.(type) { case int, int64: return map[string]any{"timeNumber": v}, nil @@ -99,7 +99,7 @@ func parseTime(t interface{}) (map[string]any, error) { } } -func parseTicks(ticks interface{}) (map[string]any, error) { +func parseTicks(ticks any) (map[string]any, error) { switch v := ticks.(type) { case int, int64: return map[string]any{"ticksNumber": v}, nil diff --git a/vendor/github.com/playwright-community/playwright-go/connection.go b/vendor/github.com/playwright-community/playwright-go/connection.go index ba1e365b..0e787462 100644 --- a/vendor/github.com/playwright-community/playwright-go/connection.go +++ b/vendor/github.com/playwright-community/playwright-go/connection.go @@ -100,15 +100,27 @@ func (c *connection) Dispatch(msg *message) { if msg.Error != nil { cb.SetError(parseError(msg.Error.Error)) } else { - cb.SetResult(c.replaceGuidsWithChannels(msg.Result).(map[string]interface{})) + // Always resolve GUIDs in responses, regardless of connection type + // The protocol guarantees that __create__ events arrive before responses that reference those objects + result, err := c.replaceGuidsWithChannels(msg.Result) + if err != nil { + cb.SetError(fmt.Errorf("failed to resolve response objects: %w", err)) + } else { + cb.SetResult(result.(map[string]any)) + } } return } object, _ := c.objects.Load(msg.GUID) if method == "__create__" { - c.createRemoteObject( + _, err := c.createRemoteObject( object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"], ) + if err != nil { + // Critical: object creation failure indicates corrupted protocol state + // Close connection to prevent cascade failures + c.cleanup(err) + } return } if object == nil { @@ -134,7 +146,15 @@ func (c *connection) Dispatch(msg *message) { if object.objectType == "JsonPipe" { object.channel.Emit(method, msg.Params) } else { - object.channel.Emit(method, c.replaceGuidsWithChannels(msg.Params)) + // Always resolve GUIDs in events, regardless of connection type + // The protocol guarantees that __create__ events arrive before events that reference those objects + params, err := c.replaceGuidsWithChannels(msg.Params) + if err != nil { + // Event parameters contain invalid references - connection is corrupted + c.cleanup(fmt.Errorf("failed to resolve event parameters for %s: %w", method, err)) + return + } + object.channel.Emit(method, params) } } @@ -142,13 +162,16 @@ func (c *connection) LocalUtils() *localUtilsImpl { return c.localUtils } -func (c *connection) createRemoteObject(parent *channelOwner, objectType string, guid string, initializer interface{}) interface{} { - initializer = c.replaceGuidsWithChannels(initializer) - result := createObjectFactory(parent, objectType, guid, initializer.(map[string]interface{})) - return result +func (c *connection) createRemoteObject(parent *channelOwner, objectType string, guid string, initializer any) (any, error) { + resolved, err := c.replaceGuidsWithChannels(initializer) + if err != nil { + return nil, fmt.Errorf("failed to resolve initializer for %s (guid=%s): %w", objectType, guid, err) + } + result := createObjectFactory(parent, objectType, guid, resolved.(map[string]any)) + return result, nil } -func (c *connection) WrapAPICall(cb func() (interface{}, error), isInternal bool) (interface{}, error) { +func (c *connection) WrapAPICall(cb func() (any, error), isInternal bool) (any, error) { if _, ok := c.apiZone.Load("apiZone"); ok { return cb() } @@ -156,34 +179,51 @@ func (c *connection) WrapAPICall(cb func() (interface{}, error), isInternal bool return cb() } -func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} { +func (c *connection) replaceGuidsWithChannels(payload any) (any, error) { if payload == nil { - return nil + return nil, nil } v := reflect.ValueOf(payload) if v.Kind() == reflect.Slice { - listV := payload.([]interface{}) - for i := 0; i < len(listV); i++ { - listV[i] = c.replaceGuidsWithChannels(listV[i]) + listV := payload.([]any) + for i := range listV { + resolved, err := c.replaceGuidsWithChannels(listV[i]) + if err != nil { + return nil, fmt.Errorf("failed to resolve slice element at index %d: %w", i, err) + } + listV[i] = resolved } - return listV + return listV, nil } if v.Kind() == reflect.Map { - mapV := payload.(map[string]interface{}) + mapV := payload.(map[string]any) + // Check if this map represents an object reference (has "guid" field) if guid, hasGUID := mapV["guid"]; hasGUID { - if channelOwner, ok := c.objects.Load(guid.(string)); ok { - return channelOwner.channel + guidStr, ok := guid.(string) + if !ok { + return nil, fmt.Errorf("guid field is not a string: %T", guid) } + // Try to load the object from connection's objects map + if channelOwner, ok := c.objects.Load(guidStr); ok { + return channelOwner.channel, nil + } + // Object not found - this indicates a protocol error or message ordering issue + return nil, fmt.Errorf("object with guid %s was not bound in the connection", guidStr) } + // Recursively process all values in the map for key := range mapV { - mapV[key] = c.replaceGuidsWithChannels(mapV[key]) + resolved, err := c.replaceGuidsWithChannels(mapV[key]) + if err != nil { + return nil, fmt.Errorf("failed to resolve map key '%s': %w", key, err) + } + mapV[key] = resolved } - return mapV + return mapV, nil } - return payload + return payload, nil } -func (c *connection) sendMessageToServer(object *channelOwner, method string, params interface{}, noReply bool) (cb *protocolCallback) { +func (c *connection) sendMessageToServer(object *channelOwner, method string, params any, noReply bool) (cb *protocolCallback) { cb = newProtocolCallback(noReply, c.abort) if err := c.closedError.Get(); err != nil { @@ -198,8 +238,8 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa id := c.lastID.Add(1) c.callbacks.Store(id, cb) var ( - metadata = make(map[string]interface{}, 0) - stack = make([]map[string]interface{}, 0) + metadata = make(map[string]any, 0) + stack = make([]map[string]any, 0) ) apiZone, ok := c.apiZone.LoadAndDelete("apiZone") if ok { @@ -209,7 +249,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa stack = append(stack, apiZone.(parsedStackTrace).frames...) } metadata["wallTime"] = time.Now().UnixMilli() - message := map[string]interface{}{ + message := map[string]any{ "id": id, "guid": object.guid, "method": method, @@ -237,8 +277,8 @@ func (c *connection) setInTracing(isTracing bool) { } type parsedStackTrace struct { - frames []map[string]interface{} - metadata map[string]interface{} + frames []map[string]any + metadata map[string]any } func serializeCallStack(isInternal bool) parsedStackTrace { @@ -259,19 +299,19 @@ func serializeCallStack(isInternal bool) parsedStackTrace { } st = st.TrimBelow(st[lastInternalIndex]) - callStack := make([]map[string]interface{}, 0) + callStack := make([]map[string]any, 0) for i, s := range st { if i == 0 { continue } - callStack = append(callStack, map[string]interface{}{ + callStack = append(callStack, map[string]any{ "file": s.Frame().File, "line": s.Frame().Line, "column": 0, "function": s.Frame().Function, }) } - metadata := make(map[string]interface{}) + metadata := make(map[string]any) if len(st) > 1 { metadata["location"] = serializeCallLocation(st[1]) } @@ -287,9 +327,9 @@ func serializeCallStack(isInternal bool) parsedStackTrace { } } -func serializeCallLocation(caller stack.Call) map[string]interface{} { +func serializeCallLocation(caller stack.Call) map[string]any { line, _ := strconv.Atoi(fmt.Sprintf("%d", caller)) - return map[string]interface{}{ + return map[string]any{ "file": fmt.Sprintf("%s", caller), "line": line, } @@ -313,11 +353,14 @@ func newConnection(transport transport, localUtils ...*localUtilsImpl) *connecti return connection } -func fromChannel(v interface{}) interface{} { - return v.(*channel).object +func fromChannel(v any) any { + if ch, ok := v.(*channel); ok { + return ch.object + } + panic(fmt.Sprintf("fromChannel: expected *channel, got %T: %+v", v, v)) } -func fromNullableChannel(v interface{}) interface{} { +func fromNullableChannel(v any) any { if v == nil { return nil } @@ -329,15 +372,17 @@ type protocolCallback struct { noReply bool abort <-chan struct{} once sync.Once - value map[string]interface{} + value map[string]any err error } -func (pc *protocolCallback) setResultOnce(result map[string]interface{}, err error) { +func (pc *protocolCallback) setResultOnce(result map[string]any, err error) { pc.once.Do(func() { pc.value = result pc.err = err - close(pc.done) + if pc.done != nil { + close(pc.done) + } }) } @@ -363,17 +408,17 @@ func (pc *protocolCallback) SetError(err error) { pc.setResultOnce(nil, err) } -func (pc *protocolCallback) SetResult(result map[string]interface{}) { +func (pc *protocolCallback) SetResult(result map[string]any) { pc.setResultOnce(result, nil) } -func (pc *protocolCallback) GetResult() (map[string]interface{}, error) { +func (pc *protocolCallback) GetResult() (map[string]any, error) { pc.waitResult() return pc.value, pc.err } // GetResultValue returns value if the map has only one element -func (pc *protocolCallback) GetResultValue() (interface{}, error) { +func (pc *protocolCallback) GetResultValue() (any, error) { pc.waitResult() if len(pc.value) == 0 { // empty map treated as nil return nil, pc.err diff --git a/vendor/github.com/playwright-community/playwright-go/console_message.go b/vendor/github.com/playwright-community/playwright-go/console_message.go index 4baf3f18..8b09b7ff 100644 --- a/vendor/github.com/playwright-community/playwright-go/console_message.go +++ b/vendor/github.com/playwright-community/playwright-go/console_message.go @@ -1,8 +1,9 @@ package playwright type consoleMessageImpl struct { - event map[string]interface{} - page Page + event map[string]any + page Page + worker Worker } func (c *consoleMessageImpl) Type() string { @@ -18,7 +19,7 @@ func (c *consoleMessageImpl) String() string { } func (c *consoleMessageImpl) Args() []JSHandle { - args := c.event["args"].([]interface{}) + args := c.event["args"].([]any) out := []JSHandle{} for idx := range args { out = append(out, fromChannel(args[idx]).(*jsHandleImpl)) @@ -36,12 +37,20 @@ func (c *consoleMessageImpl) Page() Page { return c.page } -func newConsoleMessage(event map[string]interface{}) *consoleMessageImpl { +func (c *consoleMessageImpl) Worker() (Worker, error) { + return c.worker, nil +} + +func newConsoleMessage(event map[string]any) *consoleMessageImpl { bt := &consoleMessageImpl{} bt.event = event page := fromNullableChannel(event["page"]) if page != nil { bt.page = page.(*pageImpl) } + worker := fromNullableChannel(event["worker"]) + if worker != nil { + bt.worker = worker.(*workerImpl) + } return bt } diff --git a/vendor/github.com/playwright-community/playwright-go/dialog.go b/vendor/github.com/playwright-community/playwright-go/dialog.go index 8d132342..faab25fc 100644 --- a/vendor/github.com/playwright-community/playwright-go/dialog.go +++ b/vendor/github.com/playwright-community/playwright-go/dialog.go @@ -22,7 +22,7 @@ func (d *dialogImpl) Accept(promptTextInput ...string) error { if len(promptTextInput) == 1 { promptText = &promptTextInput[0] } - _, err := d.channel.Send("accept", map[string]interface{}{ + _, err := d.channel.Send("accept", map[string]any{ "promptText": promptText, }) return err @@ -37,7 +37,7 @@ func (d *dialogImpl) Page() Page { return d.page } -func newDialog(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *dialogImpl { +func newDialog(parent *channelOwner, objectType string, guid string, initializer map[string]any) *dialogImpl { bt := &dialogImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) page := fromNullableChannel(initializer["page"]) diff --git a/vendor/github.com/playwright-community/playwright-go/element_handle.go b/vendor/github.com/playwright-community/playwright-go/element_handle.go index 62c41baa..9b7641ce 100644 --- a/vendor/github.com/playwright-community/playwright-go/element_handle.go +++ b/vendor/github.com/playwright-community/playwright-go/element_handle.go @@ -40,7 +40,7 @@ func (e *elementHandleImpl) ContentFrame() (Frame, error) { } func (e *elementHandleImpl) GetAttribute(name string) (string, error) { - attribute, err := e.channel.Send("getAttribute", map[string]interface{}{ + attribute, err := e.channel.Send("getAttribute", map[string]any{ "name": name, }) if attribute == nil { @@ -73,12 +73,12 @@ func (e *elementHandleImpl) InnerHTML() (string, error) { return innerHTML.(string), err } -func (e *elementHandleImpl) DispatchEvent(typ string, initObjects ...interface{}) error { - var initObject interface{} +func (e *elementHandleImpl) DispatchEvent(typ string, initObjects ...any) error { + var initObject any if len(initObjects) == 1 { initObject = initObjects[0] } - _, err := e.channel.Send("dispatchEvent", map[string]interface{}{ + _, err := e.channel.Send("dispatchEvent", map[string]any{ "type": typ, "eventInit": serializeArgument(initObject), }) @@ -101,7 +101,7 @@ func (e *elementHandleImpl) Dblclick(options ...ElementHandleDblclickOptions) er } func (e *elementHandleImpl) QuerySelector(selector string) (ElementHandle, error) { - channel, err := e.channel.Send("querySelector", map[string]interface{}{ + channel, err := e.channel.Send("querySelector", map[string]any{ "selector": selector, }) if err != nil { @@ -114,25 +114,25 @@ func (e *elementHandleImpl) QuerySelector(selector string) (ElementHandle, error } func (e *elementHandleImpl) QuerySelectorAll(selector string) ([]ElementHandle, error) { - channels, err := e.channel.Send("querySelectorAll", map[string]interface{}{ + channels, err := e.channel.Send("querySelectorAll", map[string]any{ "selector": selector, }) if err != nil { return nil, err } elements := make([]ElementHandle, 0) - for _, channel := range channels.([]interface{}) { + for _, channel := range channels.([]any) { elements = append(elements, fromChannel(channel).(*elementHandleImpl)) } return elements, nil } -func (e *elementHandleImpl) EvalOnSelector(selector string, expression string, options ...interface{}) (interface{}, error) { - var arg interface{} +func (e *elementHandleImpl) EvalOnSelector(selector string, expression string, options ...any) (any, error) { + var arg any if len(options) == 1 { arg = options[0] } - result, err := e.channel.Send("evalOnSelector", map[string]interface{}{ + result, err := e.channel.Send("evalOnSelector", map[string]any{ "selector": selector, "expression": expression, "arg": serializeArgument(arg), @@ -143,12 +143,12 @@ func (e *elementHandleImpl) EvalOnSelector(selector string, expression string, o return parseResult(result), nil } -func (e *elementHandleImpl) EvalOnSelectorAll(selector string, expression string, options ...interface{}) (interface{}, error) { - var arg interface{} +func (e *elementHandleImpl) EvalOnSelectorAll(selector string, expression string, options ...any) (any, error) { + var arg any if len(options) == 1 { arg = options[0] } - result, err := e.channel.Send("evalOnSelectorAll", map[string]interface{}{ + result, err := e.channel.Send("evalOnSelectorAll", map[string]any{ "selector": selector, "expression": expression, "arg": serializeArgument(arg), @@ -167,7 +167,7 @@ func (e *elementHandleImpl) ScrollIntoViewIfNeeded(options ...ElementHandleScrol return err } -func (e *elementHandleImpl) SetInputFiles(files interface{}, options ...ElementHandleSetInputFilesOptions) error { +func (e *elementHandleImpl) SetInputFiles(files any, options ...ElementHandleSetInputFilesOptions) error { frame, err := e.OwnerFrame() if err != nil { return err @@ -210,21 +210,21 @@ func (e *elementHandleImpl) Uncheck(options ...ElementHandleUncheckOptions) erro } func (e *elementHandleImpl) Press(key string, options ...ElementHandlePressOptions) error { - _, err := e.channel.Send("press", map[string]interface{}{ + _, err := e.channel.Send("press", map[string]any{ "key": key, }, options) return err } func (e *elementHandleImpl) Fill(value string, options ...ElementHandleFillOptions) error { - _, err := e.channel.Send("fill", map[string]interface{}{ + _, err := e.channel.Send("fill", map[string]any{ "value": value, }, options) return err } func (e *elementHandleImpl) Type(value string, options ...ElementHandleTypeOptions) error { - _, err := e.channel.Send("type", map[string]interface{}{ + _, err := e.channel.Send("type", map[string]any{ "text": value, }, options) return err @@ -242,19 +242,19 @@ func (e *elementHandleImpl) SelectText(options ...ElementHandleSelectTextOptions func (e *elementHandleImpl) Screenshot(options ...ElementHandleScreenshotOptions) ([]byte, error) { var path *string - overrides := map[string]interface{}{} + overrides := map[string]any{} if len(options) == 1 { path = options[0].Path options[0].Path = nil if options[0].Mask != nil { - masks := make([]map[string]interface{}, 0) + masks := make([]map[string]any, 0) for _, m := range options[0].Mask { if m.Err() != nil { // ErrLocatorNotSameFrame return nil, m.Err() } l, ok := m.(*locatorImpl) if ok { - masks = append(masks, map[string]interface{}{ + masks = append(masks, map[string]any{ "selector": l.selector, "frame": l.frame.channel, }) @@ -344,7 +344,7 @@ func (e *elementHandleImpl) IsVisible() (bool, error) { } func (e *elementHandleImpl) WaitForElementState(state ElementState, options ...ElementHandleWaitForElementStateOptions) error { - _, err := e.channel.Send("waitForElementState", map[string]interface{}{ + _, err := e.channel.Send("waitForElementState", map[string]any{ "state": state, }, options) if err != nil { @@ -354,7 +354,7 @@ func (e *elementHandleImpl) WaitForElementState(state ElementState, options ...E } func (e *elementHandleImpl) WaitForSelector(selector string, options ...ElementHandleWaitForSelectorOptions) (ElementHandle, error) { - ch, err := e.channel.Send("waitForSelector", map[string]interface{}{ + ch, err := e.channel.Send("waitForSelector", map[string]any{ "selector": selector, }, options) if err != nil { @@ -386,14 +386,14 @@ func (e *elementHandleImpl) SetChecked(checked bool, options ...ElementHandleSet } } -func newElementHandle(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *elementHandleImpl { +func newElementHandle(parent *channelOwner, objectType string, guid string, initializer map[string]any) *elementHandleImpl { bt := &elementHandleImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) return bt } -func transformToStringList(in interface{}) []string { - s := in.([]interface{}) +func transformToStringList(in any) []string { + s := in.([]any) var out []string for _, v := range s { diff --git a/vendor/github.com/playwright-community/playwright-go/event_emitter.go b/vendor/github.com/playwright-community/playwright-go/event_emitter.go index d4d62ef8..de3e0e86 100644 --- a/vendor/github.com/playwright-community/playwright-go/event_emitter.go +++ b/vendor/github.com/playwright-community/playwright-go/event_emitter.go @@ -8,11 +8,11 @@ import ( ) type EventEmitter interface { - Emit(name string, payload ...interface{}) bool + Emit(name string, payload ...any) bool ListenerCount(name string) int - On(name string, handler interface{}) - Once(name string, handler interface{}) - RemoveListener(name string, handler interface{}) + On(name string, handler any) + Once(name string, handler any) + RemoveListener(name string, handler any) RemoveListeners(name string) } @@ -27,7 +27,7 @@ type ( listeners []listener } listener struct { - handler interface{} + handler any once bool } ) @@ -36,7 +36,7 @@ func NewEventEmitter() EventEmitter { return &eventEmitter{} } -func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) { +func (e *eventEmitter) Emit(name string, payload ...any) (hasListener bool) { e.eventsMutex.Lock() e.init() @@ -49,15 +49,15 @@ func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bo return evt.callHandlers(payload...) > 0 } -func (e *eventEmitter) Once(name string, handler interface{}) { +func (e *eventEmitter) Once(name string, handler any) { e.addEvent(name, handler, true) } -func (e *eventEmitter) On(name string, handler interface{}) { +func (e *eventEmitter) On(name string, handler any) { e.addEvent(name, handler, false) } -func (e *eventEmitter) RemoveListener(name string, handler interface{}) { +func (e *eventEmitter) RemoveListener(name string, handler any) { e.eventsMutex.Lock() defer e.eventsMutex.Unlock() e.init() @@ -98,7 +98,7 @@ func (e *eventEmitter) ListenerCount(name string) int { return count } -func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) { +func (e *eventEmitter) addEvent(name string, handler any, once bool) { e.eventsMutex.Lock() defer e.eventsMutex.Unlock() e.init() @@ -118,7 +118,7 @@ func (e *eventEmitter) init() { } } -func (er *eventRegister) addHandler(handler interface{}, once bool) { +func (er *eventRegister) addHandler(handler any, once bool) { er.Lock() defer er.Unlock() er.listeners = append(er.listeners, listener{handler: handler, once: once}) @@ -130,7 +130,7 @@ func (er *eventRegister) count() int { return len(er.listeners) } -func (er *eventRegister) removeHandler(handler interface{}) { +func (er *eventRegister) removeHandler(handler any) { handlerPtr := reflect.ValueOf(handler).Pointer() er.listeners = slices.DeleteFunc(er.listeners, func(l listener) bool { @@ -138,7 +138,7 @@ func (er *eventRegister) removeHandler(handler interface{}) { }) } -func (er *eventRegister) callHandlers(payloads ...interface{}) int { +func (er *eventRegister) callHandlers(payloads ...any) int { payloadV := make([]reflect.Value, 0) for _, p := range payloads { diff --git a/vendor/github.com/playwright-community/playwright-go/fetch.go b/vendor/github.com/playwright-community/playwright-go/fetch.go index fc7f79f1..2100a5b9 100644 --- a/vendor/github.com/playwright-community/playwright-go/fetch.go +++ b/vendor/github.com/playwright-community/playwright-go/fetch.go @@ -14,7 +14,7 @@ type apiRequestImpl struct { } func (r *apiRequestImpl) NewContext(options ...APIRequestNewContextOptions) (APIRequestContext, error) { - overrides := map[string]interface{}{} + overrides := map[string]any{} if len(options) == 1 { if options[0].ClientCertificates != nil { certs, err := transformClientCertificate(options[0].ClientCertificates) @@ -41,13 +41,20 @@ func (r *apiRequestImpl) NewContext(options ...APIRequestNewContextOptions) (API options[0].StorageState = storageState options[0].StorageStatePath = nil } + if options[0].Timeout != nil { + overrides["timeout"] = options[0].Timeout + } } channel, err := r.channel.Send("newRequest", options, overrides) if err != nil { return nil, err } - return fromChannel(channel).(*apiRequestContextImpl), nil + ctx := fromChannel(channel).(*apiRequestContextImpl) + if len(options) == 1 && options[0].Timeout != nil { + ctx.defaultTimeout = options[0].Timeout + } + return ctx, nil } func newApiRequestImpl(pw *Playwright) *apiRequestImpl { @@ -56,15 +63,16 @@ func newApiRequestImpl(pw *Playwright) *apiRequestImpl { type apiRequestContextImpl struct { channelOwner - tracing *tracingImpl - closeReason *string + tracing *tracingImpl + closeReason *string + defaultTimeout *float64 } func (r *apiRequestContextImpl) Dispose(options ...APIRequestContextDisposeOptions) error { if len(options) == 1 { r.closeReason = options[0].Reason } - _, err := r.channel.Send("dispose", map[string]interface{}{ + _, err := r.channel.Send("dispose", map[string]any{ "reason": r.closeReason, }) if errors.Is(err, ErrTargetClosed) { @@ -87,7 +95,7 @@ func (r *apiRequestContextImpl) Delete(url string, options ...APIRequestContextD return r.Fetch(url, opts) } -func (r *apiRequestContextImpl) Fetch(urlOrRequest interface{}, options ...APIRequestContextFetchOptions) (APIResponse, error) { +func (r *apiRequestContextImpl) Fetch(urlOrRequest any, options ...APIRequestContextFetchOptions) (APIResponse, error) { switch v := urlOrRequest.(type) { case string: return r.innerFetch(v, nil, options...) @@ -102,7 +110,7 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options if r.closeReason != nil { return nil, fmt.Errorf("%w: %s", ErrTargetClosed, *r.closeReason) } - overrides := map[string]interface{}{} + overrides := map[string]any{} if url != "" { overrides["url"] = url } else if request != nil { @@ -154,7 +162,7 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options } case []byte: overrides["postData"] = base64.StdEncoding.EncodeToString(v) - case interface{}: + case any: data, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("could not marshal data: %w", err) @@ -165,22 +173,22 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options } options[0].Data = nil } else if options[0].Form != nil { - form, ok := options[0].Form.(map[string]interface{}) + form, ok := options[0].Form.(map[string]any) if !ok { return nil, errors.New("form must be a map") } overrides["formData"] = serializeMapToNameValue(form) options[0].Form = nil } else if options[0].Multipart != nil { - _, ok := options[0].Multipart.(map[string]interface{}) + _, ok := options[0].Multipart.(map[string]any) if !ok { return nil, errors.New("multipart must be a map") } - multipartData := []map[string]interface{}{} - for name, value := range options[0].Multipart.(map[string]interface{}) { + multipartData := []map[string]any{} + for name, value := range options[0].Multipart.(map[string]any) { switch v := value.(type) { case InputFile: - multipartData = append(multipartData, map[string]interface{}{ + multipartData = append(multipartData, map[string]any{ "name": name, "file": map[string]string{ "name": v.Name, @@ -189,7 +197,7 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options }, }) default: - multipartData = append(multipartData, map[string]interface{}{ + multipartData = append(multipartData, map[string]any{ "name": name, "value": String(fmt.Sprintf("%v", v)), }) @@ -207,6 +215,10 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options overrides["params"] = serializeMapToNameValue(options[0].Params) options[0].Params = nil } + // Use context-level timeout as default if no per-request timeout specified + if options[0].Timeout == nil && r.defaultTimeout != nil { + overrides["timeout"] = *r.defaultTimeout + } } response, err := r.channel.Send("fetch", options, overrides) @@ -214,7 +226,7 @@ func (r *apiRequestContextImpl) innerFetch(url string, request Request, options return nil, err } - return newAPIResponse(r, response.(map[string]interface{})), nil + return newAPIResponse(r, response.(map[string]any)), nil } func (r *apiRequestContextImpl) Get(url string, options ...APIRequestContextGetOptions) (APIResponse, error) { @@ -309,21 +321,23 @@ func (r *apiRequestContextImpl) StorageState(path ...string) (*StorageState, err return &storageState, nil } -func newAPIRequestContext(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *apiRequestContextImpl { +func newAPIRequestContext(parent *channelOwner, objectType string, guid string, initializer map[string]any) *apiRequestContextImpl { rc := &apiRequestContextImpl{} rc.createChannelOwner(rc, parent, objectType, guid, initializer) - rc.tracing = fromChannel(initializer["tracing"]).(*tracingImpl) + if tracingValue := initializer["tracing"]; tracingValue != nil { + rc.tracing = fromNullableChannel(tracingValue).(*tracingImpl) + } return rc } type apiResponseImpl struct { request *apiRequestContextImpl - initializer map[string]interface{} + initializer map[string]any headers *rawHeaders } func (r *apiResponseImpl) Body() ([]byte, error) { - result, err := r.request.channel.SendReturnAsDict("fetchResponseBody", []map[string]interface{}{ + result, err := r.request.channel.SendReturnAsDict("fetchResponseBody", []map[string]any{ { "fetchUid": r.fetchUid(), }, @@ -342,7 +356,7 @@ func (r *apiResponseImpl) Body() ([]byte, error) { } func (r *apiResponseImpl) Dispose() error { - _, err := r.request.channel.Send("disposeAPIResponse", []map[string]interface{}{ + _, err := r.request.channel.Send("disposeAPIResponse", []map[string]any{ { "fetchUid": r.fetchUid(), }, @@ -358,7 +372,7 @@ func (r *apiResponseImpl) HeadersArray() []NameValue { return r.headers.HeadersArray() } -func (r *apiResponseImpl) JSON(v interface{}) error { +func (r *apiResponseImpl) JSON(v any) error { body, err := r.Body() if err != nil { return err @@ -395,20 +409,20 @@ func (r *apiResponseImpl) fetchUid() string { } func (r *apiResponseImpl) fetchLog() ([]string, error) { - ret, err := r.request.channel.Send("fetchLog", map[string]interface{}{ + ret, err := r.request.channel.Send("fetchLog", map[string]any{ "fetchUid": r.fetchUid(), }) if err != nil { return nil, err } - result := make([]string, len(ret.([]interface{}))) - for i, v := range ret.([]interface{}) { + result := make([]string, len(ret.([]any))) + for i, v := range ret.([]any) { result[i] = v.(string) } return result, nil } -func newAPIResponse(context *apiRequestContextImpl, initializer map[string]interface{}) *apiResponseImpl { +func newAPIResponse(context *apiRequestContextImpl, initializer map[string]any) *apiResponseImpl { return &apiResponseImpl{ request: context, initializer: initializer, @@ -416,7 +430,7 @@ func newAPIResponse(context *apiRequestContextImpl, initializer map[string]inter } } -func countNonNil(args ...interface{}) int { +func countNonNil(args ...any) int { count := 0 for _, v := range args { if v != nil { @@ -439,7 +453,7 @@ func isJsonContentType(headers []map[string]string) bool { return false } -func serializeMapToNameValue(data map[string]interface{}) []map[string]string { +func serializeMapToNameValue(data map[string]any) []map[string]string { serialized := make([]map[string]string, 0, len(data)) for k, v := range data { serialized = append(serialized, map[string]string{ diff --git a/vendor/github.com/playwright-community/playwright-go/file_chooser.go b/vendor/github.com/playwright-community/playwright-go/file_chooser.go index 119e8858..db9b1463 100644 --- a/vendor/github.com/playwright-community/playwright-go/file_chooser.go +++ b/vendor/github.com/playwright-community/playwright-go/file_chooser.go @@ -28,7 +28,7 @@ type InputFile struct { Buffer []byte `json:"buffer"` } -func (f *fileChooserImpl) SetFiles(files interface{}, options ...FileChooserSetFilesOptions) error { +func (f *fileChooserImpl) SetFiles(files any, options ...FileChooserSetFilesOptions) error { if len(options) == 1 { return f.elementHandle.SetInputFiles(files, ElementHandleSetInputFilesOptions(options[0])) } diff --git a/vendor/github.com/playwright-community/playwright-go/frame.go b/vendor/github.com/playwright-community/playwright-go/frame.go index b571c8ed..a0a7b92f 100644 --- a/vendor/github.com/playwright-community/playwright-go/frame.go +++ b/vendor/github.com/playwright-community/playwright-go/frame.go @@ -20,7 +20,7 @@ type frameImpl struct { loadStates mapset.Set[string] } -func newFrame(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *frameImpl { +func newFrame(parent *channelOwner, objectType string, guid string, initializer map[string]any) *frameImpl { var loadStates mapset.Set[string] if ls, ok := initializer["loadStates"].([]string); ok { @@ -60,9 +60,14 @@ func (f *frameImpl) Name() string { } func (f *frameImpl) SetContent(content string, options ...FrameSetContentOptions) error { - _, err := f.channel.Send("setContent", map[string]interface{}{ + overrides := map[string]any{ "html": content, - }, options) + } + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = f.page.timeoutSettings.NavigationTimeout() + } + _, err := f.channel.Send("setContent", overrides, options) return err } @@ -75,9 +80,14 @@ func (f *frameImpl) Content() (string, error) { } func (f *frameImpl) Goto(url string, options ...FrameGotoOptions) (Response, error) { - channel, err := f.channel.Send("goto", map[string]interface{}{ + overrides := map[string]any{ "url": url, - }, options) + } + // timeout is required in Playwright v1.57+ protocol + if len(options) == 0 || options[0].Timeout == nil { + overrides["timeout"] = f.page.timeoutSettings.NavigationTimeout() + } + channel, err := f.channel.Send("goto", overrides, options) if err != nil { return nil, fmt.Errorf("Frame.Goto %s: %w", url, err) } @@ -144,7 +154,7 @@ func (f *frameImpl) waitForLoadStateImpl(state string, timeout *float64, cb func if err != nil { return err } - waiter.WaitForEvent(f, "loadstate", func(payload interface{}) bool { + waiter.WaitForEvent(f, "loadstate", func(payload any) bool { gotState := payload.(string) return gotState == state }) @@ -157,7 +167,7 @@ func (f *frameImpl) waitForLoadStateImpl(state string, timeout *float64, cb func } } -func (f *frameImpl) WaitForURL(url interface{}, options ...FrameWaitForURLOptions) error { +func (f *frameImpl) WaitForURL(url any, options ...FrameWaitForURLOptions) error { if f.page == nil { return errors.New("frame is detached") } @@ -205,8 +215,8 @@ func (f *frameImpl) ExpectNavigation(cb func() error, options ...FrameExpectNavi if option.URL != nil { matcher = newURLMatcher(option.URL, f.page.browserContext.options.BaseURL) } - predicate := func(events ...interface{}) bool { - ev := events[0].(map[string]interface{}) + predicate := func(events ...any) bool { + ev := events[0].(map[string]any) err, ok := ev["error"] if ok { // Any failed navigation results in a rejection. @@ -232,9 +242,9 @@ func (f *frameImpl) ExpectNavigation(cb func() error, options ...FrameExpectNavi return nil, err } } - event := eventData.(map[string]interface{}) - if event["newDocument"] != nil && event["newDocument"].(map[string]interface{})["request"] != nil { - request := fromChannel(event["newDocument"].(map[string]interface{})["request"]).(*requestImpl) + event := eventData.(map[string]any) + if event["newDocument"] != nil && event["newDocument"].(map[string]any)["request"] != nil { + request := fromChannel(event["newDocument"].(map[string]any)["request"]).(*requestImpl) return request.Response() } return nil, nil @@ -252,7 +262,7 @@ func (f *frameImpl) setNavigationWaiter(timeout *float64) (*waiter, error) { } waiter.RejectOnEvent(f.page, "close", f.page.closeErrorWithReason()) waiter.RejectOnEvent(f.page, "crash", fmt.Errorf("Navigation failed because page crashed!")) - waiter.RejectOnEvent(f.page, "framedetached", fmt.Errorf("Navigating frame was detached!"), func(payload interface{}) bool { + waiter.RejectOnEvent(f.page, "framedetached", fmt.Errorf("Navigating frame was detached!"), func(payload any) bool { frame, ok := payload.(*frameImpl) if ok && frame == f { return true @@ -262,7 +272,7 @@ func (f *frameImpl) setNavigationWaiter(timeout *float64) (*waiter, error) { return waiter, nil } -func (f *frameImpl) onFrameNavigated(ev map[string]interface{}) { +func (f *frameImpl) onFrameNavigated(ev map[string]any) { f.Lock() f.url = ev["url"].(string) f.name = ev["name"].(string) @@ -274,7 +284,7 @@ func (f *frameImpl) onFrameNavigated(ev map[string]interface{}) { } } -func (f *frameImpl) onLoadState(ev map[string]interface{}) { +func (f *frameImpl) onLoadState(ev map[string]any) { if ev["add"] != nil { add := ev["add"].(string) f.loadStates.Add(add) @@ -291,7 +301,7 @@ func (f *frameImpl) onLoadState(ev map[string]interface{}) { } func (f *frameImpl) QuerySelector(selector string, options ...FrameQuerySelectorOptions) (ElementHandle, error) { - params := map[string]interface{}{ + params := map[string]any{ "selector": selector, } if len(options) == 1 { @@ -308,25 +318,25 @@ func (f *frameImpl) QuerySelector(selector string, options ...FrameQuerySelector } func (f *frameImpl) QuerySelectorAll(selector string) ([]ElementHandle, error) { - channels, err := f.channel.Send("querySelectorAll", map[string]interface{}{ + channels, err := f.channel.Send("querySelectorAll", map[string]any{ "selector": selector, }) if err != nil { return nil, err } elements := make([]ElementHandle, 0) - for _, channel := range channels.([]interface{}) { + for _, channel := range channels.([]any) { elements = append(elements, fromChannel(channel).(*elementHandleImpl)) } return elements, nil } -func (f *frameImpl) Evaluate(expression string, options ...interface{}) (interface{}, error) { - var arg interface{} +func (f *frameImpl) Evaluate(expression string, options ...any) (any, error) { + var arg any if len(options) == 1 { arg = options[0] } - result, err := f.channel.Send("evaluateExpression", map[string]interface{}{ + result, err := f.channel.Send("evaluateExpression", map[string]any{ "expression": expression, "arg": serializeArgument(arg), }) @@ -336,8 +346,8 @@ func (f *frameImpl) Evaluate(expression string, options ...interface{}) (interfa return parseResult(result), nil } -func (f *frameImpl) EvalOnSelector(selector string, expression string, arg interface{}, options ...FrameEvalOnSelectorOptions) (interface{}, error) { - params := map[string]interface{}{ +func (f *frameImpl) EvalOnSelector(selector string, expression string, arg any, options ...FrameEvalOnSelectorOptions) (any, error) { + params := map[string]any{ "selector": selector, "expression": expression, "arg": serializeArgument(arg), @@ -353,12 +363,12 @@ func (f *frameImpl) EvalOnSelector(selector string, expression string, arg inter return parseResult(result), nil } -func (f *frameImpl) EvalOnSelectorAll(selector string, expression string, options ...interface{}) (interface{}, error) { - var arg interface{} +func (f *frameImpl) EvalOnSelectorAll(selector string, expression string, options ...any) (any, error) { + var arg any if len(options) == 1 { arg = options[0] } - result, err := f.channel.Send("evalOnSelectorAll", map[string]interface{}{ + result, err := f.channel.Send("evalOnSelectorAll", map[string]any{ "selector": selector, "expression": expression, "arg": serializeArgument(arg), @@ -369,12 +379,12 @@ func (f *frameImpl) EvalOnSelectorAll(selector string, expression string, option return parseResult(result), nil } -func (f *frameImpl) EvaluateHandle(expression string, options ...interface{}) (JSHandle, error) { - var arg interface{} +func (f *frameImpl) EvaluateHandle(expression string, options ...any) (JSHandle, error) { + var arg any if len(options) == 1 { arg = options[0] } - result, err := f.channel.Send("evaluateExpressionHandle", map[string]interface{}{ + result, err := f.channel.Send("evaluateExpressionHandle", map[string]any{ "expression": expression, "arg": serializeArgument(arg), }) @@ -389,14 +399,14 @@ func (f *frameImpl) EvaluateHandle(expression string, options ...interface{}) (J } func (f *frameImpl) Click(selector string, options ...FrameClickOptions) error { - _, err := f.channel.Send("click", map[string]interface{}{ + _, err := f.channel.Send("click", map[string]any{ "selector": selector, }, options) return err } func (f *frameImpl) WaitForSelector(selector string, options ...FrameWaitForSelectorOptions) (ElementHandle, error) { - channel, err := f.channel.Send("waitForSelector", map[string]interface{}{ + channel, err := f.channel.Send("waitForSelector", map[string]any{ "selector": selector, }, options) if err != nil { @@ -409,17 +419,17 @@ func (f *frameImpl) WaitForSelector(selector string, options ...FrameWaitForSele return channelOwner.(*elementHandleImpl), nil } -func (f *frameImpl) DispatchEvent(selector, typ string, eventInit interface{}, options ...FrameDispatchEventOptions) error { - _, err := f.channel.Send("dispatchEvent", map[string]interface{}{ +func (f *frameImpl) DispatchEvent(selector, typ string, eventInit any, options ...FrameDispatchEventOptions) error { + _, err := f.channel.Send("dispatchEvent", map[string]any{ "selector": selector, "type": typ, "eventInit": serializeArgument(eventInit), - }) + }, options) return err } func (f *frameImpl) InnerText(selector string, options ...FrameInnerTextOptions) (string, error) { - innerText, err := f.channel.Send("innerText", map[string]interface{}{ + innerText, err := f.channel.Send("innerText", map[string]any{ "selector": selector, }, options) if innerText == nil { @@ -429,7 +439,7 @@ func (f *frameImpl) InnerText(selector string, options ...FrameInnerTextOptions) } func (f *frameImpl) InnerHTML(selector string, options ...FrameInnerHTMLOptions) (string, error) { - innerHTML, err := f.channel.Send("innerHTML", map[string]interface{}{ + innerHTML, err := f.channel.Send("innerHTML", map[string]any{ "selector": selector, }, options) if innerHTML == nil { @@ -439,7 +449,7 @@ func (f *frameImpl) InnerHTML(selector string, options ...FrameInnerHTMLOptions) } func (f *frameImpl) GetAttribute(selector string, name string, options ...FrameGetAttributeOptions) (string, error) { - attribute, err := f.channel.Send("getAttribute", map[string]interface{}{ + attribute, err := f.channel.Send("getAttribute", map[string]any{ "selector": selector, "name": name, }, options) @@ -450,13 +460,13 @@ func (f *frameImpl) GetAttribute(selector string, name string, options ...FrameG } func (f *frameImpl) Hover(selector string, options ...FrameHoverOptions) error { - _, err := f.channel.Send("hover", map[string]interface{}{ + _, err := f.channel.Send("hover", map[string]any{ "selector": selector, }, options) return err } -func (f *frameImpl) SetInputFiles(selector string, files interface{}, options ...FrameSetInputFilesOptions) error { +func (f *frameImpl) SetInputFiles(selector string, files any, options ...FrameSetInputFilesOptions) error { params, err := convertInputFiles(files, f.page.browserContext) if err != nil { return err @@ -467,7 +477,7 @@ func (f *frameImpl) SetInputFiles(selector string, files interface{}, options .. } func (f *frameImpl) Type(selector, text string, options ...FrameTypeOptions) error { - _, err := f.channel.Send("type", map[string]interface{}{ + _, err := f.channel.Send("type", map[string]any{ "selector": selector, "text": text, }, options) @@ -475,7 +485,7 @@ func (f *frameImpl) Type(selector, text string, options ...FrameTypeOptions) err } func (f *frameImpl) Press(selector, key string, options ...FramePressOptions) error { - _, err := f.channel.Send("press", map[string]interface{}{ + _, err := f.channel.Send("press", map[string]any{ "selector": selector, "key": key, }, options) @@ -483,14 +493,14 @@ func (f *frameImpl) Press(selector, key string, options ...FramePressOptions) er } func (f *frameImpl) Check(selector string, options ...FrameCheckOptions) error { - _, err := f.channel.Send("check", map[string]interface{}{ + _, err := f.channel.Send("check", map[string]any{ "selector": selector, }, options) return err } func (f *frameImpl) Uncheck(selector string, options ...FrameUncheckOptions) error { - _, err := f.channel.Send("uncheck", map[string]interface{}{ + _, err := f.channel.Send("uncheck", map[string]any{ "selector": selector, }, options) return err @@ -500,17 +510,23 @@ func (f *frameImpl) WaitForTimeout(timeout float64) { time.Sleep(time.Duration(timeout) * time.Millisecond) } -func (f *frameImpl) WaitForFunction(expression string, arg interface{}, options ...FrameWaitForFunctionOptions) (JSHandle, error) { +func (f *frameImpl) WaitForFunction(expression string, arg any, options ...FrameWaitForFunctionOptions) (JSHandle, error) { var option FrameWaitForFunctionOptions if len(options) == 1 { option = options[0] } - result, err := f.channel.Send("waitForFunction", map[string]interface{}{ + overrides := map[string]any{ "expression": expression, "arg": serializeArgument(arg), - "timeout": option.Timeout, "polling": option.Polling, - }) + } + // timeout is required in Playwright v1.57+ protocol + if option.Timeout == nil { + overrides["timeout"] = f.page.timeoutSettings.Timeout() + } else { + overrides["timeout"] = option.Timeout + } + result, err := f.channel.Send("waitForFunction", overrides) if err != nil { return nil, err } @@ -534,14 +550,14 @@ func (f *frameImpl) ChildFrames() []Frame { } func (f *frameImpl) Dblclick(selector string, options ...FrameDblclickOptions) error { - _, err := f.channel.Send("dblclick", map[string]interface{}{ + _, err := f.channel.Send("dblclick", map[string]any{ "selector": selector, }, options) return err } func (f *frameImpl) Fill(selector string, value string, options ...FrameFillOptions) error { - _, err := f.channel.Send("fill", map[string]interface{}{ + _, err := f.channel.Send("fill", map[string]any{ "selector": selector, "value": value, }, options) @@ -549,7 +565,7 @@ func (f *frameImpl) Fill(selector string, value string, options ...FrameFillOpti } func (f *frameImpl) Focus(selector string, options ...FrameFocusOptions) error { - _, err := f.channel.Send("focus", map[string]interface{}{ + _, err := f.channel.Send("focus", map[string]any{ "selector": selector, }, options) return err @@ -572,7 +588,7 @@ func (f *frameImpl) ParentFrame() Frame { } func (f *frameImpl) TextContent(selector string, options ...FrameTextContentOptions) (string, error) { - textContent, err := f.channel.Send("textContent", map[string]interface{}{ + textContent, err := f.channel.Send("textContent", map[string]any{ "selector": selector, }, options) if textContent == nil { @@ -582,7 +598,7 @@ func (f *frameImpl) TextContent(selector string, options ...FrameTextContentOpti } func (f *frameImpl) Tap(selector string, options ...FrameTapOptions) error { - _, err := f.channel.Send("tap", map[string]interface{}{ + _, err := f.channel.Send("tap", map[string]any{ "selector": selector, }, options) return err @@ -591,7 +607,7 @@ func (f *frameImpl) Tap(selector string, options ...FrameTapOptions) error { func (f *frameImpl) SelectOption(selector string, values SelectOptionValues, options ...FrameSelectOptionOptions) ([]string, error) { opts := convertSelectOptionSet(values) - m := make(map[string]interface{}) + m := make(map[string]any) m["selector"] = selector for k, v := range opts { m[k] = v @@ -605,7 +621,7 @@ func (f *frameImpl) SelectOption(selector string, values SelectOptionValues, opt } func (f *frameImpl) IsChecked(selector string, options ...FrameIsCheckedOptions) (bool, error) { - checked, err := f.channel.Send("isChecked", map[string]interface{}{ + checked, err := f.channel.Send("isChecked", map[string]any{ "selector": selector, }, options) if err != nil { @@ -615,7 +631,7 @@ func (f *frameImpl) IsChecked(selector string, options ...FrameIsCheckedOptions) } func (f *frameImpl) IsDisabled(selector string, options ...FrameIsDisabledOptions) (bool, error) { - disabled, err := f.channel.Send("isDisabled", map[string]interface{}{ + disabled, err := f.channel.Send("isDisabled", map[string]any{ "selector": selector, }, options) if err != nil { @@ -625,7 +641,7 @@ func (f *frameImpl) IsDisabled(selector string, options ...FrameIsDisabledOption } func (f *frameImpl) IsEditable(selector string, options ...FrameIsEditableOptions) (bool, error) { - editable, err := f.channel.Send("isEditable", map[string]interface{}{ + editable, err := f.channel.Send("isEditable", map[string]any{ "selector": selector, }, options) if err != nil { @@ -635,7 +651,7 @@ func (f *frameImpl) IsEditable(selector string, options ...FrameIsEditableOption } func (f *frameImpl) IsEnabled(selector string, options ...FrameIsEnabledOptions) (bool, error) { - enabled, err := f.channel.Send("isEnabled", map[string]interface{}{ + enabled, err := f.channel.Send("isEnabled", map[string]any{ "selector": selector, }, options) if err != nil { @@ -645,7 +661,7 @@ func (f *frameImpl) IsEnabled(selector string, options ...FrameIsEnabledOptions) } func (f *frameImpl) IsHidden(selector string, options ...FrameIsHiddenOptions) (bool, error) { - hidden, err := f.channel.Send("isHidden", map[string]interface{}{ + hidden, err := f.channel.Send("isHidden", map[string]any{ "selector": selector, }, options) if err != nil { @@ -655,7 +671,7 @@ func (f *frameImpl) IsHidden(selector string, options ...FrameIsHiddenOptions) ( } func (f *frameImpl) IsVisible(selector string, options ...FrameIsVisibleOptions) (bool, error) { - visible, err := f.channel.Send("isVisible", map[string]interface{}{ + visible, err := f.channel.Send("isVisible", map[string]any{ "selector": selector, }, options) if err != nil { @@ -665,7 +681,7 @@ func (f *frameImpl) IsVisible(selector string, options ...FrameIsVisibleOptions) } func (f *frameImpl) InputValue(selector string, options ...FrameInputValueOptions) (string, error) { - value, err := f.channel.Send("inputValue", map[string]interface{}{ + value, err := f.channel.Send("inputValue", map[string]any{ "selector": selector, }, options) if value == nil { @@ -675,7 +691,7 @@ func (f *frameImpl) InputValue(selector string, options ...FrameInputValueOption } func (f *frameImpl) DragAndDrop(source, target string, options ...FrameDragAndDropOptions) error { - _, err := f.channel.Send("dragAndDrop", map[string]interface{}{ + _, err := f.channel.Send("dragAndDrop", map[string]any{ "source": source, "target": target, }, options) @@ -684,12 +700,12 @@ func (f *frameImpl) DragAndDrop(source, target string, options ...FrameDragAndDr func (f *frameImpl) SetChecked(selector string, checked bool, options ...FrameSetCheckedOptions) error { if checked { - _, err := f.channel.Send("check", map[string]interface{}{ + _, err := f.channel.Send("check", map[string]any{ "selector": selector, }, options) return err } else { - _, err := f.channel.Send("uncheck", map[string]interface{}{ + _, err := f.channel.Send("uncheck", map[string]any{ "selector": selector, }, options) return err @@ -709,7 +725,7 @@ func (f *frameImpl) Locator(selector string, options ...FrameLocatorOptions) Loc return newLocator(f, selector, option) } -func (f *frameImpl) GetByAltText(text interface{}, options ...FrameGetByAltTextOptions) Locator { +func (f *frameImpl) GetByAltText(text any, options ...FrameGetByAltTextOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -719,7 +735,7 @@ func (f *frameImpl) GetByAltText(text interface{}, options ...FrameGetByAltTextO return f.Locator(getByAltTextSelector(text, exact)) } -func (f *frameImpl) GetByLabel(text interface{}, options ...FrameGetByLabelOptions) Locator { +func (f *frameImpl) GetByLabel(text any, options ...FrameGetByLabelOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -729,7 +745,7 @@ func (f *frameImpl) GetByLabel(text interface{}, options ...FrameGetByLabelOptio return f.Locator(getByLabelSelector(text, exact)) } -func (f *frameImpl) GetByPlaceholder(text interface{}, options ...FrameGetByPlaceholderOptions) Locator { +func (f *frameImpl) GetByPlaceholder(text any, options ...FrameGetByPlaceholderOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -746,11 +762,11 @@ func (f *frameImpl) GetByRole(role AriaRole, options ...FrameGetByRoleOptions) L return f.Locator(getByRoleSelector(role)) } -func (f *frameImpl) GetByTestId(testId interface{}) Locator { +func (f *frameImpl) GetByTestId(testId any) Locator { return f.Locator(getByTestIdSelector(getTestIdAttributeName(), testId)) } -func (f *frameImpl) GetByText(text interface{}, options ...FrameGetByTextOptions) Locator { +func (f *frameImpl) GetByText(text any, options ...FrameGetByTextOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -760,7 +776,7 @@ func (f *frameImpl) GetByText(text interface{}, options ...FrameGetByTextOptions return f.Locator(getByTextSelector(text, exact)) } -func (f *frameImpl) GetByTitle(text interface{}, options ...FrameGetByTitleOptions) Locator { +func (f *frameImpl) GetByTitle(text any, options ...FrameGetByTitleOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -775,14 +791,14 @@ func (f *frameImpl) FrameLocator(selector string) FrameLocator { } func (f *frameImpl) highlight(selector string) error { - _, err := f.channel.Send("highlight", map[string]interface{}{ + _, err := f.channel.Send("highlight", map[string]any{ "selector": selector, }) return err } func (f *frameImpl) queryCount(selector string) (int, error) { - response, err := f.channel.Send("queryCount", map[string]interface{}{ + response, err := f.channel.Send("queryCount", map[string]any{ "selector": selector, }) if err != nil { diff --git a/vendor/github.com/playwright-community/playwright-go/frame_locator.go b/vendor/github.com/playwright-community/playwright-go/frame_locator.go index d4b8fd0d..b55c4f34 100644 --- a/vendor/github.com/playwright-community/playwright-go/frame_locator.go +++ b/vendor/github.com/playwright-community/playwright-go/frame_locator.go @@ -23,7 +23,7 @@ func (fl *frameLocatorImpl) FrameLocator(selector string) FrameLocator { return newFrameLocator(fl.frame, fl.frameSelector+" >> internal:control=enter-frame >> "+selector) } -func (fl *frameLocatorImpl) GetByAltText(text interface{}, options ...FrameLocatorGetByAltTextOptions) Locator { +func (fl *frameLocatorImpl) GetByAltText(text any, options ...FrameLocatorGetByAltTextOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -33,7 +33,7 @@ func (fl *frameLocatorImpl) GetByAltText(text interface{}, options ...FrameLocat return fl.Locator(getByAltTextSelector(text, exact)) } -func (fl *frameLocatorImpl) GetByLabel(text interface{}, options ...FrameLocatorGetByLabelOptions) Locator { +func (fl *frameLocatorImpl) GetByLabel(text any, options ...FrameLocatorGetByLabelOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -43,7 +43,7 @@ func (fl *frameLocatorImpl) GetByLabel(text interface{}, options ...FrameLocator return fl.Locator(getByLabelSelector(text, exact)) } -func (fl *frameLocatorImpl) GetByPlaceholder(text interface{}, options ...FrameLocatorGetByPlaceholderOptions) Locator { +func (fl *frameLocatorImpl) GetByPlaceholder(text any, options ...FrameLocatorGetByPlaceholderOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -60,11 +60,11 @@ func (fl *frameLocatorImpl) GetByRole(role AriaRole, options ...FrameLocatorGetB return fl.Locator(getByRoleSelector(role)) } -func (fl *frameLocatorImpl) GetByTestId(testId interface{}) Locator { +func (fl *frameLocatorImpl) GetByTestId(testId any) Locator { return fl.Locator(getByTestIdSelector(getTestIdAttributeName(), testId)) } -func (fl *frameLocatorImpl) GetByText(text interface{}, options ...FrameLocatorGetByTextOptions) Locator { +func (fl *frameLocatorImpl) GetByText(text any, options ...FrameLocatorGetByTextOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -74,7 +74,7 @@ func (fl *frameLocatorImpl) GetByText(text interface{}, options ...FrameLocatorG return fl.Locator(getByTextSelector(text, exact)) } -func (fl *frameLocatorImpl) GetByTitle(text interface{}, options ...FrameLocatorGetByTitleOptions) Locator { +func (fl *frameLocatorImpl) GetByTitle(text any, options ...FrameLocatorGetByTitleOptions) Locator { exact := false if len(options) == 1 { if *options[0].Exact { @@ -88,7 +88,7 @@ func (fl *frameLocatorImpl) Last() FrameLocator { return newFrameLocator(fl.frame, fl.frameSelector+" >> nth=-1") } -func (fl *frameLocatorImpl) Locator(selectorOrLocator interface{}, options ...FrameLocatorLocatorOptions) Locator { +func (fl *frameLocatorImpl) Locator(selectorOrLocator any, options ...FrameLocatorLocatorOptions) Locator { var option LocatorOptions if len(options) == 1 { option = LocatorOptions{ diff --git a/vendor/github.com/playwright-community/playwright-go/generated-interfaces.go b/vendor/github.com/playwright-community/playwright-go/generated-interfaces.go index 187dc910..39e9e6f1 100644 --- a/vendor/github.com/playwright-community/playwright-go/generated-interfaces.go +++ b/vendor/github.com/playwright-community/playwright-go/generated-interfaces.go @@ -40,7 +40,7 @@ type APIRequestContext interface { // update context cookies from the response. The method will automatically follow redirects. // // urlOrRequest: Target URL or Request to get all parameters from. - Fetch(urlOrRequest interface{}, options ...APIRequestContextFetchOptions) (APIResponse, error) + Fetch(urlOrRequest any, options ...APIRequestContextFetchOptions) (APIResponse, error) // Sends HTTP(S) [GET] request and returns its // response. The method will populate request cookies from the context and update context cookies from the response. @@ -109,7 +109,7 @@ type APIResponse interface { // Returns the JSON representation of response body. // This method will throw if the response body is not parsable via `JSON.parse`. - JSON(v interface{}) error + JSON(v any) error // Contains a boolean stating whether the response was successful (status in the range 200-299) or not. Ok() bool @@ -217,8 +217,9 @@ type Browser interface { // Non-persistent browser contexts don't write any browsing data to disk. type BrowserContext interface { EventEmitter - // **NOTE** Only works with Chromium browser's persistent context. - // Emitted when new background page is created in the context. + // This event is not emitted. + // + // Deprecated: Background pages have been removed from Chromium together with Manifest V2 extensions. OnBackgroundPage(fn func(Page)) // Playwright has ability to mock clock and passage of time. @@ -292,11 +293,13 @@ type BrowserContext interface { // script: Script to be evaluated in all pages in the browser context. AddInitScript(script Script) error - // **NOTE** Background pages are only supported on Chromium-based browsers. - // All existing background pages in the context. + // Returns an empty list. + // + // Deprecated: Background pages have been removed from Chromium together with Manifest V2 extensions. BackgroundPages() []Page - // Returns the browser instance of the context. If it was launched as a persistent context null gets returned. + // Gets the browser instance that owns the context. Returns `null` if the context is created outside of normal + // browser, e.g. Android or Electron. Browser() Browser // Removes cookies from context. Accepts optional filter. @@ -351,6 +354,8 @@ type BrowserContext interface { // - `'clipboard-write'` // - `'geolocation'` // - `'gyroscope'` + // - `'local-fonts'` + // - `'local-network-access'` // - `'magnetometer'` // - `'microphone'` // - `'midi-sysex'` (system-exclusive midi) @@ -365,7 +370,7 @@ type BrowserContext interface { // // page: Target to create new session for. For backwards-compatibility, this parameter is named `page`, but it can be a // `Page` or `Frame` type. - NewCDPSession(page interface{}) (CDPSession, error) + NewCDPSession(page any) (CDPSession, error) // Creates a new page in the browser context. NewPage() (Page, error) @@ -388,7 +393,7 @@ type BrowserContext interface { // 2. handler: handler function to route the request. // // [this]: https://github.com/microsoft/playwright/issues/1090 - Route(url interface{}, handler routeHandler, times ...int) error + Route(url any, handler routeHandler, times ...int) error // If specified the network requests that are made in the context will be served from the HAR file. Read more about // [Replaying from HAR]. @@ -410,7 +415,7 @@ type BrowserContext interface { // 1. url: Only WebSockets with the url matching this pattern will be routed. A string pattern can be relative to the // “[object Object]” context option. // 2. handler: Handler function to route the WebSocket. - RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error + RouteWebSocket(url any, handler func(WebSocketRoute)) error // **NOTE** Service workers are only supported on Chromium-based browsers. // All existing service workers in the context. @@ -465,7 +470,7 @@ type BrowserContext interface { // // 1. url: A glob pattern, regex pattern or predicate receiving [URL] used to register a routing with [BrowserContext.Route]. // 2. handler: Optional handler function used to register a routing with [BrowserContext.Route]. - Unroute(url interface{}, handler ...routeHandler) error + Unroute(url any, handler ...routeHandler) error // Performs action and waits for a [ConsoleMessage] to be logged by in the pages in the context. If predicate is // provided, it passes [ConsoleMessage] value into the `predicate` function and waits for `predicate(message)` to @@ -477,7 +482,7 @@ type BrowserContext interface { // value. Will throw an error if the context closes before the event is fired. Returns the event data value. // // event: Event name, same one would pass into `browserContext.on(event)`. - ExpectEvent(event string, cb func() error, options ...BrowserContextExpectEventOptions) (interface{}, error) + ExpectEvent(event string, cb func() error, options ...BrowserContextExpectEventOptions) (any, error) // Performs action and waits for a new [Page] to be created in the context. If predicate is provided, it passes [Page] // value into the `predicate` function and waits for `predicate(event)` to return a truthy value. Will throw an error @@ -490,7 +495,7 @@ type BrowserContext interface { // before the `event` is fired. // // event: Event name, same one typically passed into `*.on(event)`. - WaitForEvent(event string, options ...BrowserContextWaitForEventOptions) (interface{}, error) + WaitForEvent(event string, options ...BrowserContextWaitForEventOptions) (any, error) } // BrowserType provides methods to launch a specific browser instance or connect to an existing one. The following is @@ -538,6 +543,12 @@ type BrowserType interface { // **parent** directory of the "Profile Path" seen at `chrome://version`. // // Note that browsers do not allow launching multiple instances with the same User Data Directory. + // + // **NOTE** Chromium/Chrome: Due to recent Chrome policy changes, automating the default Chrome user profile is not + // supported. Pointing `userDataDir` to Chrome's main "User Data" directory (the profile used for your regular + // browsing) may result in pages not loading or the browser exiting. Create and use a separate directory (for example, + // an empty folder) as your automation profile instead. See https://developer.chrome.com/blog/remote-debugging-port + // for details. LaunchPersistentContext(userDataDir string, options ...BrowserTypeLaunchPersistentContextOptions) (BrowserContext, error) // Returns browser name. For example: `chromium`, `webkit` or `firefox`. @@ -564,7 +575,7 @@ type CDPSession interface { // // 1. method: Protocol method name. // 2. params: Optional method parameters. - Send(method string, params map[string]interface{}) (interface{}, error) + Send(method string, params map[string]any) (any, error) } // Accurately simulating time-dependent behavior is essential for verifying the correctness of applications. Learn @@ -579,7 +590,7 @@ type Clock interface { // // ticks: Time may be the number of milliseconds to advance the clock by or a human-readable string. Valid string formats are // "08" for eight seconds, "01:00" for one minute and "02:34:10" for two hours, 34 minutes and ten seconds. - FastForward(ticks interface{}) error + FastForward(ticks any) error // Install fake implementations for the following time-related functions: // - `Date` @@ -601,7 +612,7 @@ type Clock interface { // // ticks: Time may be the number of milliseconds to advance the clock by or a human-readable string. Valid string formats are // "08" for eight seconds, "01:00" for one minute and "02:34:10" for two hours, 34 minutes and ten seconds. - RunFor(ticks interface{}) error + RunFor(ticks any) error // Advance the clock by jumping forward in time and pause the time. Once this method is called, no timers are fired // unless [Clock.RunFor], [Clock.FastForward], [Clock.PauseAt] or [Clock.Resume] is called. @@ -609,7 +620,7 @@ type Clock interface { // at the specified time and pausing. // // time: Time to pause at. - PauseAt(time interface{}) error + PauseAt(time any) error // Resumes timers. Once this method is called, time resumes flowing, timers are fired as usual. Resume() error @@ -621,13 +632,13 @@ type Clock interface { // time: Time to be set. // // [clock emulation]: https://playwright.dev/docs/clock - SetFixedTime(time interface{}) error + SetFixedTime(time any) error // Sets system time, but does not trigger any timers. Use this to test how the web page reacts to a time shift, for // example switching from summer to winter time, or changing time zones. // // time: Time to be set. - SetSystemTime(time interface{}) error + SetSystemTime(time any) error } // [ConsoleMessage] objects are dispatched by page via the [Page.OnConsole] event. For each console message logged in @@ -651,6 +662,10 @@ type ConsoleMessage interface { // `trace`, `clear`, `startGroup`, `startGroupCollapsed`, `endGroup`, `assert`, `profile`, // `profileEnd`, `count`, `timeEnd`. Type() string + + // The web worker or service worker that produced this console message, if any. Note that console messages from web + // workers also have non-null [ConsoleMessage.Page]. + Worker() (Worker, error) } // [Dialog] objects are dispatched by page via the [Page.OnDialog] event. @@ -825,7 +840,7 @@ type ElementHandle interface { // [TouchEvent]: https://developer.mozilla.org/en-US/docs/Web/API/TouchEvent/TouchEvent // [WheelEvent]: https://developer.mozilla.org/en-US/docs/Web/API/WheelEvent/WheelEvent // [locators]: https://playwright.dev/docs/locators - DispatchEvent(typ string, eventInit ...interface{}) error + DispatchEvent(typ string, eventInit ...any) error // Returns the return value of “[object Object]”. // The method finds an element matching the specified selector in the `ElementHandle`s subtree and passes it as a @@ -839,7 +854,7 @@ type ElementHandle interface { // 2. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 3. arg: Optional argument to pass to “[object Object]”. - EvalOnSelector(selector string, expression string, arg ...interface{}) (interface{}, error) + EvalOnSelector(selector string, expression string, arg ...any) (any, error) // Returns the return value of “[object Object]”. // The method finds all elements matching the specified selector in the `ElementHandle`'s subtree and passes an array @@ -853,7 +868,7 @@ type ElementHandle interface { // 2. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 3. arg: Optional argument to pass to “[object Object]”. - EvalOnSelectorAll(selector string, expression string, arg ...interface{}) (interface{}, error) + EvalOnSelectorAll(selector string, expression string, arg ...any) (any, error) // This method waits for [actionability] checks, focuses the element, fills it and triggers an // `input` event after filling. Note that you can pass an empty string to clear the input field. @@ -1113,7 +1128,7 @@ type ElementHandle interface { // [input element]: https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input // [control]: https://developer.mozilla.org/en-US/docs/Web/API/HTMLLabelElement/control // [locators]: https://playwright.dev/docs/locators - SetInputFiles(files interface{}, options ...ElementHandleSetInputFilesOptions) error + SetInputFiles(files any, options ...ElementHandleSetInputFilesOptions) error // This method taps the element by performing the following steps: // 1. Wait for [actionability] checks on the element, unless “[object Object]” option is set. @@ -1215,7 +1230,7 @@ type FileChooser interface { // Sets the value of the file input this chooser is associated with. If some of the `filePaths` are relative paths, // then they are resolved relative to the current working directory. For empty array, clears the selected files. - SetFiles(files interface{}, options ...FileChooserSetFilesOptions) error + SetFiles(files any, options ...FileChooserSetFilesOptions) error } // At every point of time, page exposes its current frame tree via the [Page.MainFrame] and [Frame.ChildFrames] @@ -1329,7 +1344,7 @@ type Frame interface { // [TouchEvent]: https://developer.mozilla.org/en-US/docs/Web/API/TouchEvent/TouchEvent // [WheelEvent]: https://developer.mozilla.org/en-US/docs/Web/API/WheelEvent/WheelEvent // [locators]: https://playwright.dev/docs/locators - DispatchEvent(selector string, typ string, eventInit interface{}, options ...FrameDispatchEventOptions) error + DispatchEvent(selector string, typ string, eventInit any, options ...FrameDispatchEventOptions) error // // 1. source: A selector to search for an element to drag. If there are multiple elements satisfying the selector, the first will @@ -1350,7 +1365,7 @@ type Frame interface { // 2. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 3. arg: Optional argument to pass to “[object Object]”. - EvalOnSelector(selector string, expression string, arg interface{}, options ...FrameEvalOnSelectorOptions) (interface{}, error) + EvalOnSelector(selector string, expression string, arg any, options ...FrameEvalOnSelectorOptions) (any, error) // Returns the return value of “[object Object]”. // The method finds all elements matching the specified selector within the frame and passes an array of matched @@ -1364,7 +1379,7 @@ type Frame interface { // 2. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 3. arg: Optional argument to pass to “[object Object]”. - EvalOnSelectorAll(selector string, expression string, arg ...interface{}) (interface{}, error) + EvalOnSelectorAll(selector string, expression string, arg ...any) (any, error) // Returns the return value of “[object Object]”. // If the function passed to the [Frame.Evaluate] returns a [Promise], then [Frame.Evaluate] would wait for the @@ -1376,7 +1391,7 @@ type Frame interface { // 1. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 2. arg: Optional argument to pass to “[object Object]”. - Evaluate(expression string, arg ...interface{}) (interface{}, error) + Evaluate(expression string, arg ...any) (any, error) // Returns the return value of “[object Object]” as a [JSHandle]. // The only difference between [Frame.Evaluate] and [Frame.EvaluateHandle] is that [Frame.EvaluateHandle] returns @@ -1387,7 +1402,7 @@ type Frame interface { // 1. expression: JavaScript expression to be evaluated in the browser context. If the expression evaluates to a function, the // function is automatically invoked. // 2. arg: Optional argument to pass to “[object Object]”. - EvaluateHandle(expression string, arg ...interface{}) (JSHandle, error) + EvaluateHandle(expression string, arg ...any) (JSHandle, error) // This method waits for an element matching “[object Object]”, waits for [actionability] checks, // focuses the element, fills it and triggers an `input` event after filling. Note that you can pass an empty string @@ -1445,18 +1460,18 @@ type Frame interface { // Allows locating elements by their alt text. // // text: Text to locate the element for. - GetByAltText(text interface{}, options ...FrameGetByAltTextOptions) Locator + GetByAltText(text any, options ...FrameGetByAltTextOptions) Locator // Allows locating input elements by the text of the associated `