aboutsummaryrefslogtreecommitdiff
path: root/weed/worker/tasks/task.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/worker/tasks/task.go')
-rw-r--r--weed/worker/tasks/task.go66
1 files changed, 60 insertions, 6 deletions
diff --git a/weed/worker/tasks/task.go b/weed/worker/tasks/task.go
index 9813ae97f..f3eed8b2d 100644
--- a/weed/worker/tasks/task.go
+++ b/weed/worker/tasks/task.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
)
@@ -21,7 +22,8 @@ type BaseTask struct {
estimatedDuration time.Duration
logger TaskLogger
loggerConfig TaskLoggerConfig
- progressCallback func(float64) // Callback function for progress updates
+ progressCallback func(float64, string) // Callback function for progress updates
+ currentStage string // Current stage description
}
// NewBaseTask creates a new base task
@@ -90,20 +92,64 @@ func (t *BaseTask) SetProgress(progress float64) {
}
oldProgress := t.progress
callback := t.progressCallback
+ stage := t.currentStage
t.progress = progress
t.mutex.Unlock()
// Log progress change
if t.logger != nil && progress != oldProgress {
- t.logger.LogProgress(progress, fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress))
+ message := stage
+ if message == "" {
+ message = fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress)
+ }
+ t.logger.LogProgress(progress, message)
}
// Call progress callback if set
if callback != nil && progress != oldProgress {
- callback(progress)
+ callback(progress, stage)
}
}
+// SetProgressWithStage sets the current progress with a stage description
+func (t *BaseTask) SetProgressWithStage(progress float64, stage string) {
+ t.mutex.Lock()
+ if progress < 0 {
+ progress = 0
+ }
+ if progress > 100 {
+ progress = 100
+ }
+ callback := t.progressCallback
+ t.progress = progress
+ t.currentStage = stage
+ t.mutex.Unlock()
+
+ // Log progress change
+ if t.logger != nil {
+ t.logger.LogProgress(progress, stage)
+ }
+
+ // Call progress callback if set
+ if callback != nil {
+ callback(progress, stage)
+ }
+}
+
+// SetCurrentStage sets the current stage description
+func (t *BaseTask) SetCurrentStage(stage string) {
+ t.mutex.Lock()
+ defer t.mutex.Unlock()
+ t.currentStage = stage
+}
+
+// GetCurrentStage returns the current stage description
+func (t *BaseTask) GetCurrentStage() string {
+ t.mutex.RLock()
+ defer t.mutex.RUnlock()
+ return t.currentStage
+}
+
// Cancel cancels the task
func (t *BaseTask) Cancel() error {
t.mutex.Lock()
@@ -170,7 +216,7 @@ func (t *BaseTask) GetEstimatedDuration() time.Duration {
}
// SetProgressCallback sets the progress callback function
-func (t *BaseTask) SetProgressCallback(callback func(float64)) {
+func (t *BaseTask) SetProgressCallback(callback func(float64, string)) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.progressCallback = callback
@@ -273,7 +319,7 @@ func (t *BaseTask) ExecuteTask(ctx context.Context, params types.TaskParams, exe
if t.logger != nil {
t.logger.LogWithFields("INFO", "Task execution started", map[string]interface{}{
"volume_id": params.VolumeID,
- "server": params.Server,
+ "server": getServerFromSources(params.TypedParams.Sources),
"collection": params.Collection,
})
}
@@ -362,7 +408,7 @@ func ValidateParams(params types.TaskParams, requiredFields ...string) error {
return &ValidationError{Field: field, Message: "volume_id is required"}
}
case "server":
- if params.Server == "" {
+ if len(params.TypedParams.Sources) == 0 {
return &ValidationError{Field: field, Message: "server is required"}
}
case "collection":
@@ -383,3 +429,11 @@ type ValidationError struct {
func (e *ValidationError) Error() string {
return e.Field + ": " + e.Message
}
+
+// getServerFromSources extracts the server address from unified sources
+func getServerFromSources(sources []*worker_pb.TaskSource) string {
+ if len(sources) > 0 {
+ return sources[0].Node
+ }
+ return ""
+}