diff options
| author | Chris Lu <chrislusf@users.noreply.github.com> | 2019-09-14 01:06:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-09-14 01:06:14 -0700 |
| commit | b0e4771135d5d7076e96ce9cb30a601ed0d3ba50 (patch) | |
| tree | 31a67399d61756deeb618347c4a2af5c98027f46 | |
| parent | ae53f636804e41c2c7a0817e8f35434a00b6eacb (diff) | |
| parent | bb31462b52d24bc7ebf54266b45801c198685f70 (diff) | |
| download | seaweedfs-b0e4771135d5d7076e96ce9cb30a601ed0d3ba50.tar.xz seaweedfs-b0e4771135d5d7076e96ce9cb30a601ed0d3ba50.zip | |
Merge pull request #1060 from divinerapier/master
fix: non-thread-safe rand will panic
| -rw-r--r-- | weed/wdclient/vid_map.go | 26 | ||||
| -rw-r--r-- | weed/wdclient/vid_map_test.go | 77 |
2 files changed, 98 insertions, 5 deletions
diff --git a/weed/wdclient/vid_map.go b/weed/wdclient/vid_map.go index 01d9cdaed..7a3f50aad 100644 --- a/weed/wdclient/vid_map.go +++ b/weed/wdclient/vid_map.go @@ -3,11 +3,11 @@ package wdclient import ( "errors" "fmt" - "math/rand" + "math" "strconv" "strings" "sync" - "time" + "sync/atomic" "github.com/chrislusf/seaweedfs/weed/glog" ) @@ -20,14 +20,25 @@ type Location struct { type vidMap struct { sync.RWMutex vid2Locations map[uint32][]Location - r *rand.Rand + + cursor int64 } func newVidMap() vidMap { return vidMap{ vid2Locations: make(map[uint32][]Location), - r: rand.New(rand.NewSource(time.Now().UnixNano())), + cursor: -1, + } +} + +func (vc *vidMap) getLocationIndex(length int64) (int64, error) { + if length <= 0 { + return 0, fmt.Errorf("invalid length: %d", length) } + if atomic.LoadInt64(&vc.cursor) == math.MaxInt64 { + atomic.CompareAndSwapInt64(&vc.cursor, math.MaxInt64, -1) + } + return atomic.AddInt64(&vc.cursor, 1) % length, nil } func (vc *vidMap) LookupVolumeServerUrl(vid string) (serverUrl string, err error) { @@ -94,7 +105,12 @@ func (vc *vidMap) GetRandomLocation(vid uint32) (serverUrl string, err error) { return "", fmt.Errorf("volume %d not found", vid) } - return locations[vc.r.Intn(len(locations))].Url, nil + index, err := vc.getLocationIndex(int64(len(locations))) + if err != nil { + return "", fmt.Errorf("volume %d. %v", vid, err) + } + + return locations[index].Url, nil } func (vc *vidMap) addLocation(vid uint32, location Location) { diff --git a/weed/wdclient/vid_map_test.go b/weed/wdclient/vid_map_test.go new file mode 100644 index 000000000..ae4680e7a --- /dev/null +++ b/weed/wdclient/vid_map_test.go @@ -0,0 +1,77 @@ +package wdclient + +import ( + "fmt" + "math" + "testing" +) + +func TestLocationIndex(t *testing.T) { + vm := vidMap{} + // test must be failed + mustFailed := func(length int64) { + _, err := vm.getLocationIndex(length) + if err == nil { + t.Errorf("length %d must be failed", length) + } + if err.Error() != fmt.Sprintf("invalid length: %d", length) { + t.Errorf("length %d must be failed. error: %v", length, err) + } + } + + mustFailed(-1) + mustFailed(0) + + mustOk := func(length, cursor, expect int64) { + if length <= 0 { + t.Fatal("please don't do this") + } + vm.cursor = cursor + got, err := vm.getLocationIndex(length) + if err != nil { + t.Errorf("length: %d, why? %v\n", length, err) + return + } + if got != expect { + t.Errorf("cursor: %d, length: %d, expect: %d, got: %d\n", cursor, length, expect, got) + return + } + } + + for i := int64(-1); i < 100; i++ { + mustOk(7, i, (i+1)%7) + } + + // when cursor reaches MaxInt64 + mustOk(7, math.MaxInt64, 0) + + // test with constructor + vm = newVidMap() + length := int64(7) + for i := int64(0); i < 100; i++ { + got, err := vm.getLocationIndex(length) + if err != nil { + t.Errorf("length: %d, why? %v\n", length, err) + return + } + if got != i%length { + t.Errorf("length: %d, i: %d, got: %d\n", length, i, got) + } + } +} + +func BenchmarkLocationIndex(b *testing.B) { + b.SetParallelism(8) + vm := vidMap{ + cursor: math.MaxInt64 - 10000, + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := vm.getLocationIndex(3) + if err != nil { + b.Error(err) + } + } + }) +} |
