diff options
Diffstat (limited to 'weed/worker/tasks/task.go')
| -rw-r--r-- | weed/worker/tasks/task.go | 66 |
1 files changed, 60 insertions, 6 deletions
diff --git a/weed/worker/tasks/task.go b/weed/worker/tasks/task.go index 9813ae97f..f3eed8b2d 100644 --- a/weed/worker/tasks/task.go +++ b/weed/worker/tasks/task.go @@ -7,6 +7,7 @@ import ( "time" "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" "github.com/seaweedfs/seaweedfs/weed/worker/types" ) @@ -21,7 +22,8 @@ type BaseTask struct { estimatedDuration time.Duration logger TaskLogger loggerConfig TaskLoggerConfig - progressCallback func(float64) // Callback function for progress updates + progressCallback func(float64, string) // Callback function for progress updates + currentStage string // Current stage description } // NewBaseTask creates a new base task @@ -90,20 +92,64 @@ func (t *BaseTask) SetProgress(progress float64) { } oldProgress := t.progress callback := t.progressCallback + stage := t.currentStage t.progress = progress t.mutex.Unlock() // Log progress change if t.logger != nil && progress != oldProgress { - t.logger.LogProgress(progress, fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress)) + message := stage + if message == "" { + message = fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress) + } + t.logger.LogProgress(progress, message) } // Call progress callback if set if callback != nil && progress != oldProgress { - callback(progress) + callback(progress, stage) } } +// SetProgressWithStage sets the current progress with a stage description +func (t *BaseTask) SetProgressWithStage(progress float64, stage string) { + t.mutex.Lock() + if progress < 0 { + progress = 0 + } + if progress > 100 { + progress = 100 + } + callback := t.progressCallback + t.progress = progress + t.currentStage = stage + t.mutex.Unlock() + + // Log progress change + if t.logger != nil { + t.logger.LogProgress(progress, stage) + } + + // Call progress callback if set + if callback != nil { + callback(progress, stage) + } +} + +// SetCurrentStage sets the current stage description +func (t *BaseTask) SetCurrentStage(stage string) { + t.mutex.Lock() + defer t.mutex.Unlock() + t.currentStage = stage +} + +// GetCurrentStage returns the current stage description +func (t *BaseTask) GetCurrentStage() string { + t.mutex.RLock() + defer t.mutex.RUnlock() + return t.currentStage +} + // Cancel cancels the task func (t *BaseTask) Cancel() error { t.mutex.Lock() @@ -170,7 +216,7 @@ func (t *BaseTask) GetEstimatedDuration() time.Duration { } // SetProgressCallback sets the progress callback function -func (t *BaseTask) SetProgressCallback(callback func(float64)) { +func (t *BaseTask) SetProgressCallback(callback func(float64, string)) { t.mutex.Lock() defer t.mutex.Unlock() t.progressCallback = callback @@ -273,7 +319,7 @@ func (t *BaseTask) ExecuteTask(ctx context.Context, params types.TaskParams, exe if t.logger != nil { t.logger.LogWithFields("INFO", "Task execution started", map[string]interface{}{ "volume_id": params.VolumeID, - "server": params.Server, + "server": getServerFromSources(params.TypedParams.Sources), "collection": params.Collection, }) } @@ -362,7 +408,7 @@ func ValidateParams(params types.TaskParams, requiredFields ...string) error { return &ValidationError{Field: field, Message: "volume_id is required"} } case "server": - if params.Server == "" { + if len(params.TypedParams.Sources) == 0 { return &ValidationError{Field: field, Message: "server is required"} } case "collection": @@ -383,3 +429,11 @@ type ValidationError struct { func (e *ValidationError) Error() string { return e.Field + ": " + e.Message } + +// getServerFromSources extracts the server address from unified sources +func getServerFromSources(sources []*worker_pb.TaskSource) string { + if len(sources) > 0 { + return sources[0].Node + } + return "" +} |
