Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions prometheus/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ func (err AlreadyRegisteredError) Error() string {
// by a Gatherer to report multiple errors during MetricFamily gathering.
type MultiError []error

// SafeMultiError is a thread-safe wrapper around MultiError using a mutex.
type SafeMultiError struct {
mu sync.Mutex
errs MultiError
}

// Appends the provided error to the contained MultiError in a thread-safe way.
func (s *SafeMultiError) Append(err error) {
s.mu.Lock()
s.errs.Append(err)
s.mu.Unlock()
}

// Error formats the contained errors as a bullet point list, preceded by the
// total number of errors. Note that this results in a multi-line string.
func (errs MultiError) Error() string {
Expand Down Expand Up @@ -408,6 +421,16 @@ func (r *Registry) MustRegister(cs ...Collector) {
}
}

// MustGather implements Gatherer.
// Wraps around Gather and panics if Gather fails for any reason.
func (r *Registry) MustGather() []*dto.MetricFamily {
mfs, err := r.Gather()
if err != nil {
panic(err)
}
return mfs
}

// Gather implements Gatherer.
func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
r.mtx.RLock()
Expand All @@ -423,7 +446,7 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
uncheckedMetricChan = make(chan Metric, capMetricChan)
metricHashes = map[uint64]struct{}{}
wg sync.WaitGroup
errs MultiError // The collected errors to return in the end.
safeErrs = &SafeMultiError{} // To collect errors in a threadsafe way
registeredDescIDs map[uint64]struct{} // Only used for pedantic checks
)

Expand Down Expand Up @@ -453,9 +476,9 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
for {
select {
case collector := <-checkedCollectors:
collector.Collect(checkedMetricChan)
safeErrs.Append((safeCollect(collector, checkedMetricChan)))
case collector := <-uncheckedCollectors:
collector.Collect(uncheckedMetricChan)
safeErrs.Append(safeCollect(collector, uncheckedMetricChan))
default:
return
}
Expand Down Expand Up @@ -499,7 +522,7 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
cmc = nil
break
}
errs.Append(processMetric(
safeErrs.Append(processMetric(
metric, metricFamiliesByName,
metricHashes,
registeredDescIDs,
Expand All @@ -509,7 +532,7 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
umc = nil
break
}
errs.Append(processMetric(
safeErrs.Append(processMetric(
metric, metricFamiliesByName,
metricHashes,
nil,
Expand All @@ -526,7 +549,7 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
cmc = nil
break
}
errs.Append(processMetric(
safeErrs.Append(processMetric(
metric, metricFamiliesByName,
metricHashes,
registeredDescIDs,
Expand All @@ -536,7 +559,7 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
umc = nil
break
}
errs.Append(processMetric(
safeErrs.Append(processMetric(
metric, metricFamiliesByName,
metricHashes,
nil,
Expand All @@ -556,7 +579,8 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
break
}
}
return internal.NormalizeMetricFamilies(metricFamiliesByName), errs.MaybeUnwrap()

return internal.NormalizeMetricFamilies(metricFamiliesByName), safeErrs.errs.MaybeUnwrap()
}

// Describe implements Collector.
Expand All @@ -571,6 +595,24 @@ func (r *Registry) Describe(ch chan<- *Desc) {
}
}

// Helper wrapper around Collector.Collect.
// It tries to collect from the channel, recovers on panic and
// if it has recovered from a panic, then it sends an InvalidMetric into
// the channel with an InvalidDesc, and an error that includes a stack trace.
func safeCollect(c Collector, ch chan<- Metric) (err error) {
defer func() {
if r := recover(); r != nil {
buf := make([]byte, 64<<10) // 64 KB
n := runtime.Stack(buf, false)
err = fmt.Errorf("prometheus collector panic recovered: type=%T: error=%v\nstack trace=%s", c, r, buf[:n])
ch <- NewInvalidMetric(NewInvalidDesc(err), err)
}
}()
c.Collect(ch)

return err
}

// Collect implements Collector.
func (r *Registry) Collect(ch chan<- Metric) {
r.mtx.RLock()
Expand Down
50 changes: 50 additions & 0 deletions prometheus/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"net/http/httptest"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -1306,6 +1307,55 @@ func (co *customCollector) Collect(ch chan<- prometheus.Metric) {
co.collectFunc(ch)
}

// TestCollectorOnMetricPanic ensures that if a collector panics while collecting a metric,
// the panic is recovered and the error is returned by Gather. It also checks that the metric
// collected before the panic is still present in the gathered metrics. Additionally,
// it verifies that if a collector does not panic, Gather returns the collected metrics without error.
func TestCollectorOnMetricPanic(t *testing.T) {
reg := prometheus.NewRegistry()

desc := prometheus.NewDesc("metric_a", "", nil, nil)
metric := prometheus.MustNewConstMetric(desc, prometheus.CounterValue, 1)
timestamp := time.Now()

panicCollector := &customCollector{
collectFunc: func(ch chan<- prometheus.Metric) {
ch <- prometheus.NewMetricWithTimestamp(timestamp, metric)
panic("test panic message") // Panic during metric collection
},
}
reg.MustRegister(panicCollector)

mfs, err := reg.Gather()
if err == nil {
t.Error("metric should return error instead of nil")
}

// Check if metric_a is there
if len(mfs) != 1 || mfs[0].GetName() != "metric_a" {
t.Error("expected metric_a to be present in the gathered metrics")
}
if !strings.Contains(err.Error(), "test panic message") {
t.Errorf("expected panic message in error, got: %v", err)
}

reg = prometheus.NewRegistry()
desc = prometheus.NewDesc("metric_b", "", nil, nil)
metric = prometheus.MustNewConstMetric(desc, prometheus.CounterValue, 1)
timestamp = time.Now()

nonPanicCollector := &customCollector{
collectFunc: func(ch chan<- prometheus.Metric) {
ch <- prometheus.NewMetricWithTimestamp(timestamp, metric)
},
}
reg.MustRegister(nonPanicCollector)
_, err2 := reg.Gather()
if err2 != nil {
t.Error("metric should not return error:", err2)
}
}

// TestCheckMetricConsistency
func TestCheckMetricConsistency(t *testing.T) {
reg := prometheus.NewRegistry()
Expand Down
Loading