While you can use
golang.org/x/sync/singleflight
for thundering herd protection, it's a tad awkward. This is a convenience wrapper for:
- Callers not needing to type assert.
- Idiomatic placement of error return.
- Protobuf-aware entrypoint that automatically clones shared values.
- Uses
msg.CloneVTif available.
- Uses
- Bumps Prometheus counter for monitoring.
- Just remove if you don't need it.
Behind the scenes, a group is created and registered for each return type. This
exchanges convenience at the call-site with a RWMutex on each entry. I'm happy with
this trade-off since it's going to be dwarfed by I/O work performed, but could be
a barrier if used for another workload.
I've ported this tiny package across projects and companies, so figured it was worth posting.
Example
package main
import (
"fmt"
"sync"
"path/to/herd"
)
func main() {
var wg sync.WaitGroup
for range 5 {
wg.Go(func() {
data, shared, err := herd.Do("foo", fetch)
fmt.Printf("shared=%t\terr=%v\tdata=%v\n", shared, err, string(data))
})
}
wg.Wait()
// Output:
// shared=false err=<nil> data=test data
// shared=true err=<nil> data=test data
// shared=true err=<nil> data=test data
// shared=true err=<nil> data=test data
// shared=true err=<nil> data=test data
}
func fetch() ([]byte, error) {
return []byte("test data"), nil
}
Code
package herd
import (
"reflect"
"strconv"
"sync"
"golang.org/x/sync/singleflight"
"google.golang.org/protobuf/proto"
"github.com/prometheus/client_golang/prometheus"
)
var (
mu sync.RWMutex
groups = map[reflect.Type]*singleflight.Group{}
counter = registerCounter()
)
// Do runs a callback with herd protection, indicating
// whether the return values are shared or not.
//
// 'key' must be sufficiently unique to uniquely identify
// your workload under the given type 'T'. When called
// concurrently with the same 'key' value, only one
// invocation will execute the callback while the rest wait for
// the result.
//
// Users must take appropriate safety measures when
// returning pointers, such as cloning values if
// 'shared==true' or ensuring read-only access.
//
// Protobufs will be automatically cloned if their execution
// was shared, and 'shared' will be false.
//
// Use [DoProto] to safely ignore the result of 'shared'.
//
// Under the hood, each type 'T' gets its own
// [singleflight.Group].
func Do[T any](key string, cb func() (T, error)) (value T, shared bool, err error) {
anyValue, shared, err := do(key, cb)
if err != nil {
var zero T
return zero, shared, err
}
// Try our best to clone on behalf of the caller if we
// know how to for safety
if shared {
switch pb := anyValue.(type) {
case vtproto[T]:
// It's OK to lie since from the perspective of the
// caller, the value isn't shared.
return pb.CloneVT(), false, nil
case proto.Message:
return proto.Clone(pb).(T), false, nil
}
}
return anyValue.(T), shared, nil
}
// DoProto runs a callback with herd protection, cloning the
// result when execution was shared
func DoProto[T proto.Message](key string, cb func() (T, error)) (T, error) {
anyValue, shared, err := do(key, cb)
if err != nil {
var zero T
return zero, err
}
if shared {
// Always clone for the caller before returning
switch pb := anyValue.(type) {
case vtproto[T]:
return pb.CloneVT(), nil
case proto.Message:
return proto.Clone(pb).(T), nil
}
}
return anyValue.(T), nil
}
func do[T any](key string, cb func() (T, error)) (any, bool, error) {
gkey := reflect.TypeFor[T]()
g := group(gkey)
anyValue, err, shared := g.Do(key, func() (any, error) {
return cb()
})
counter.With(prometheus.Labels{
"type": gkey.String(),
"error": strconv.FormatBool(err != nil),
"shared": strconv.FormatBool(shared),
}).Inc()
return anyValue, shared, err
}
type vtproto[T any] interface {
CloneVT() T
}
// Forget tells the underlying singleflight to forget about
// the key
func Forget[T any](key string) {
group := group(reflect.TypeFor[T]())
group.Forget(key)
}
// group returns the relevant [singleflight.Group] for the
// type or allocates one
func group(key reflect.Type) *singleflight.Group {
mu.RLock()
group, ok := groups[key]
mu.RUnlock()
if ok {
return group
}
mu.Lock()
group, ok = groups[key]
if !ok {
group = new(singleflight.Group)
groups[key] = group
}
mu.Unlock()
return group
}
// Recording 'key' could be very expensive for keys like
// 'get:dict:1', where '1' may be thousands of values.
//
// But having no context on what the shared value is bad
// too. Given that these are type-based, let's just expose
// that.
func registerCounter() *prometheus.CounterVec {
c := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "herd_total",
Help: "The total number of herd-protected functions",
}, []string{"type", "error", "shared"})
prometheus.MustRegister(c)
return c
}