fix: deal with panics in Getter.Get

This commit is contained in:
Leo Antunes 2022-02-25 00:42:14 +01:00
parent ece2929696
commit f7a502813a
2 changed files with 81 additions and 7 deletions

View File

@ -18,7 +18,10 @@ limitations under the License.
// mechanism.
package singleflight
import "sync"
import (
"fmt"
"sync"
)
// call is an in-flight or completed Do call
type call struct {
@ -48,17 +51,22 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, err
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c := &call{
err: fmt.Errorf("singleflight leader panicked"),
}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
defer func() {
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
}()
c.val, c.err = fn()
return c.val, c.err
}

View File

@ -19,6 +19,7 @@ package singleflight
import (
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
@ -83,3 +84,68 @@ func TestDoDupSuppress(t *testing.T) {
t.Errorf("number of calls = %d; want 1", got)
}
}
func TestDoPanic(t *testing.T) {
var g Group
var err error
func() {
defer func() {
// do not let the panic below leak to the test
_ = recover()
}()
_, err = g.Do("key", func() (interface{}, error) {
panic("something went horribly wrong")
})
}()
if err != nil {
t.Errorf("Do error = %v; want someErr", err)
}
// ensure subsequent calls to same key still work
v, err := g.Do("key", func() (interface{}, error) {
return "foo", nil
})
if err != nil {
t.Errorf("Do error = %v; want no error", err)
}
if v.(string) != "foo" {
t.Errorf("got %q; want %q", v, "foo")
}
}
func TestDoConcurrentPanic(t *testing.T) {
var g Group
c := make(chan struct{})
var calls int32
fn := func() (interface{}, error) {
atomic.AddInt32(&calls, 1)
<-c
panic("something went horribly wrong")
}
const n = 10
var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer func() {
// do not let the panic leak to the test
_ = recover()
wg.Done()
}()
v, err := g.Do("key", fn)
if err == nil || !strings.Contains(err.Error(), "singleflight leader panicked") {
t.Errorf("Do error: %v; wanted 'singleflight panicked'", err)
}
if v != nil {
t.Errorf("got %q; want nil", v)
}
}()
}
time.Sleep(100 * time.Millisecond) // let goroutines above block
c <- struct{}{}
wg.Wait()
if got := atomic.LoadInt32(&calls); got != 1 {
t.Errorf("number of calls = %d; want 1", got)
}
}