aboutsummaryrefslogtreecommitdiff
path: root/weed/admin/dash/worker_grpc_server.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/admin/dash/worker_grpc_server.go')
-rw-r--r--weed/admin/dash/worker_grpc_server.go461
1 files changed, 461 insertions, 0 deletions
diff --git a/weed/admin/dash/worker_grpc_server.go b/weed/admin/dash/worker_grpc_server.go
new file mode 100644
index 000000000..c824cc388
--- /dev/null
+++ b/weed/admin/dash/worker_grpc_server.go
@@ -0,0 +1,461 @@
+package dash
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/pb"
+ "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
+ "github.com/seaweedfs/seaweedfs/weed/security"
+ "github.com/seaweedfs/seaweedfs/weed/util"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/peer"
+)
+
+// WorkerGrpcServer implements the WorkerService gRPC interface
+type WorkerGrpcServer struct {
+ worker_pb.UnimplementedWorkerServiceServer
+ adminServer *AdminServer
+
+ // Worker connection management
+ connections map[string]*WorkerConnection
+ connMutex sync.RWMutex
+
+ // gRPC server
+ grpcServer *grpc.Server
+ listener net.Listener
+ running bool
+ stopChan chan struct{}
+}
+
+// WorkerConnection represents an active worker connection
+type WorkerConnection struct {
+ workerID string
+ stream worker_pb.WorkerService_WorkerStreamServer
+ lastSeen time.Time
+ capabilities []MaintenanceTaskType
+ address string
+ maxConcurrent int32
+ outgoing chan *worker_pb.AdminMessage
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+// NewWorkerGrpcServer creates a new gRPC server for worker connections
+func NewWorkerGrpcServer(adminServer *AdminServer) *WorkerGrpcServer {
+ return &WorkerGrpcServer{
+ adminServer: adminServer,
+ connections: make(map[string]*WorkerConnection),
+ stopChan: make(chan struct{}),
+ }
+}
+
+// StartWithTLS starts the gRPC server on the specified port with optional TLS
+func (s *WorkerGrpcServer) StartWithTLS(port int) error {
+ if s.running {
+ return fmt.Errorf("worker gRPC server is already running")
+ }
+
+ // Create listener
+ listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
+ if err != nil {
+ return fmt.Errorf("failed to listen on port %d: %v", port, err)
+ }
+
+ // Create gRPC server with optional TLS
+ grpcServer := pb.NewGrpcServer(security.LoadServerTLS(util.GetViper(), "grpc.admin"))
+
+ worker_pb.RegisterWorkerServiceServer(grpcServer, s)
+
+ s.grpcServer = grpcServer
+ s.listener = listener
+ s.running = true
+
+ // Start cleanup routine
+ go s.cleanupRoutine()
+
+ // Start serving in a goroutine
+ go func() {
+ if err := s.grpcServer.Serve(listener); err != nil {
+ if s.running {
+ glog.Errorf("Worker gRPC server error: %v", err)
+ }
+ }
+ }()
+
+ return nil
+}
+
+// Stop stops the gRPC server
+func (s *WorkerGrpcServer) Stop() error {
+ if !s.running {
+ return nil
+ }
+
+ s.running = false
+ close(s.stopChan)
+
+ // Close all worker connections
+ s.connMutex.Lock()
+ for _, conn := range s.connections {
+ conn.cancel()
+ close(conn.outgoing)
+ }
+ s.connections = make(map[string]*WorkerConnection)
+ s.connMutex.Unlock()
+
+ // Stop gRPC server
+ if s.grpcServer != nil {
+ s.grpcServer.GracefulStop()
+ }
+
+ // Close listener
+ if s.listener != nil {
+ s.listener.Close()
+ }
+
+ glog.Infof("Worker gRPC server stopped")
+ return nil
+}
+
+// WorkerStream handles bidirectional communication with workers
+func (s *WorkerGrpcServer) WorkerStream(stream worker_pb.WorkerService_WorkerStreamServer) error {
+ ctx := stream.Context()
+
+ // get client address
+ address := findClientAddress(ctx)
+
+ // Wait for initial registration message
+ msg, err := stream.Recv()
+ if err != nil {
+ return fmt.Errorf("failed to receive registration message: %v", err)
+ }
+
+ registration := msg.GetRegistration()
+ if registration == nil {
+ return fmt.Errorf("first message must be registration")
+ }
+ registration.Address = address
+
+ workerID := registration.WorkerId
+ if workerID == "" {
+ return fmt.Errorf("worker ID cannot be empty")
+ }
+
+ glog.Infof("Worker %s connecting from %s", workerID, registration.Address)
+
+ // Create worker connection
+ connCtx, connCancel := context.WithCancel(ctx)
+ conn := &WorkerConnection{
+ workerID: workerID,
+ stream: stream,
+ lastSeen: time.Now(),
+ address: registration.Address,
+ maxConcurrent: registration.MaxConcurrent,
+ outgoing: make(chan *worker_pb.AdminMessage, 100),
+ ctx: connCtx,
+ cancel: connCancel,
+ }
+
+ // Convert capabilities
+ capabilities := make([]MaintenanceTaskType, len(registration.Capabilities))
+ for i, cap := range registration.Capabilities {
+ capabilities[i] = MaintenanceTaskType(cap)
+ }
+ conn.capabilities = capabilities
+
+ // Register connection
+ s.connMutex.Lock()
+ s.connections[workerID] = conn
+ s.connMutex.Unlock()
+
+ // Register worker with maintenance manager
+ s.registerWorkerWithManager(conn)
+
+ // Send registration response
+ regResponse := &worker_pb.AdminMessage{
+ Timestamp: time.Now().Unix(),
+ Message: &worker_pb.AdminMessage_RegistrationResponse{
+ RegistrationResponse: &worker_pb.RegistrationResponse{
+ Success: true,
+ Message: "Worker registered successfully",
+ },
+ },
+ }
+
+ select {
+ case conn.outgoing <- regResponse:
+ case <-time.After(5 * time.Second):
+ glog.Errorf("Failed to send registration response to worker %s", workerID)
+ }
+
+ // Start outgoing message handler
+ go s.handleOutgoingMessages(conn)
+
+ // Handle incoming messages
+ for {
+ select {
+ case <-ctx.Done():
+ glog.Infof("Worker %s connection closed: %v", workerID, ctx.Err())
+ s.unregisterWorker(workerID)
+ return nil
+ case <-connCtx.Done():
+ glog.Infof("Worker %s connection cancelled", workerID)
+ s.unregisterWorker(workerID)
+ return nil
+ default:
+ }
+
+ msg, err := stream.Recv()
+ if err != nil {
+ if err == io.EOF {
+ glog.Infof("Worker %s disconnected", workerID)
+ } else {
+ glog.Errorf("Error receiving from worker %s: %v", workerID, err)
+ }
+ s.unregisterWorker(workerID)
+ return err
+ }
+
+ conn.lastSeen = time.Now()
+ s.handleWorkerMessage(conn, msg)
+ }
+}
+
+// handleOutgoingMessages sends messages to worker
+func (s *WorkerGrpcServer) handleOutgoingMessages(conn *WorkerConnection) {
+ for {
+ select {
+ case <-conn.ctx.Done():
+ return
+ case msg, ok := <-conn.outgoing:
+ if !ok {
+ return
+ }
+
+ if err := conn.stream.Send(msg); err != nil {
+ glog.Errorf("Failed to send message to worker %s: %v", conn.workerID, err)
+ conn.cancel()
+ return
+ }
+ }
+ }
+}
+
+// handleWorkerMessage processes incoming messages from workers
+func (s *WorkerGrpcServer) handleWorkerMessage(conn *WorkerConnection, msg *worker_pb.WorkerMessage) {
+ workerID := conn.workerID
+
+ switch m := msg.Message.(type) {
+ case *worker_pb.WorkerMessage_Heartbeat:
+ s.handleHeartbeat(conn, m.Heartbeat)
+
+ case *worker_pb.WorkerMessage_TaskRequest:
+ s.handleTaskRequest(conn, m.TaskRequest)
+
+ case *worker_pb.WorkerMessage_TaskUpdate:
+ s.handleTaskUpdate(conn, m.TaskUpdate)
+
+ case *worker_pb.WorkerMessage_TaskComplete:
+ s.handleTaskCompletion(conn, m.TaskComplete)
+
+ case *worker_pb.WorkerMessage_Shutdown:
+ glog.Infof("Worker %s shutting down: %s", workerID, m.Shutdown.Reason)
+ s.unregisterWorker(workerID)
+
+ default:
+ glog.Warningf("Unknown message type from worker %s", workerID)
+ }
+}
+
+// registerWorkerWithManager registers the worker with the maintenance manager
+func (s *WorkerGrpcServer) registerWorkerWithManager(conn *WorkerConnection) {
+ if s.adminServer.maintenanceManager == nil {
+ return
+ }
+
+ worker := &MaintenanceWorker{
+ ID: conn.workerID,
+ Address: conn.address,
+ LastHeartbeat: time.Now(),
+ Status: "active",
+ Capabilities: conn.capabilities,
+ MaxConcurrent: int(conn.maxConcurrent),
+ CurrentLoad: 0,
+ }
+
+ s.adminServer.maintenanceManager.RegisterWorker(worker)
+ glog.V(1).Infof("Registered worker %s with maintenance manager", conn.workerID)
+}
+
+// handleHeartbeat processes heartbeat messages
+func (s *WorkerGrpcServer) handleHeartbeat(conn *WorkerConnection, heartbeat *worker_pb.WorkerHeartbeat) {
+ if s.adminServer.maintenanceManager != nil {
+ s.adminServer.maintenanceManager.UpdateWorkerHeartbeat(conn.workerID)
+ }
+
+ // Send heartbeat response
+ response := &worker_pb.AdminMessage{
+ Timestamp: time.Now().Unix(),
+ Message: &worker_pb.AdminMessage_HeartbeatResponse{
+ HeartbeatResponse: &worker_pb.HeartbeatResponse{
+ Success: true,
+ Message: "Heartbeat acknowledged",
+ },
+ },
+ }
+
+ select {
+ case conn.outgoing <- response:
+ case <-time.After(time.Second):
+ glog.Warningf("Failed to send heartbeat response to worker %s", conn.workerID)
+ }
+}
+
+// handleTaskRequest processes task requests from workers
+func (s *WorkerGrpcServer) handleTaskRequest(conn *WorkerConnection, request *worker_pb.TaskRequest) {
+ if s.adminServer.maintenanceManager == nil {
+ return
+ }
+
+ // Get next task from maintenance manager
+ task := s.adminServer.maintenanceManager.GetNextTask(conn.workerID, conn.capabilities)
+
+ if task != nil {
+ // Send task assignment
+ assignment := &worker_pb.AdminMessage{
+ Timestamp: time.Now().Unix(),
+ Message: &worker_pb.AdminMessage_TaskAssignment{
+ TaskAssignment: &worker_pb.TaskAssignment{
+ TaskId: task.ID,
+ TaskType: string(task.Type),
+ Params: &worker_pb.TaskParams{
+ VolumeId: task.VolumeID,
+ Server: task.Server,
+ Collection: task.Collection,
+ Parameters: convertTaskParameters(task.Parameters),
+ },
+ Priority: int32(task.Priority),
+ CreatedTime: time.Now().Unix(),
+ },
+ },
+ }
+
+ select {
+ case conn.outgoing <- assignment:
+ glog.V(2).Infof("Assigned task %s to worker %s", task.ID, conn.workerID)
+ case <-time.After(time.Second):
+ glog.Warningf("Failed to send task assignment to worker %s", conn.workerID)
+ }
+ }
+}
+
+// handleTaskUpdate processes task progress updates
+func (s *WorkerGrpcServer) handleTaskUpdate(conn *WorkerConnection, update *worker_pb.TaskUpdate) {
+ if s.adminServer.maintenanceManager != nil {
+ s.adminServer.maintenanceManager.UpdateTaskProgress(update.TaskId, float64(update.Progress))
+ glog.V(3).Infof("Updated task %s progress: %.1f%%", update.TaskId, update.Progress)
+ }
+}
+
+// handleTaskCompletion processes task completion notifications
+func (s *WorkerGrpcServer) handleTaskCompletion(conn *WorkerConnection, completion *worker_pb.TaskComplete) {
+ if s.adminServer.maintenanceManager != nil {
+ errorMsg := ""
+ if !completion.Success {
+ errorMsg = completion.ErrorMessage
+ }
+ s.adminServer.maintenanceManager.CompleteTask(completion.TaskId, errorMsg)
+
+ if completion.Success {
+ glog.V(1).Infof("Worker %s completed task %s successfully", conn.workerID, completion.TaskId)
+ } else {
+ glog.Errorf("Worker %s failed task %s: %s", conn.workerID, completion.TaskId, completion.ErrorMessage)
+ }
+ }
+}
+
+// unregisterWorker removes a worker connection
+func (s *WorkerGrpcServer) unregisterWorker(workerID string) {
+ s.connMutex.Lock()
+ if conn, exists := s.connections[workerID]; exists {
+ conn.cancel()
+ close(conn.outgoing)
+ delete(s.connections, workerID)
+ }
+ s.connMutex.Unlock()
+
+ glog.V(1).Infof("Unregistered worker %s", workerID)
+}
+
+// cleanupRoutine periodically cleans up stale connections
+func (s *WorkerGrpcServer) cleanupRoutine() {
+ ticker := time.NewTicker(30 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-s.stopChan:
+ return
+ case <-ticker.C:
+ s.cleanupStaleConnections()
+ }
+ }
+}
+
+// cleanupStaleConnections removes connections that haven't been seen recently
+func (s *WorkerGrpcServer) cleanupStaleConnections() {
+ cutoff := time.Now().Add(-2 * time.Minute)
+
+ s.connMutex.Lock()
+ defer s.connMutex.Unlock()
+
+ for workerID, conn := range s.connections {
+ if conn.lastSeen.Before(cutoff) {
+ glog.Warningf("Cleaning up stale worker connection: %s", workerID)
+ conn.cancel()
+ close(conn.outgoing)
+ delete(s.connections, workerID)
+ }
+ }
+}
+
+// GetConnectedWorkers returns a list of currently connected workers
+func (s *WorkerGrpcServer) GetConnectedWorkers() []string {
+ s.connMutex.RLock()
+ defer s.connMutex.RUnlock()
+
+ workers := make([]string, 0, len(s.connections))
+ for workerID := range s.connections {
+ workers = append(workers, workerID)
+ }
+ return workers
+}
+
+// convertTaskParameters converts task parameters to protobuf format
+func convertTaskParameters(params map[string]interface{}) map[string]string {
+ result := make(map[string]string)
+ for key, value := range params {
+ result[key] = fmt.Sprintf("%v", value)
+ }
+ return result
+}
+
+func findClientAddress(ctx context.Context) string {
+ // fmt.Printf("FromContext %+v\n", ctx)
+ pr, ok := peer.FromContext(ctx)
+ if !ok {
+ glog.Error("failed to get peer from ctx")
+ return ""
+ }
+ if pr.Addr == net.Addr(nil) {
+ glog.Error("failed to get peer address")
+ return ""
+ }
+ return pr.Addr.String()
+}