From f7a502813aef0a8e38156b6ff59f1fd360b89df3 Mon Sep 17 00:00:00 2001 From: Leo Antunes Date: Fri, 25 Feb 2022 00:42:14 +0100 Subject: [PATCH] fix: deal with panics in Getter.Get --- singleflight/singleflight.go | 22 +++++++---- singleflight/singleflight_test.go | 66 +++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index a2a6326..41f3e4e 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -18,7 +18,10 @@ limitations under the License. // mechanism. package singleflight -import "sync" +import ( + "fmt" + "sync" +) // call is an in-flight or completed Do call type call struct { @@ -48,17 +51,22 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, err c.wg.Wait() return c.val, c.err } - c := new(call) + c := &call{ + err: fmt.Errorf("singleflight leader panicked"), + } c.wg.Add(1) g.m[key] = c g.mu.Unlock() - c.val, c.err = fn() - c.wg.Done() + defer func() { + c.wg.Done() - g.mu.Lock() - delete(g.m, key) - g.mu.Unlock() + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + }() + + c.val, c.err = fn() return c.val, c.err } diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index 47b4d3d..8f4254a 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -19,6 +19,7 @@ package singleflight import ( "errors" "fmt" + "strings" "sync" "sync/atomic" "testing" @@ -83,3 +84,68 @@ func TestDoDupSuppress(t *testing.T) { t.Errorf("number of calls = %d; want 1", got) } } + +func TestDoPanic(t *testing.T) { + var g Group + var err error + func() { + defer func() { + // do not let the panic below leak to the test + _ = recover() + }() + _, err = g.Do("key", func() (interface{}, error) { + panic("something went horribly wrong") + }) + }() + if err != nil { + t.Errorf("Do error = %v; want someErr", err) + } + // ensure subsequent calls to same key still work + v, err := g.Do("key", func() (interface{}, error) { + return "foo", nil + }) + if err != nil { + t.Errorf("Do error = %v; want no error", err) + } + if v.(string) != "foo" { + t.Errorf("got %q; want %q", v, "foo") + } +} + +func TestDoConcurrentPanic(t *testing.T) { + var g Group + c := make(chan struct{}) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + <-c + panic("something went horribly wrong") + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer func() { + // do not let the panic leak to the test + _ = recover() + wg.Done() + }() + + v, err := g.Do("key", fn) + if err == nil || !strings.Contains(err.Error(), "singleflight leader panicked") { + t.Errorf("Do error: %v; wanted 'singleflight panicked'", err) + } + if v != nil { + t.Errorf("got %q; want nil", v) + } + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- struct{}{} + wg.Wait() + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("number of calls = %d; want 1", got) + } +}