aboutsummaryrefslogtreecommitdiff
path: root/weed/shell/common.go
blob: cb2df58287a29c507c78967d68fd8c2bc540b39f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package shell

import (
	"errors"
	"fmt"
	"sync"
)

var (
	// Default maximum parallelization/concurrency for commands supporting it.
	DefaultMaxParallelization = 10
)

// ErrorWaitGroup implements a goroutine wait group which aggregates errors, if any.
type ErrorWaitGroup struct {
	maxConcurrency int
	wg             *sync.WaitGroup
	wgSem          chan bool
	errors         []error
	errorsMu       sync.Mutex
}

type ErrorWaitGroupTask func() error

func NewErrorWaitGroup(maxConcurrency int) *ErrorWaitGroup {
	if maxConcurrency <= 0 {
		// no concurrency = one task at the time
		maxConcurrency = 1
	}
	return &ErrorWaitGroup{
		maxConcurrency: maxConcurrency,
		wg:             &sync.WaitGroup{},
		wgSem:          make(chan bool, maxConcurrency),
	}
}

// Reset restarts an ErrorWaitGroup, keeping original settings. Errors and pending goroutines, if any, are flushed.
func (ewg *ErrorWaitGroup) Reset() {
	close(ewg.wgSem)

	ewg.wg = &sync.WaitGroup{}
	ewg.wgSem = make(chan bool, ewg.maxConcurrency)
	ewg.errors = nil
}

// Add queues an ErrorWaitGroupTask to be executed as a goroutine.
func (ewg *ErrorWaitGroup) Add(f ErrorWaitGroupTask) {
	if ewg.maxConcurrency <= 1 {
		// keep run order deterministic when parallelization is off
		ewg.errors = append(ewg.errors, f())
		return
	}

	ewg.wg.Add(1)
	go func() {
		ewg.wgSem <- true

		err := f()
		ewg.errorsMu.Lock()
		ewg.errors = append(ewg.errors, err)
		ewg.errorsMu.Unlock()

		<-ewg.wgSem
		ewg.wg.Done()
	}()
}

// AddErrorf adds an error to an ErrorWaitGroupTask result, without queueing any goroutines.
func (ewg *ErrorWaitGroup) AddErrorf(format string, a ...interface{}) {
	ewg.errorsMu.Lock()
	ewg.errors = append(ewg.errors, fmt.Errorf(format, a...))
	ewg.errorsMu.Unlock()
}

// Wait sleeps until all ErrorWaitGroupTasks are completed, then returns errors for them.
func (ewg *ErrorWaitGroup) Wait() error {
	ewg.wg.Wait()
	return errors.Join(ewg.errors...)
}