diff options
Diffstat (limited to 'seaweedfs-rdma-sidecar/rdma-engine/src')
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/error.rs | 269 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/ipc.rs | 542 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/lib.rs | 153 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/main.rs | 175 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/memory.rs | 630 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/rdma.rs | 467 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/session.rs | 587 | ||||
| -rw-r--r-- | seaweedfs-rdma-sidecar/rdma-engine/src/ucx.rs | 606 |
8 files changed, 3429 insertions, 0 deletions
diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/error.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/error.rs new file mode 100644 index 000000000..be60ef4aa --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/error.rs @@ -0,0 +1,269 @@ +//! Error types and handling for the RDMA engine + +// use std::fmt; // Unused for now +use thiserror::Error; + +/// Result type alias for RDMA operations +pub type RdmaResult<T> = Result<T, RdmaError>; + +/// Comprehensive error types for RDMA operations +#[derive(Error, Debug)] +pub enum RdmaError { + /// RDMA device not found or unavailable + #[error("RDMA device '{device}' not found or unavailable")] + DeviceNotFound { device: String }, + + /// Failed to initialize RDMA context + #[error("Failed to initialize RDMA context: {reason}")] + ContextInitFailed { reason: String }, + + /// Failed to allocate protection domain + #[error("Failed to allocate protection domain: {reason}")] + PdAllocFailed { reason: String }, + + /// Failed to create completion queue + #[error("Failed to create completion queue: {reason}")] + CqCreationFailed { reason: String }, + + /// Failed to create queue pair + #[error("Failed to create queue pair: {reason}")] + QpCreationFailed { reason: String }, + + /// Memory registration failed + #[error("Memory registration failed: {reason}")] + MemoryRegFailed { reason: String }, + + /// RDMA operation failed + #[error("RDMA operation failed: {operation}, status: {status}")] + OperationFailed { operation: String, status: i32 }, + + /// Session not found + #[error("Session '{session_id}' not found")] + SessionNotFound { session_id: String }, + + /// Session expired + #[error("Session '{session_id}' has expired")] + SessionExpired { session_id: String }, + + /// Too many active sessions + #[error("Maximum number of sessions ({max_sessions}) exceeded")] + TooManySessions { max_sessions: usize }, + + /// IPC communication error + #[error("IPC communication error: {reason}")] + IpcError { reason: String }, + + /// Serialization/deserialization error + #[error("Serialization error: {reason}")] + SerializationError { reason: String }, + + /// Invalid request parameters + #[error("Invalid request: {reason}")] + InvalidRequest { reason: String }, + + /// Insufficient buffer space + #[error("Insufficient buffer space: requested {requested}, available {available}")] + InsufficientBuffer { requested: usize, available: usize }, + + /// Hardware not supported + #[error("Hardware not supported: {reason}")] + UnsupportedHardware { reason: String }, + + /// System resource exhausted + #[error("System resource exhausted: {resource}")] + ResourceExhausted { resource: String }, + + /// Permission denied + #[error("Permission denied: {operation}")] + PermissionDenied { operation: String }, + + /// Network timeout + #[error("Network timeout after {timeout_ms}ms")] + NetworkTimeout { timeout_ms: u64 }, + + /// I/O error + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// Generic error for unexpected conditions + #[error("Internal error: {reason}")] + Internal { reason: String }, +} + +impl RdmaError { + /// Create a new DeviceNotFound error + pub fn device_not_found(device: impl Into<String>) -> Self { + Self::DeviceNotFound { device: device.into() } + } + + /// Create a new ContextInitFailed error + pub fn context_init_failed(reason: impl Into<String>) -> Self { + Self::ContextInitFailed { reason: reason.into() } + } + + /// Create a new MemoryRegFailed error + pub fn memory_reg_failed(reason: impl Into<String>) -> Self { + Self::MemoryRegFailed { reason: reason.into() } + } + + /// Create a new OperationFailed error + pub fn operation_failed(operation: impl Into<String>, status: i32) -> Self { + Self::OperationFailed { + operation: operation.into(), + status + } + } + + /// Create a new SessionNotFound error + pub fn session_not_found(session_id: impl Into<String>) -> Self { + Self::SessionNotFound { session_id: session_id.into() } + } + + /// Create a new IpcError + pub fn ipc_error(reason: impl Into<String>) -> Self { + Self::IpcError { reason: reason.into() } + } + + /// Create a new InvalidRequest error + pub fn invalid_request(reason: impl Into<String>) -> Self { + Self::InvalidRequest { reason: reason.into() } + } + + /// Create a new Internal error + pub fn internal(reason: impl Into<String>) -> Self { + Self::Internal { reason: reason.into() } + } + + /// Check if this error is recoverable + pub fn is_recoverable(&self) -> bool { + match self { + // Network and temporary errors are recoverable + Self::NetworkTimeout { .. } | + Self::ResourceExhausted { .. } | + Self::TooManySessions { .. } | + Self::InsufficientBuffer { .. } => true, + + // Session errors are recoverable (can retry with new session) + Self::SessionNotFound { .. } | + Self::SessionExpired { .. } => true, + + // Hardware and system errors are generally not recoverable + Self::DeviceNotFound { .. } | + Self::ContextInitFailed { .. } | + Self::UnsupportedHardware { .. } | + Self::PermissionDenied { .. } => false, + + // IPC errors might be recoverable + Self::IpcError { .. } | + Self::SerializationError { .. } => true, + + // Invalid requests are not recoverable without fixing the request + Self::InvalidRequest { .. } => false, + + // RDMA operation failures might be recoverable + Self::OperationFailed { .. } => true, + + // Memory and resource allocation failures depend on the cause + Self::PdAllocFailed { .. } | + Self::CqCreationFailed { .. } | + Self::QpCreationFailed { .. } | + Self::MemoryRegFailed { .. } => false, + + // I/O errors might be recoverable + Self::Io(_) => true, + + // Internal errors are generally not recoverable + Self::Internal { .. } => false, + } + } + + /// Get error category for metrics and logging + pub fn category(&self) -> &'static str { + match self { + Self::DeviceNotFound { .. } | + Self::ContextInitFailed { .. } | + Self::UnsupportedHardware { .. } => "hardware", + + Self::PdAllocFailed { .. } | + Self::CqCreationFailed { .. } | + Self::QpCreationFailed { .. } | + Self::MemoryRegFailed { .. } => "resource", + + Self::OperationFailed { .. } => "rdma", + + Self::SessionNotFound { .. } | + Self::SessionExpired { .. } | + Self::TooManySessions { .. } => "session", + + Self::IpcError { .. } | + Self::SerializationError { .. } => "ipc", + + Self::InvalidRequest { .. } => "request", + + Self::InsufficientBuffer { .. } | + Self::ResourceExhausted { .. } => "capacity", + + Self::PermissionDenied { .. } => "security", + + Self::NetworkTimeout { .. } => "network", + + Self::Io(_) => "io", + + Self::Internal { .. } => "internal", + } + } +} + +/// Convert from various RDMA library error codes +impl From<i32> for RdmaError { + fn from(errno: i32) -> Self { + match errno { + libc::ENODEV => Self::DeviceNotFound { + device: "unknown".to_string() + }, + libc::ENOMEM => Self::ResourceExhausted { + resource: "memory".to_string() + }, + libc::EPERM | libc::EACCES => Self::PermissionDenied { + operation: "RDMA operation".to_string() + }, + libc::ETIMEDOUT => Self::NetworkTimeout { + timeout_ms: 5000 + }, + libc::ENOSPC => Self::InsufficientBuffer { + requested: 0, + available: 0 + }, + _ => Self::Internal { + reason: format!("System error: {}", errno) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_creation() { + let err = RdmaError::device_not_found("mlx5_0"); + assert!(matches!(err, RdmaError::DeviceNotFound { .. })); + assert_eq!(err.category(), "hardware"); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_error_recoverability() { + assert!(RdmaError::NetworkTimeout { timeout_ms: 1000 }.is_recoverable()); + assert!(!RdmaError::DeviceNotFound { device: "test".to_string() }.is_recoverable()); + assert!(RdmaError::SessionExpired { session_id: "test".to_string() }.is_recoverable()); + } + + #[test] + fn test_error_display() { + let err = RdmaError::InvalidRequest { reason: "missing field".to_string() }; + assert!(err.to_string().contains("Invalid request")); + assert!(err.to_string().contains("missing field")); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/ipc.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/ipc.rs new file mode 100644 index 000000000..a578c2d7d --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/ipc.rs @@ -0,0 +1,542 @@ +//! IPC (Inter-Process Communication) module for communicating with Go sidecar +//! +//! This module handles high-performance IPC between the Rust RDMA engine and +//! the Go control plane sidecar using Unix domain sockets and MessagePack serialization. + +use crate::{RdmaError, RdmaResult, rdma::RdmaContext, session::SessionManager}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::net::{UnixListener, UnixStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tracing::{info, debug, error}; +use uuid::Uuid; +use std::path::Path; + +/// Atomic counter for generating unique work request IDs +/// This ensures no hash collisions that could cause incorrect completion handling +static NEXT_WR_ID: AtomicU64 = AtomicU64::new(1); + +/// IPC message types between Go sidecar and Rust RDMA engine +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum IpcMessage { + /// Request to start an RDMA read operation + StartRead(StartReadRequest), + /// Response with RDMA session information + StartReadResponse(StartReadResponse), + + /// Request to complete an RDMA operation + CompleteRead(CompleteReadRequest), + /// Response confirming completion + CompleteReadResponse(CompleteReadResponse), + + /// Request for engine capabilities + GetCapabilities(GetCapabilitiesRequest), + /// Response with engine capabilities + GetCapabilitiesResponse(GetCapabilitiesResponse), + + /// Health check ping + Ping(PingRequest), + /// Ping response + Pong(PongResponse), + + /// Error response + Error(ErrorResponse), +} + +/// Request to start RDMA read operation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StartReadRequest { + /// Volume ID in SeaweedFS + pub volume_id: u32, + /// Needle ID in SeaweedFS + pub needle_id: u64, + /// Needle cookie for validation + pub cookie: u32, + /// File offset within the needle data + pub offset: u64, + /// Size to read (0 = entire needle) + pub size: u64, + /// Remote memory address from Go sidecar + pub remote_addr: u64, + /// Remote key for RDMA access + pub remote_key: u32, + /// Session timeout in seconds + pub timeout_secs: u64, + /// Authentication token (optional) + pub auth_token: Option<String>, +} + +/// Response with RDMA session details +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StartReadResponse { + /// Unique session identifier + pub session_id: String, + /// Local buffer address for RDMA + pub local_addr: u64, + /// Local key for RDMA operations + pub local_key: u32, + /// Actual size that will be transferred + pub transfer_size: u64, + /// Expected CRC checksum + pub expected_crc: u32, + /// Session expiration timestamp (Unix nanoseconds) + pub expires_at_ns: u64, +} + +/// Request to complete RDMA operation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteReadRequest { + /// Session ID to complete + pub session_id: String, + /// Whether the operation was successful + pub success: bool, + /// Actual bytes transferred + pub bytes_transferred: u64, + /// Client-computed CRC (for verification) + pub client_crc: Option<u32>, + /// Error message if failed + pub error_message: Option<String>, +} + +/// Response confirming completion +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompleteReadResponse { + /// Whether completion was successful + pub success: bool, + /// Server-computed CRC for verification + pub server_crc: Option<u32>, + /// Any cleanup messages + pub message: Option<String>, +} + +/// Request for engine capabilities +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetCapabilitiesRequest { + /// Client identifier + pub client_id: Option<String>, +} + +/// Response with engine capabilities +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetCapabilitiesResponse { + /// RDMA device name + pub device_name: String, + /// RDMA device vendor ID + pub vendor_id: u32, + /// Maximum transfer size in bytes + pub max_transfer_size: u64, + /// Maximum concurrent sessions + pub max_sessions: usize, + /// Current active sessions + pub active_sessions: usize, + /// Device port GID + pub port_gid: String, + /// Device port LID + pub port_lid: u16, + /// Supported authentication methods + pub supported_auth: Vec<String>, + /// Engine version + pub version: String, + /// Whether real RDMA hardware is available + pub real_rdma: bool, +} + +/// Health check ping request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PingRequest { + /// Client timestamp (Unix nanoseconds) + pub timestamp_ns: u64, + /// Client identifier + pub client_id: Option<String>, +} + +/// Ping response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PongResponse { + /// Original client timestamp + pub client_timestamp_ns: u64, + /// Server timestamp (Unix nanoseconds) + pub server_timestamp_ns: u64, + /// Round-trip time in nanoseconds (server perspective) + pub server_rtt_ns: u64, +} + +/// Error response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorResponse { + /// Error code + pub code: String, + /// Human-readable error message + pub message: String, + /// Error category + pub category: String, + /// Whether the error is recoverable + pub recoverable: bool, +} + +impl From<&RdmaError> for ErrorResponse { + fn from(error: &RdmaError) -> Self { + Self { + code: format!("{:?}", error), + message: error.to_string(), + category: error.category().to_string(), + recoverable: error.is_recoverable(), + } + } +} + +/// IPC server handling communication with Go sidecar +pub struct IpcServer { + socket_path: String, + listener: Option<UnixListener>, + rdma_context: Arc<RdmaContext>, + session_manager: Arc<SessionManager>, + shutdown_flag: Arc<parking_lot::RwLock<bool>>, +} + +impl IpcServer { + /// Create new IPC server + pub async fn new( + socket_path: &str, + rdma_context: Arc<RdmaContext>, + session_manager: Arc<SessionManager>, + ) -> RdmaResult<Self> { + // Remove existing socket if it exists + if Path::new(socket_path).exists() { + std::fs::remove_file(socket_path) + .map_err(|e| RdmaError::ipc_error(format!("Failed to remove existing socket: {}", e)))?; + } + + Ok(Self { + socket_path: socket_path.to_string(), + listener: None, + rdma_context, + session_manager, + shutdown_flag: Arc::new(parking_lot::RwLock::new(false)), + }) + } + + /// Start the IPC server + pub async fn run(&mut self) -> RdmaResult<()> { + let listener = UnixListener::bind(&self.socket_path) + .map_err(|e| RdmaError::ipc_error(format!("Failed to bind Unix socket: {}", e)))?; + + info!("๐ฏ IPC server listening on: {}", self.socket_path); + self.listener = Some(listener); + + if let Some(ref listener) = self.listener { + loop { + // Check shutdown flag + if *self.shutdown_flag.read() { + info!("IPC server shutting down"); + break; + } + + // Accept connection with timeout + let accept_result = tokio::time::timeout( + tokio::time::Duration::from_millis(100), + listener.accept() + ).await; + + match accept_result { + Ok(Ok((stream, addr))) => { + debug!("New IPC connection from: {:?}", addr); + + // Spawn handler for this connection + let rdma_context = self.rdma_context.clone(); + let session_manager = self.session_manager.clone(); + let shutdown_flag = self.shutdown_flag.clone(); + + tokio::spawn(async move { + if let Err(e) = Self::handle_connection(stream, rdma_context, session_manager, shutdown_flag).await { + error!("IPC connection error: {}", e); + } + }); + } + Ok(Err(e)) => { + error!("Failed to accept IPC connection: {}", e); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + Err(_) => { + // Timeout - continue loop to check shutdown flag + continue; + } + } + } + } + + Ok(()) + } + + /// Handle a single IPC connection + async fn handle_connection( + stream: UnixStream, + rdma_context: Arc<RdmaContext>, + session_manager: Arc<SessionManager>, + shutdown_flag: Arc<parking_lot::RwLock<bool>>, + ) -> RdmaResult<()> { + let (reader_half, writer_half) = stream.into_split(); + let mut reader = BufReader::new(reader_half); + let mut writer = BufWriter::new(writer_half); + + let mut buffer = Vec::with_capacity(4096); + + loop { + // Check shutdown + if *shutdown_flag.read() { + break; + } + + // Read message length (4 bytes) + let mut len_bytes = [0u8; 4]; + match tokio::time::timeout( + tokio::time::Duration::from_millis(100), + reader.read_exact(&mut len_bytes) + ).await { + Ok(Ok(_)) => {}, + Ok(Err(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + debug!("IPC connection closed by peer"); + break; + } + Ok(Err(e)) => return Err(RdmaError::ipc_error(format!("Read error: {}", e))), + Err(_) => continue, // Timeout, check shutdown flag + } + + let msg_len = u32::from_le_bytes(len_bytes) as usize; + if msg_len > 1024 * 1024 { // 1MB max message size + return Err(RdmaError::ipc_error("Message too large")); + } + + // Read message data + buffer.clear(); + buffer.resize(msg_len, 0); + reader.read_exact(&mut buffer).await + .map_err(|e| RdmaError::ipc_error(format!("Failed to read message: {}", e)))?; + + // Deserialize message + let request: IpcMessage = rmp_serde::from_slice(&buffer) + .map_err(|e| RdmaError::SerializationError { reason: e.to_string() })?; + + debug!("Received IPC message: {:?}", request); + + // Process message + let response = Self::process_message( + request, + &rdma_context, + &session_manager, + ).await; + + // Serialize response + let response_data = rmp_serde::to_vec(&response) + .map_err(|e| RdmaError::SerializationError { reason: e.to_string() })?; + + // Send response + let response_len = (response_data.len() as u32).to_le_bytes(); + writer.write_all(&response_len).await + .map_err(|e| RdmaError::ipc_error(format!("Failed to write response length: {}", e)))?; + writer.write_all(&response_data).await + .map_err(|e| RdmaError::ipc_error(format!("Failed to write response: {}", e)))?; + writer.flush().await + .map_err(|e| RdmaError::ipc_error(format!("Failed to flush response: {}", e)))?; + + debug!("Sent IPC response"); + } + + Ok(()) + } + + /// Process IPC message and generate response + async fn process_message( + message: IpcMessage, + rdma_context: &Arc<RdmaContext>, + session_manager: &Arc<SessionManager>, + ) -> IpcMessage { + match message { + IpcMessage::Ping(req) => { + let server_timestamp = chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0) as u64; + IpcMessage::Pong(PongResponse { + client_timestamp_ns: req.timestamp_ns, + server_timestamp_ns: server_timestamp, + server_rtt_ns: server_timestamp.saturating_sub(req.timestamp_ns), + }) + } + + IpcMessage::GetCapabilities(_req) => { + let device_info = rdma_context.device_info(); + let active_sessions = session_manager.active_session_count().await; + + IpcMessage::GetCapabilitiesResponse(GetCapabilitiesResponse { + device_name: device_info.name.clone(), + vendor_id: device_info.vendor_id, + max_transfer_size: device_info.max_mr_size, + max_sessions: session_manager.max_sessions(), + active_sessions, + port_gid: device_info.port_gid.clone(), + port_lid: device_info.port_lid, + supported_auth: vec!["none".to_string()], + version: env!("CARGO_PKG_VERSION").to_string(), + real_rdma: cfg!(feature = "real-ucx"), + }) + } + + IpcMessage::StartRead(req) => { + match Self::handle_start_read(req, rdma_context, session_manager).await { + Ok(response) => IpcMessage::StartReadResponse(response), + Err(error) => IpcMessage::Error(ErrorResponse::from(&error)), + } + } + + IpcMessage::CompleteRead(req) => { + match Self::handle_complete_read(req, session_manager).await { + Ok(response) => IpcMessage::CompleteReadResponse(response), + Err(error) => IpcMessage::Error(ErrorResponse::from(&error)), + } + } + + _ => IpcMessage::Error(ErrorResponse { + code: "UNSUPPORTED_MESSAGE".to_string(), + message: "Unsupported message type".to_string(), + category: "request".to_string(), + recoverable: true, + }), + } + } + + /// Handle StartRead request + async fn handle_start_read( + req: StartReadRequest, + rdma_context: &Arc<RdmaContext>, + session_manager: &Arc<SessionManager>, + ) -> RdmaResult<StartReadResponse> { + info!("๐ Starting RDMA read: volume={}, needle={}, size={}", + req.volume_id, req.needle_id, req.size); + + // Create session + let session_id = Uuid::new_v4().to_string(); + let transfer_size = if req.size == 0 { 65536 } else { req.size }; // Default 64KB + + // Allocate local buffer + let buffer = vec![0u8; transfer_size as usize]; + let local_addr = buffer.as_ptr() as u64; + + // Register memory for RDMA + let memory_region = rdma_context.register_memory(local_addr, transfer_size as usize).await?; + + // Create and store session + session_manager.create_session( + session_id.clone(), + req.volume_id, + req.needle_id, + req.remote_addr, + req.remote_key, + transfer_size, + buffer, + memory_region.clone(), + chrono::Duration::seconds(req.timeout_secs as i64), + ).await?; + + // Perform RDMA read with unique work request ID + // Use atomic counter to avoid hash collisions that could cause incorrect completion handling + let wr_id = NEXT_WR_ID.fetch_add(1, Ordering::Relaxed); + rdma_context.post_read( + local_addr, + req.remote_addr, + req.remote_key, + transfer_size as usize, + wr_id, + ).await?; + + // Poll for completion + let completions = rdma_context.poll_completion(1).await?; + if completions.is_empty() { + return Err(RdmaError::operation_failed("RDMA read", -1)); + } + + let completion = &completions[0]; + if completion.status != crate::rdma::CompletionStatus::Success { + return Err(RdmaError::operation_failed("RDMA read", completion.status as i32)); + } + + info!("โ
RDMA read completed: {} bytes", completion.byte_len); + + let expires_at = chrono::Utc::now() + chrono::Duration::seconds(req.timeout_secs as i64); + + Ok(StartReadResponse { + session_id, + local_addr, + local_key: memory_region.lkey, + transfer_size, + expected_crc: 0x12345678, // Mock CRC + expires_at_ns: expires_at.timestamp_nanos_opt().unwrap_or(0) as u64, + }) + } + + /// Handle CompleteRead request + async fn handle_complete_read( + req: CompleteReadRequest, + session_manager: &Arc<SessionManager>, + ) -> RdmaResult<CompleteReadResponse> { + info!("๐ Completing RDMA read session: {}", req.session_id); + + // Clean up session + session_manager.remove_session(&req.session_id).await?; + + Ok(CompleteReadResponse { + success: req.success, + server_crc: Some(0x12345678), // Mock CRC + message: Some("Session completed successfully".to_string()), + }) + } + + /// Shutdown the IPC server + pub async fn shutdown(&mut self) -> RdmaResult<()> { + info!("Shutting down IPC server"); + *self.shutdown_flag.write() = true; + + // Remove socket file + if Path::new(&self.socket_path).exists() { + std::fs::remove_file(&self.socket_path) + .map_err(|e| RdmaError::ipc_error(format!("Failed to remove socket file: {}", e)))?; + } + + Ok(()) + } +} + + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_response_conversion() { + let error = RdmaError::device_not_found("mlx5_0"); + let response = ErrorResponse::from(&error); + + assert!(response.message.contains("mlx5_0")); + assert_eq!(response.category, "hardware"); + assert!(!response.recoverable); + } + + #[test] + fn test_message_serialization() { + let request = IpcMessage::Ping(PingRequest { + timestamp_ns: 12345, + client_id: Some("test".to_string()), + }); + + let serialized = rmp_serde::to_vec(&request).unwrap(); + let deserialized: IpcMessage = rmp_serde::from_slice(&serialized).unwrap(); + + match deserialized { + IpcMessage::Ping(ping) => { + assert_eq!(ping.timestamp_ns, 12345); + assert_eq!(ping.client_id, Some("test".to_string())); + } + _ => panic!("Wrong message type"), + } + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/lib.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/lib.rs new file mode 100644 index 000000000..c92dcf91a --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/lib.rs @@ -0,0 +1,153 @@ +//! High-Performance RDMA Engine for SeaweedFS +//! +//! This crate provides a high-performance RDMA (Remote Direct Memory Access) engine +//! designed to accelerate data transfer operations in SeaweedFS. It communicates with +//! the Go-based sidecar via IPC and handles the performance-critical RDMA operations. +//! +//! # Architecture +//! +//! ```text +//! โโโโโโโโโโโโโโโโโโโโโโโ IPC โโโโโโโโโโโโโโโโโโโโโโโ +//! โ Go Control Plane โโโโโโโโโโโโบโ Rust Data Plane โ +//! โ โ ~300ns โ โ +//! โ โข gRPC Server โ โ โข RDMA Operations โ +//! โ โข Session Mgmt โ โ โข Memory Mgmt โ +//! โ โข HTTP Fallback โ โ โข Hardware Access โ +//! โ โข Error Handling โ โ โข Zero-Copy I/O โ +//! โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโ +//! ``` +//! +//! # Features +//! +//! - `mock-rdma` (default): Mock RDMA operations for testing and development +//! - `real-rdma`: Real RDMA hardware integration using rdma-core bindings + +use std::sync::Arc; +use anyhow::Result; + +pub mod ucx; +pub mod rdma; +pub mod ipc; +pub mod session; +pub mod memory; +pub mod error; + +pub use error::{RdmaError, RdmaResult}; + +/// Configuration for the RDMA engine +#[derive(Debug, Clone)] +pub struct RdmaEngineConfig { + /// RDMA device name (e.g., "mlx5_0") + pub device_name: String, + /// RDMA port number + pub port: u16, + /// Maximum number of concurrent sessions + pub max_sessions: usize, + /// Session timeout in seconds + pub session_timeout_secs: u64, + /// Memory buffer size in bytes + pub buffer_size: usize, + /// IPC socket path + pub ipc_socket_path: String, + /// Enable debug logging + pub debug: bool, +} + +impl Default for RdmaEngineConfig { + fn default() -> Self { + Self { + device_name: "mlx5_0".to_string(), + port: 18515, + max_sessions: 1000, + session_timeout_secs: 300, // 5 minutes + buffer_size: 1024 * 1024 * 1024, // 1GB + ipc_socket_path: "/tmp/rdma-engine.sock".to_string(), + debug: false, + } + } +} + +/// Main RDMA engine instance +pub struct RdmaEngine { + config: RdmaEngineConfig, + rdma_context: Arc<rdma::RdmaContext>, + session_manager: Arc<session::SessionManager>, + ipc_server: Option<ipc::IpcServer>, +} + +impl RdmaEngine { + /// Create a new RDMA engine with the given configuration + pub async fn new(config: RdmaEngineConfig) -> Result<Self> { + tracing::info!("Initializing RDMA engine with config: {:?}", config); + + // Initialize RDMA context + let rdma_context = Arc::new(rdma::RdmaContext::new(&config).await?); + + // Initialize session manager + let session_manager = Arc::new(session::SessionManager::new( + config.max_sessions, + std::time::Duration::from_secs(config.session_timeout_secs), + )); + + Ok(Self { + config, + rdma_context, + session_manager, + ipc_server: None, + }) + } + + /// Start the RDMA engine server + pub async fn run(&mut self) -> Result<()> { + tracing::info!("Starting RDMA engine server on {}", self.config.ipc_socket_path); + + // Start IPC server + let ipc_server = ipc::IpcServer::new( + &self.config.ipc_socket_path, + self.rdma_context.clone(), + self.session_manager.clone(), + ).await?; + + self.ipc_server = Some(ipc_server); + + // Start session cleanup task + let session_manager = self.session_manager.clone(); + tokio::spawn(async move { + session_manager.start_cleanup_task().await; + }); + + // Run IPC server + if let Some(ref mut server) = self.ipc_server { + server.run().await?; + } + + Ok(()) + } + + /// Shutdown the RDMA engine + pub async fn shutdown(&mut self) -> Result<()> { + tracing::info!("Shutting down RDMA engine"); + + if let Some(ref mut server) = self.ipc_server { + server.shutdown().await?; + } + + self.session_manager.shutdown().await; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_rdma_engine_creation() { + let config = RdmaEngineConfig::default(); + let result = RdmaEngine::new(config).await; + + // Should succeed with mock RDMA + assert!(result.is_ok()); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/main.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/main.rs new file mode 100644 index 000000000..996d3a9d5 --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/main.rs @@ -0,0 +1,175 @@ +//! RDMA Engine Server +//! +//! High-performance RDMA engine server that communicates with the Go sidecar +//! via IPC and handles RDMA operations with zero-copy semantics. +//! +//! Usage: +//! ```bash +//! rdma-engine-server --device mlx5_0 --port 18515 --ipc-socket /tmp/rdma-engine.sock +//! ``` + +use clap::Parser; +use rdma_engine::{RdmaEngine, RdmaEngineConfig}; +use std::path::PathBuf; +use tracing::{info, error}; +use tracing_subscriber::{EnvFilter, fmt::layer, prelude::*}; + +#[derive(Parser)] +#[command( + name = "rdma-engine-server", + about = "High-performance RDMA engine for SeaweedFS", + version = env!("CARGO_PKG_VERSION") +)] +struct Args { + /// UCX device name preference (e.g., mlx5_0, or 'auto' for UCX auto-selection) + #[arg(short, long, default_value = "auto")] + device: String, + + /// RDMA port number + #[arg(short, long, default_value_t = 18515)] + port: u16, + + /// Maximum number of concurrent sessions + #[arg(long, default_value_t = 1000)] + max_sessions: usize, + + /// Session timeout in seconds + #[arg(long, default_value_t = 300)] + session_timeout: u64, + + /// Memory buffer size in bytes + #[arg(long, default_value_t = 1024 * 1024 * 1024)] + buffer_size: usize, + + /// IPC socket path + #[arg(long, default_value = "/tmp/rdma-engine.sock")] + ipc_socket: PathBuf, + + /// Enable debug logging + #[arg(long)] + debug: bool, + + /// Configuration file path + #[arg(short, long)] + config: Option<PathBuf>, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + // Initialize tracing + let filter = if args.debug { + EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("debug")) + .unwrap() + } else { + EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("info")) + .unwrap() + }; + + tracing_subscriber::registry() + .with(layer().with_target(false)) + .with(filter) + .init(); + + info!("๐ Starting SeaweedFS UCX RDMA Engine Server"); + info!(" Version: {}", env!("CARGO_PKG_VERSION")); + info!(" UCX Device Preference: {}", args.device); + info!(" Port: {}", args.port); + info!(" Max Sessions: {}", args.max_sessions); + info!(" Buffer Size: {} bytes", args.buffer_size); + info!(" IPC Socket: {}", args.ipc_socket.display()); + info!(" Debug Mode: {}", args.debug); + + // Load configuration + let config = RdmaEngineConfig { + device_name: args.device, + port: args.port, + max_sessions: args.max_sessions, + session_timeout_secs: args.session_timeout, + buffer_size: args.buffer_size, + ipc_socket_path: args.ipc_socket.to_string_lossy().to_string(), + debug: args.debug, + }; + + // Override with config file if provided + if let Some(config_path) = args.config { + info!("Loading configuration from: {}", config_path.display()); + // TODO: Implement configuration file loading + } + + // Create and run RDMA engine + let mut engine = match RdmaEngine::new(config).await { + Ok(engine) => { + info!("โ
RDMA engine initialized successfully"); + engine + } + Err(e) => { + error!("โ Failed to initialize RDMA engine: {}", e); + return Err(e); + } + }; + + // Set up signal handlers for graceful shutdown + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; + let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?; + + // Run engine in background + let engine_handle = tokio::spawn(async move { + if let Err(e) = engine.run().await { + error!("RDMA engine error: {}", e); + return Err(e); + } + Ok(()) + }); + + info!("๐ฏ RDMA engine is running and ready to accept connections"); + info!(" Send SIGTERM or SIGINT to shutdown gracefully"); + + // Wait for shutdown signal + tokio::select! { + _ = sigterm.recv() => { + info!("๐ก Received SIGTERM, shutting down gracefully"); + } + _ = sigint.recv() => { + info!("๐ก Received SIGINT (Ctrl+C), shutting down gracefully"); + } + result = engine_handle => { + match result { + Ok(Ok(())) => info!("๐ RDMA engine completed successfully"), + Ok(Err(e)) => { + error!("โ RDMA engine failed: {}", e); + return Err(e); + } + Err(e) => { + error!("โ RDMA engine task panicked: {}", e); + return Err(anyhow::anyhow!("Engine task panicked: {}", e)); + } + } + } + } + + info!("๐ RDMA engine server shut down complete"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_args_parsing() { + let args = Args::try_parse_from(&[ + "rdma-engine-server", + "--device", "mlx5_0", + "--port", "18515", + "--debug" + ]).unwrap(); + + assert_eq!(args.device, "mlx5_0"); + assert_eq!(args.port, 18515); + assert!(args.debug); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/memory.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/memory.rs new file mode 100644 index 000000000..17a9a5b1d --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/memory.rs @@ -0,0 +1,630 @@ +//! Memory management for RDMA operations +//! +//! This module provides efficient memory allocation, registration, and management +//! for RDMA operations with zero-copy semantics and proper cleanup. + +use crate::{RdmaError, RdmaResult}; +use memmap2::MmapMut; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::Arc; +use tracing::{debug, info, warn}; + +/// Memory pool for efficient buffer allocation +pub struct MemoryPool { + /// Pre-allocated memory regions by size + pools: RwLock<HashMap<usize, Vec<PooledBuffer>>>, + /// Total allocated memory in bytes + total_allocated: RwLock<usize>, + /// Maximum pool size per buffer size + max_pool_size: usize, + /// Maximum total memory usage + max_total_memory: usize, + /// Statistics + stats: RwLock<MemoryPoolStats>, +} + +/// Statistics for memory pool +#[derive(Debug, Clone, Default)] +pub struct MemoryPoolStats { + /// Total allocations requested + pub total_allocations: u64, + /// Total deallocations + pub total_deallocations: u64, + /// Cache hits (reused buffers) + pub cache_hits: u64, + /// Cache misses (new allocations) + pub cache_misses: u64, + /// Current active allocations + pub active_allocations: usize, + /// Peak memory usage in bytes + pub peak_memory_usage: usize, +} + +/// A pooled memory buffer +pub struct PooledBuffer { + /// Raw buffer data + data: Vec<u8>, + /// Size of the buffer + size: usize, + /// Whether the buffer is currently in use + in_use: bool, + /// Creation timestamp + created_at: std::time::Instant, +} + +impl PooledBuffer { + /// Create new pooled buffer + fn new(size: usize) -> Self { + Self { + data: vec![0u8; size], + size, + in_use: false, + created_at: std::time::Instant::now(), + } + } + + /// Get buffer data as slice + pub fn as_slice(&self) -> &[u8] { + &self.data + } + + /// Get buffer data as mutable slice + pub fn as_mut_slice(&mut self) -> &mut [u8] { + &mut self.data + } + + /// Get buffer size + pub fn size(&self) -> usize { + self.size + } + + /// Get buffer age + pub fn age(&self) -> std::time::Duration { + self.created_at.elapsed() + } + + /// Get raw pointer to buffer data + pub fn as_ptr(&self) -> *const u8 { + self.data.as_ptr() + } + + /// Get mutable raw pointer to buffer data + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.data.as_mut_ptr() + } +} + +impl MemoryPool { + /// Create new memory pool + pub fn new(max_pool_size: usize, max_total_memory: usize) -> Self { + info!("๐ง Memory pool initialized: max_pool_size={}, max_total_memory={} bytes", + max_pool_size, max_total_memory); + + Self { + pools: RwLock::new(HashMap::new()), + total_allocated: RwLock::new(0), + max_pool_size, + max_total_memory, + stats: RwLock::new(MemoryPoolStats::default()), + } + } + + /// Allocate buffer from pool + pub fn allocate(&self, size: usize) -> RdmaResult<Arc<RwLock<PooledBuffer>>> { + // Round up to next power of 2 for better pooling + let pool_size = size.next_power_of_two(); + + { + let mut stats = self.stats.write(); + stats.total_allocations += 1; + } + + // Try to get buffer from pool first + { + let mut pools = self.pools.write(); + if let Some(pool) = pools.get_mut(&pool_size) { + // Find available buffer in pool + for buffer in pool.iter_mut() { + if !buffer.in_use { + buffer.in_use = true; + + let mut stats = self.stats.write(); + stats.cache_hits += 1; + stats.active_allocations += 1; + + debug!("๐ฆ Reused buffer from pool: size={}", pool_size); + return Ok(Arc::new(RwLock::new(std::mem::replace( + buffer, + PooledBuffer::new(0) // Placeholder + )))); + } + } + } + } + + // No available buffer in pool, create new one + let total_allocated = *self.total_allocated.read(); + if total_allocated + pool_size > self.max_total_memory { + return Err(RdmaError::ResourceExhausted { + resource: "memory".to_string() + }); + } + + let mut buffer = PooledBuffer::new(pool_size); + buffer.in_use = true; + + // Update allocation tracking + let new_total = { + let mut total = self.total_allocated.write(); + *total += pool_size; + *total + }; + + { + let mut stats = self.stats.write(); + stats.cache_misses += 1; + stats.active_allocations += 1; + if new_total > stats.peak_memory_usage { + stats.peak_memory_usage = new_total; + } + } + + debug!("๐ Allocated new buffer: size={}, total_allocated={}", + pool_size, new_total); + + Ok(Arc::new(RwLock::new(buffer))) + } + + /// Return buffer to pool + pub fn deallocate(&self, buffer: Arc<RwLock<PooledBuffer>>) -> RdmaResult<()> { + let buffer_size = { + let buf = buffer.read(); + buf.size() + }; + + { + let mut stats = self.stats.write(); + stats.total_deallocations += 1; + stats.active_allocations = stats.active_allocations.saturating_sub(1); + } + + // Try to return buffer to pool + { + let mut pools = self.pools.write(); + let pool = pools.entry(buffer_size).or_insert_with(Vec::new); + + if pool.len() < self.max_pool_size { + // Reset buffer state and return to pool + if let Ok(buf) = Arc::try_unwrap(buffer) { + let mut buf = buf.into_inner(); + buf.in_use = false; + buf.data.fill(0); // Clear data for security + pool.push(buf); + + debug!("โป๏ธ Returned buffer to pool: size={}", buffer_size); + return Ok(()); + } + } + } + + // Pool is full or buffer is still referenced, just track deallocation + { + let mut total = self.total_allocated.write(); + *total = total.saturating_sub(buffer_size); + } + + debug!("๐๏ธ Buffer deallocated (not pooled): size={}", buffer_size); + Ok(()) + } + + /// Get memory pool statistics + pub fn stats(&self) -> MemoryPoolStats { + self.stats.read().clone() + } + + /// Get current memory usage + pub fn current_usage(&self) -> usize { + *self.total_allocated.read() + } + + /// Clean up old unused buffers from pools + pub fn cleanup_old_buffers(&self, max_age: std::time::Duration) { + let mut cleaned_count = 0; + let mut cleaned_bytes = 0; + + { + let mut pools = self.pools.write(); + for (size, pool) in pools.iter_mut() { + pool.retain(|buffer| { + if buffer.age() > max_age && !buffer.in_use { + cleaned_count += 1; + cleaned_bytes += size; + false + } else { + true + } + }); + } + } + + if cleaned_count > 0 { + { + let mut total = self.total_allocated.write(); + *total = total.saturating_sub(cleaned_bytes); + } + + info!("๐งน Cleaned up {} old buffers, freed {} bytes", + cleaned_count, cleaned_bytes); + } + } +} + +/// RDMA-specific memory manager +pub struct RdmaMemoryManager { + /// General purpose memory pool + pool: MemoryPool, + /// Memory-mapped regions for large allocations + mmapped_regions: RwLock<HashMap<u64, MmapRegion>>, + /// HugePage allocations (if available) + hugepage_regions: RwLock<HashMap<u64, HugePageRegion>>, + /// Configuration + config: MemoryConfig, +} + +/// Memory configuration +#[derive(Debug, Clone)] +pub struct MemoryConfig { + /// Use hugepages for large allocations + pub use_hugepages: bool, + /// Hugepage size in bytes + pub hugepage_size: usize, + /// Memory pool settings + pub pool_max_size: usize, + /// Maximum total memory usage + pub max_total_memory: usize, + /// Buffer cleanup interval + pub cleanup_interval_secs: u64, +} + +impl Default for MemoryConfig { + fn default() -> Self { + Self { + use_hugepages: true, + hugepage_size: 2 * 1024 * 1024, // 2MB + pool_max_size: 1000, + max_total_memory: 8 * 1024 * 1024 * 1024, // 8GB + cleanup_interval_secs: 300, // 5 minutes + } + } +} + +/// Memory-mapped region +#[allow(dead_code)] +struct MmapRegion { + mmap: MmapMut, + size: usize, + created_at: std::time::Instant, +} + +/// HugePage memory region +#[allow(dead_code)] +struct HugePageRegion { + addr: *mut u8, + size: usize, + created_at: std::time::Instant, +} + +unsafe impl Send for HugePageRegion {} +unsafe impl Sync for HugePageRegion {} + +impl RdmaMemoryManager { + /// Create new RDMA memory manager + pub fn new(config: MemoryConfig) -> Self { + let pool = MemoryPool::new(config.pool_max_size, config.max_total_memory); + + Self { + pool, + mmapped_regions: RwLock::new(HashMap::new()), + hugepage_regions: RwLock::new(HashMap::new()), + config, + } + } + + /// Allocate memory optimized for RDMA operations + pub fn allocate_rdma_buffer(&self, size: usize) -> RdmaResult<RdmaBuffer> { + if size >= self.config.hugepage_size && self.config.use_hugepages { + self.allocate_hugepage_buffer(size) + } else if size >= 64 * 1024 { // Use mmap for large buffers + self.allocate_mmap_buffer(size) + } else { + self.allocate_pool_buffer(size) + } + } + + /// Allocate buffer from memory pool + fn allocate_pool_buffer(&self, size: usize) -> RdmaResult<RdmaBuffer> { + let buffer = self.pool.allocate(size)?; + Ok(RdmaBuffer::Pool { buffer, size }) + } + + /// Allocate memory-mapped buffer + fn allocate_mmap_buffer(&self, size: usize) -> RdmaResult<RdmaBuffer> { + let mmap = MmapMut::map_anon(size) + .map_err(|e| RdmaError::memory_reg_failed(format!("mmap failed: {}", e)))?; + + let addr = mmap.as_ptr() as u64; + let region = MmapRegion { + mmap, + size, + created_at: std::time::Instant::now(), + }; + + { + let mut regions = self.mmapped_regions.write(); + regions.insert(addr, region); + } + + debug!("๐บ๏ธ Allocated mmap buffer: addr=0x{:x}, size={}", addr, size); + Ok(RdmaBuffer::Mmap { addr, size }) + } + + /// Allocate hugepage buffer (Linux-specific) + fn allocate_hugepage_buffer(&self, size: usize) -> RdmaResult<RdmaBuffer> { + #[cfg(target_os = "linux")] + { + use nix::sys::mman::{mmap, MapFlags, ProtFlags}; + + // Round up to hugepage boundary + let aligned_size = (size + self.config.hugepage_size - 1) & !(self.config.hugepage_size - 1); + + let addr = unsafe { + // For anonymous mapping, we can use -1 as the file descriptor + use std::os::fd::BorrowedFd; + let fake_fd = BorrowedFd::borrow_raw(-1); // Anonymous mapping uses -1 + + mmap( + None, // ptr::null_mut() -> None + std::num::NonZero::new(aligned_size).unwrap(), // aligned_size -> NonZero<usize> + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS | MapFlags::MAP_HUGETLB, + Some(&fake_fd), // Use borrowed FD for -1 wrapped in Some + 0, + ) + }; + + match addr { + Ok(addr) => { + let addr_u64 = addr as u64; + let region = HugePageRegion { + addr: addr as *mut u8, + size: aligned_size, + created_at: std::time::Instant::now(), + }; + + { + let mut regions = self.hugepage_regions.write(); + regions.insert(addr_u64, region); + } + + info!("๐ฅ Allocated hugepage buffer: addr=0x{:x}, size={}", addr_u64, aligned_size); + Ok(RdmaBuffer::HugePage { addr: addr_u64, size: aligned_size }) + } + Err(_) => { + warn!("Failed to allocate hugepage buffer, falling back to mmap"); + self.allocate_mmap_buffer(size) + } + } + } + + #[cfg(not(target_os = "linux"))] + { + warn!("HugePages not supported on this platform, using mmap"); + self.allocate_mmap_buffer(size) + } + } + + /// Deallocate RDMA buffer + pub fn deallocate_buffer(&self, buffer: RdmaBuffer) -> RdmaResult<()> { + match buffer { + RdmaBuffer::Pool { buffer, .. } => { + self.pool.deallocate(buffer) + } + RdmaBuffer::Mmap { addr, .. } => { + let mut regions = self.mmapped_regions.write(); + regions.remove(&addr); + debug!("๐๏ธ Deallocated mmap buffer: addr=0x{:x}", addr); + Ok(()) + } + RdmaBuffer::HugePage { addr, size } => { + { + let mut regions = self.hugepage_regions.write(); + regions.remove(&addr); + } + + #[cfg(target_os = "linux")] + { + use nix::sys::mman::munmap; + unsafe { + let _ = munmap(addr as *mut std::ffi::c_void, size); + } + } + + debug!("๐๏ธ Deallocated hugepage buffer: addr=0x{:x}, size={}", addr, size); + Ok(()) + } + } + } + + /// Get memory manager statistics + pub fn stats(&self) -> MemoryManagerStats { + let pool_stats = self.pool.stats(); + let mmap_count = self.mmapped_regions.read().len(); + let hugepage_count = self.hugepage_regions.read().len(); + + MemoryManagerStats { + pool_stats, + mmap_regions: mmap_count, + hugepage_regions: hugepage_count, + total_memory_usage: self.pool.current_usage(), + } + } + + /// Start background cleanup task + pub async fn start_cleanup_task(&self) -> tokio::task::JoinHandle<()> { + let pool = MemoryPool::new(self.config.pool_max_size, self.config.max_total_memory); + let cleanup_interval = std::time::Duration::from_secs(self.config.cleanup_interval_secs); + + tokio::spawn(async move { + let mut interval = tokio::time::interval( + tokio::time::Duration::from_secs(300) // 5 minutes + ); + + loop { + interval.tick().await; + pool.cleanup_old_buffers(cleanup_interval); + } + }) + } +} + +/// RDMA buffer types +pub enum RdmaBuffer { + /// Buffer from memory pool + Pool { + buffer: Arc<RwLock<PooledBuffer>>, + size: usize, + }, + /// Memory-mapped buffer + Mmap { + addr: u64, + size: usize, + }, + /// HugePage buffer + HugePage { + addr: u64, + size: usize, + }, +} + +impl RdmaBuffer { + /// Get buffer address + pub fn addr(&self) -> u64 { + match self { + Self::Pool { buffer, .. } => { + buffer.read().as_ptr() as u64 + } + Self::Mmap { addr, .. } => *addr, + Self::HugePage { addr, .. } => *addr, + } + } + + /// Get buffer size + pub fn size(&self) -> usize { + match self { + Self::Pool { size, .. } => *size, + Self::Mmap { size, .. } => *size, + Self::HugePage { size, .. } => *size, + } + } + + /// Get buffer as Vec (copy to avoid lifetime issues) + pub fn to_vec(&self) -> Vec<u8> { + match self { + Self::Pool { buffer, .. } => { + buffer.read().as_slice().to_vec() + } + Self::Mmap { addr, size } => { + unsafe { + let slice = std::slice::from_raw_parts(*addr as *const u8, *size); + slice.to_vec() + } + } + Self::HugePage { addr, size } => { + unsafe { + let slice = std::slice::from_raw_parts(*addr as *const u8, *size); + slice.to_vec() + } + } + } + } + + /// Get buffer type name + pub fn buffer_type(&self) -> &'static str { + match self { + Self::Pool { .. } => "pool", + Self::Mmap { .. } => "mmap", + Self::HugePage { .. } => "hugepage", + } + } +} + +/// Memory manager statistics +#[derive(Debug, Clone)] +pub struct MemoryManagerStats { + /// Pool statistics + pub pool_stats: MemoryPoolStats, + /// Number of mmap regions + pub mmap_regions: usize, + /// Number of hugepage regions + pub hugepage_regions: usize, + /// Total memory usage in bytes + pub total_memory_usage: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memory_pool_allocation() { + let pool = MemoryPool::new(10, 1024 * 1024); + + let buffer1 = pool.allocate(4096).unwrap(); + let buffer2 = pool.allocate(4096).unwrap(); + + assert_eq!(buffer1.read().size(), 4096); + assert_eq!(buffer2.read().size(), 4096); + + let stats = pool.stats(); + assert_eq!(stats.total_allocations, 2); + assert_eq!(stats.cache_misses, 2); + } + + #[test] + fn test_memory_pool_reuse() { + let pool = MemoryPool::new(10, 1024 * 1024); + + // Allocate and deallocate + let buffer = pool.allocate(4096).unwrap(); + let size = buffer.read().size(); + pool.deallocate(buffer).unwrap(); + + // Allocate again - should reuse + let buffer2 = pool.allocate(4096).unwrap(); + assert_eq!(buffer2.read().size(), size); + + let stats = pool.stats(); + assert_eq!(stats.cache_hits, 1); + } + + #[tokio::test] + async fn test_rdma_memory_manager() { + let config = MemoryConfig::default(); + let manager = RdmaMemoryManager::new(config); + + // Test small buffer (pool) + let small_buffer = manager.allocate_rdma_buffer(1024).unwrap(); + assert_eq!(small_buffer.size(), 1024); + assert_eq!(small_buffer.buffer_type(), "pool"); + + // Test large buffer (mmap) + let large_buffer = manager.allocate_rdma_buffer(128 * 1024).unwrap(); + assert_eq!(large_buffer.size(), 128 * 1024); + assert_eq!(large_buffer.buffer_type(), "mmap"); + + // Clean up + manager.deallocate_buffer(small_buffer).unwrap(); + manager.deallocate_buffer(large_buffer).unwrap(); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/rdma.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/rdma.rs new file mode 100644 index 000000000..7549a217e --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/rdma.rs @@ -0,0 +1,467 @@ +//! RDMA operations and context management +//! +//! This module provides both mock and real RDMA implementations: +//! - Mock implementation for development and testing +//! - Real implementation using libibverbs for production + +use crate::{RdmaResult, RdmaEngineConfig}; +use tracing::{debug, warn, info}; +use parking_lot::RwLock; + +/// RDMA completion status +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum CompletionStatus { + Success, + LocalLengthError, + LocalQpOperationError, + LocalEecOperationError, + LocalProtectionError, + WrFlushError, + MemoryWindowBindError, + BadResponseError, + LocalAccessError, + RemoteInvalidRequestError, + RemoteAccessError, + RemoteOperationError, + TransportRetryCounterExceeded, + RnrRetryCounterExceeded, + LocalRddViolationError, + RemoteInvalidRdRequest, + RemoteAbortedError, + InvalidEecnError, + InvalidEecStateError, + FatalError, + ResponseTimeoutError, + GeneralError, +} + +impl From<u32> for CompletionStatus { + fn from(status: u32) -> Self { + match status { + 0 => Self::Success, + 1 => Self::LocalLengthError, + 2 => Self::LocalQpOperationError, + 3 => Self::LocalEecOperationError, + 4 => Self::LocalProtectionError, + 5 => Self::WrFlushError, + 6 => Self::MemoryWindowBindError, + 7 => Self::BadResponseError, + 8 => Self::LocalAccessError, + 9 => Self::RemoteInvalidRequestError, + 10 => Self::RemoteAccessError, + 11 => Self::RemoteOperationError, + 12 => Self::TransportRetryCounterExceeded, + 13 => Self::RnrRetryCounterExceeded, + 14 => Self::LocalRddViolationError, + 15 => Self::RemoteInvalidRdRequest, + 16 => Self::RemoteAbortedError, + 17 => Self::InvalidEecnError, + 18 => Self::InvalidEecStateError, + 19 => Self::FatalError, + 20 => Self::ResponseTimeoutError, + _ => Self::GeneralError, + } + } +} + +/// RDMA operation types +#[derive(Debug, Clone, Copy)] +pub enum RdmaOp { + Read, + Write, + Send, + Receive, + Atomic, +} + +/// RDMA memory region information +#[derive(Debug, Clone)] +pub struct MemoryRegion { + /// Local virtual address + pub addr: u64, + /// Remote key for RDMA operations + pub rkey: u32, + /// Local key for local operations + pub lkey: u32, + /// Size of the memory region + pub size: usize, + /// Whether the region is registered with RDMA hardware + pub registered: bool, +} + +/// RDMA work completion +#[derive(Debug)] +pub struct WorkCompletion { + /// Work request ID + pub wr_id: u64, + /// Completion status + pub status: CompletionStatus, + /// Operation type + pub opcode: RdmaOp, + /// Number of bytes transferred + pub byte_len: u32, + /// Immediate data (if any) + pub imm_data: Option<u32>, +} + +/// RDMA context implementation (simplified enum approach) +#[derive(Debug)] +pub enum RdmaContextImpl { + Mock(MockRdmaContext), + // Ucx(UcxRdmaContext), // TODO: Add UCX implementation +} + +/// RDMA device information +#[derive(Debug, Clone)] +pub struct RdmaDeviceInfo { + pub name: String, + pub vendor_id: u32, + pub vendor_part_id: u32, + pub hw_ver: u32, + pub max_mr: u32, + pub max_qp: u32, + pub max_cq: u32, + pub max_mr_size: u64, + pub port_gid: String, + pub port_lid: u16, +} + +/// Main RDMA context +pub struct RdmaContext { + inner: RdmaContextImpl, + #[allow(dead_code)] + config: RdmaEngineConfig, +} + +impl RdmaContext { + /// Create new RDMA context + pub async fn new(config: &RdmaEngineConfig) -> RdmaResult<Self> { + let inner = if cfg!(feature = "real-ucx") { + RdmaContextImpl::Mock(MockRdmaContext::new(config).await?) // TODO: Use UCX when ready + } else { + RdmaContextImpl::Mock(MockRdmaContext::new(config).await?) + }; + + Ok(Self { + inner, + config: config.clone(), + }) + } + + /// Register memory for RDMA operations + pub async fn register_memory(&self, addr: u64, size: usize) -> RdmaResult<MemoryRegion> { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.register_memory(addr, size).await, + } + } + + /// Deregister memory region + pub async fn deregister_memory(&self, region: &MemoryRegion) -> RdmaResult<()> { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.deregister_memory(region).await, + } + } + + /// Post RDMA read operation + pub async fn post_read(&self, + local_addr: u64, + remote_addr: u64, + rkey: u32, + size: usize, + wr_id: u64, + ) -> RdmaResult<()> { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.post_read(local_addr, remote_addr, rkey, size, wr_id).await, + } + } + + /// Post RDMA write operation + pub async fn post_write(&self, + local_addr: u64, + remote_addr: u64, + rkey: u32, + size: usize, + wr_id: u64, + ) -> RdmaResult<()> { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.post_write(local_addr, remote_addr, rkey, size, wr_id).await, + } + } + + /// Poll for work completions + pub async fn poll_completion(&self, max_completions: usize) -> RdmaResult<Vec<WorkCompletion>> { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.poll_completion(max_completions).await, + } + } + + /// Get device information + pub fn device_info(&self) -> &RdmaDeviceInfo { + match &self.inner { + RdmaContextImpl::Mock(ctx) => ctx.device_info(), + } + } +} + +/// Mock RDMA context for testing and development +#[derive(Debug)] +pub struct MockRdmaContext { + device_info: RdmaDeviceInfo, + registered_regions: RwLock<Vec<MemoryRegion>>, + pending_operations: RwLock<Vec<WorkCompletion>>, + #[allow(dead_code)] + config: RdmaEngineConfig, +} + +impl MockRdmaContext { + pub async fn new(config: &RdmaEngineConfig) -> RdmaResult<Self> { + warn!("๐ก Using MOCK RDMA implementation - for development only!"); + info!(" Device: {} (mock)", config.device_name); + info!(" Port: {} (mock)", config.port); + + let device_info = RdmaDeviceInfo { + name: config.device_name.clone(), + vendor_id: 0x02c9, // Mellanox mock vendor ID + vendor_part_id: 0x1017, // ConnectX-5 mock part ID + hw_ver: 0, + max_mr: 131072, + max_qp: 262144, + max_cq: 65536, + max_mr_size: 1024 * 1024 * 1024 * 1024, // 1TB mock + port_gid: "fe80:0000:0000:0000:0200:5eff:fe12:3456".to_string(), + port_lid: 1, + }; + + Ok(Self { + device_info, + registered_regions: RwLock::new(Vec::new()), + pending_operations: RwLock::new(Vec::new()), + config: config.clone(), + }) + } +} + +impl MockRdmaContext { + pub async fn register_memory(&self, addr: u64, size: usize) -> RdmaResult<MemoryRegion> { + debug!("๐ก Mock: Registering memory region addr=0x{:x}, size={}", addr, size); + + // Simulate registration delay + tokio::time::sleep(tokio::time::Duration::from_micros(10)).await; + + let region = MemoryRegion { + addr, + rkey: 0x12345678, // Mock remote key + lkey: 0x87654321, // Mock local key + size, + registered: true, + }; + + self.registered_regions.write().push(region.clone()); + + Ok(region) + } + + pub async fn deregister_memory(&self, region: &MemoryRegion) -> RdmaResult<()> { + debug!("๐ก Mock: Deregistering memory region rkey=0x{:x}", region.rkey); + + let mut regions = self.registered_regions.write(); + regions.retain(|r| r.rkey != region.rkey); + + Ok(()) + } + + pub async fn post_read(&self, + local_addr: u64, + remote_addr: u64, + rkey: u32, + size: usize, + wr_id: u64, + ) -> RdmaResult<()> { + debug!("๐ก Mock: RDMA READ local=0x{:x}, remote=0x{:x}, rkey=0x{:x}, size={}", + local_addr, remote_addr, rkey, size); + + // Simulate RDMA read latency (much faster than real network, but realistic for mock) + tokio::time::sleep(tokio::time::Duration::from_nanos(150)).await; + + // Mock data transfer - copy pattern data to local address + let data_ptr = local_addr as *mut u8; + unsafe { + for i in 0..size { + *data_ptr.add(i) = (i % 256) as u8; // Pattern: 0,1,2,...,255,0,1,2... + } + } + + // Create completion + let completion = WorkCompletion { + wr_id, + status: CompletionStatus::Success, + opcode: RdmaOp::Read, + byte_len: size as u32, + imm_data: None, + }; + + self.pending_operations.write().push(completion); + + Ok(()) + } + + pub async fn post_write(&self, + local_addr: u64, + remote_addr: u64, + rkey: u32, + size: usize, + wr_id: u64, + ) -> RdmaResult<()> { + debug!("๐ก Mock: RDMA WRITE local=0x{:x}, remote=0x{:x}, rkey=0x{:x}, size={}", + local_addr, remote_addr, rkey, size); + + // Simulate RDMA write latency + tokio::time::sleep(tokio::time::Duration::from_nanos(100)).await; + + // Create completion + let completion = WorkCompletion { + wr_id, + status: CompletionStatus::Success, + opcode: RdmaOp::Write, + byte_len: size as u32, + imm_data: None, + }; + + self.pending_operations.write().push(completion); + + Ok(()) + } + + pub async fn poll_completion(&self, max_completions: usize) -> RdmaResult<Vec<WorkCompletion>> { + let mut operations = self.pending_operations.write(); + let available = operations.len().min(max_completions); + let completions = operations.drain(..available).collect(); + + Ok(completions) + } + + pub fn device_info(&self) -> &RdmaDeviceInfo { + &self.device_info + } +} + +/// Real RDMA context using libibverbs +#[cfg(feature = "real-ucx")] +pub struct RealRdmaContext { + // Real implementation would contain: + // ibv_context: *mut ibv_context, + // ibv_pd: *mut ibv_pd, + // ibv_cq: *mut ibv_cq, + // ibv_qp: *mut ibv_qp, + device_info: RdmaDeviceInfo, + config: RdmaEngineConfig, +} + +#[cfg(feature = "real-ucx")] +impl RealRdmaContext { + pub async fn new(config: &RdmaEngineConfig) -> RdmaResult<Self> { + info!("โ
Initializing REAL RDMA context for device: {}", config.device_name); + + // Real implementation would: + // 1. Get device list with ibv_get_device_list() + // 2. Find device by name + // 3. Open device with ibv_open_device() + // 4. Create protection domain with ibv_alloc_pd() + // 5. Create completion queue with ibv_create_cq() + // 6. Create queue pair with ibv_create_qp() + // 7. Transition QP to RTS state + + todo!("Real RDMA implementation using libibverbs"); + } +} + +#[cfg(feature = "real-ucx")] +#[async_trait::async_trait] +impl RdmaContextTrait for RealRdmaContext { + async fn register_memory(&self, _addr: u64, _size: usize) -> RdmaResult<MemoryRegion> { + // Real implementation would use ibv_reg_mr() + todo!("Real memory registration") + } + + async fn deregister_memory(&self, _region: &MemoryRegion) -> RdmaResult<()> { + // Real implementation would use ibv_dereg_mr() + todo!("Real memory deregistration") + } + + async fn post_read(&self, + _local_addr: u64, + _remote_addr: u64, + _rkey: u32, + _size: usize, + _wr_id: u64, + ) -> RdmaResult<()> { + // Real implementation would use ibv_post_send() with IBV_WR_RDMA_READ + todo!("Real RDMA read") + } + + async fn post_write(&self, + _local_addr: u64, + _remote_addr: u64, + _rkey: u32, + _size: usize, + _wr_id: u64, + ) -> RdmaResult<()> { + // Real implementation would use ibv_post_send() with IBV_WR_RDMA_WRITE + todo!("Real RDMA write") + } + + async fn poll_completion(&self, _max_completions: usize) -> RdmaResult<Vec<WorkCompletion>> { + // Real implementation would use ibv_poll_cq() + todo!("Real completion polling") + } + + fn device_info(&self) -> &RdmaDeviceInfo { + &self.device_info + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_rdma_context() { + let config = RdmaEngineConfig::default(); + let ctx = RdmaContext::new(&config).await.unwrap(); + + // Test device info + let info = ctx.device_info(); + assert_eq!(info.name, "mlx5_0"); + assert!(info.max_mr > 0); + + // Test memory registration + let addr = 0x7f000000u64; + let size = 4096; + let region = ctx.register_memory(addr, size).await.unwrap(); + assert_eq!(region.addr, addr); + assert_eq!(region.size, size); + assert!(region.registered); + + // Test RDMA read + let local_buf = vec![0u8; 1024]; + let local_addr = local_buf.as_ptr() as u64; + let result = ctx.post_read(local_addr, 0x8000000, region.rkey, 1024, 1).await; + assert!(result.is_ok()); + + // Test completion polling + let completions = ctx.poll_completion(10).await.unwrap(); + assert_eq!(completions.len(), 1); + assert_eq!(completions[0].status, CompletionStatus::Success); + + // Test memory deregistration + let result = ctx.deregister_memory(®ion).await; + assert!(result.is_ok()); + } + + #[test] + fn test_completion_status_conversion() { + assert_eq!(CompletionStatus::from(0), CompletionStatus::Success); + assert_eq!(CompletionStatus::from(1), CompletionStatus::LocalLengthError); + assert_eq!(CompletionStatus::from(999), CompletionStatus::GeneralError); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/session.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/session.rs new file mode 100644 index 000000000..fa089c72a --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/session.rs @@ -0,0 +1,587 @@ +//! Session management for RDMA operations +//! +//! This module manages the lifecycle of RDMA sessions, including creation, +//! storage, expiration, and cleanup of resources. + +use crate::{RdmaError, RdmaResult, rdma::MemoryRegion}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::time::{Duration, Instant}; +use tracing::{debug, info}; +// use uuid::Uuid; // Unused for now + +/// RDMA session state +#[derive(Debug, Clone)] +pub struct RdmaSession { + /// Unique session identifier + pub id: String, + /// SeaweedFS volume ID + pub volume_id: u32, + /// SeaweedFS needle ID + pub needle_id: u64, + /// Remote memory address + pub remote_addr: u64, + /// Remote key for RDMA access + pub remote_key: u32, + /// Transfer size in bytes + pub transfer_size: u64, + /// Local data buffer + pub buffer: Vec<u8>, + /// RDMA memory region + pub memory_region: MemoryRegion, + /// Session creation time + pub created_at: Instant, + /// Session expiration time + pub expires_at: Instant, + /// Current session state + pub state: SessionState, + /// Operation statistics + pub stats: SessionStats, +} + +/// Session state enum +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum SessionState { + /// Session created but not yet active + Created, + /// RDMA operation in progress + Active, + /// Operation completed successfully + Completed, + /// Operation failed + Failed, + /// Session expired + Expired, + /// Session being cleaned up + CleaningUp, +} + +/// Session operation statistics +#[derive(Debug, Clone, Default)] +pub struct SessionStats { + /// Number of RDMA operations performed + pub operations_count: u64, + /// Total bytes transferred + pub bytes_transferred: u64, + /// Time spent in RDMA operations (nanoseconds) + pub rdma_time_ns: u64, + /// Number of completion polling attempts + pub poll_attempts: u64, + /// Time of last operation + pub last_operation_at: Option<Instant>, +} + +impl RdmaSession { + /// Create a new RDMA session + pub fn new( + id: String, + volume_id: u32, + needle_id: u64, + remote_addr: u64, + remote_key: u32, + transfer_size: u64, + buffer: Vec<u8>, + memory_region: MemoryRegion, + timeout: Duration, + ) -> Self { + let now = Instant::now(); + + Self { + id, + volume_id, + needle_id, + remote_addr, + remote_key, + transfer_size, + buffer, + memory_region, + created_at: now, + expires_at: now + timeout, + state: SessionState::Created, + stats: SessionStats::default(), + } + } + + /// Check if session has expired + pub fn is_expired(&self) -> bool { + Instant::now() > self.expires_at + } + + /// Get session age in seconds + pub fn age_secs(&self) -> f64 { + self.created_at.elapsed().as_secs_f64() + } + + /// Get time until expiration in seconds + pub fn time_to_expiration_secs(&self) -> f64 { + if self.is_expired() { + 0.0 + } else { + (self.expires_at - Instant::now()).as_secs_f64() + } + } + + /// Update session state + pub fn set_state(&mut self, state: SessionState) { + debug!("Session {} state: {:?} -> {:?}", self.id, self.state, state); + self.state = state; + } + + /// Record RDMA operation statistics + pub fn record_operation(&mut self, bytes_transferred: u64, duration_ns: u64) { + self.stats.operations_count += 1; + self.stats.bytes_transferred += bytes_transferred; + self.stats.rdma_time_ns += duration_ns; + self.stats.last_operation_at = Some(Instant::now()); + } + + /// Get average operation latency in nanoseconds + pub fn avg_operation_latency_ns(&self) -> u64 { + if self.stats.operations_count > 0 { + self.stats.rdma_time_ns / self.stats.operations_count + } else { + 0 + } + } + + /// Get throughput in bytes per second + pub fn throughput_bps(&self) -> f64 { + let age_secs = self.age_secs(); + if age_secs > 0.0 { + self.stats.bytes_transferred as f64 / age_secs + } else { + 0.0 + } + } +} + +/// Session manager for handling multiple concurrent RDMA sessions +pub struct SessionManager { + /// Active sessions + sessions: Arc<RwLock<HashMap<String, Arc<RwLock<RdmaSession>>>>>, + /// Maximum number of concurrent sessions + max_sessions: usize, + /// Default session timeout + #[allow(dead_code)] + default_timeout: Duration, + /// Cleanup task handle + cleanup_task: RwLock<Option<tokio::task::JoinHandle<()>>>, + /// Shutdown flag + shutdown_flag: Arc<RwLock<bool>>, + /// Statistics + stats: Arc<RwLock<SessionManagerStats>>, +} + +/// Session manager statistics +#[derive(Debug, Clone, Default)] +pub struct SessionManagerStats { + /// Total sessions created + pub total_sessions_created: u64, + /// Total sessions completed + pub total_sessions_completed: u64, + /// Total sessions failed + pub total_sessions_failed: u64, + /// Total sessions expired + pub total_sessions_expired: u64, + /// Total bytes transferred across all sessions + pub total_bytes_transferred: u64, + /// Manager start time + pub started_at: Option<Instant>, +} + +impl SessionManager { + /// Create new session manager + pub fn new(max_sessions: usize, default_timeout: Duration) -> Self { + info!("๐ฏ Session manager initialized: max_sessions={}, timeout={:?}", + max_sessions, default_timeout); + + let mut stats = SessionManagerStats::default(); + stats.started_at = Some(Instant::now()); + + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + max_sessions, + default_timeout, + cleanup_task: RwLock::new(None), + shutdown_flag: Arc::new(RwLock::new(false)), + stats: Arc::new(RwLock::new(stats)), + } + } + + /// Create a new RDMA session + pub async fn create_session( + &self, + session_id: String, + volume_id: u32, + needle_id: u64, + remote_addr: u64, + remote_key: u32, + transfer_size: u64, + buffer: Vec<u8>, + memory_region: MemoryRegion, + timeout: chrono::Duration, + ) -> RdmaResult<Arc<RwLock<RdmaSession>>> { + // Check session limit + { + let sessions = self.sessions.read(); + if sessions.len() >= self.max_sessions { + return Err(RdmaError::TooManySessions { + max_sessions: self.max_sessions + }); + } + + // Check if session already exists + if sessions.contains_key(&session_id) { + return Err(RdmaError::invalid_request( + format!("Session {} already exists", session_id) + )); + } + } + + let timeout_duration = Duration::from_millis(timeout.num_milliseconds().max(1) as u64); + + let session = Arc::new(RwLock::new(RdmaSession::new( + session_id.clone(), + volume_id, + needle_id, + remote_addr, + remote_key, + transfer_size, + buffer, + memory_region, + timeout_duration, + ))); + + // Store session + { + let mut sessions = self.sessions.write(); + sessions.insert(session_id.clone(), session.clone()); + } + + // Update stats + { + let mut stats = self.stats.write(); + stats.total_sessions_created += 1; + } + + info!("๐ฆ Created session {}: volume={}, needle={}, size={}", + session_id, volume_id, needle_id, transfer_size); + + Ok(session) + } + + /// Get session by ID + pub async fn get_session(&self, session_id: &str) -> RdmaResult<Arc<RwLock<RdmaSession>>> { + let sessions = self.sessions.read(); + match sessions.get(session_id) { + Some(session) => { + if session.read().is_expired() { + Err(RdmaError::SessionExpired { + session_id: session_id.to_string() + }) + } else { + Ok(session.clone()) + } + } + None => Err(RdmaError::SessionNotFound { + session_id: session_id.to_string() + }), + } + } + + /// Remove and cleanup session + pub async fn remove_session(&self, session_id: &str) -> RdmaResult<()> { + let session = { + let mut sessions = self.sessions.write(); + sessions.remove(session_id) + }; + + if let Some(session) = session { + let session_data = session.read(); + info!("๐๏ธ Removed session {}: stats={:?}", session_id, session_data.stats); + + // Update manager stats + { + let mut stats = self.stats.write(); + match session_data.state { + SessionState::Completed => stats.total_sessions_completed += 1, + SessionState::Failed => stats.total_sessions_failed += 1, + SessionState::Expired => stats.total_sessions_expired += 1, + _ => {} + } + stats.total_bytes_transferred += session_data.stats.bytes_transferred; + } + + Ok(()) + } else { + Err(RdmaError::SessionNotFound { + session_id: session_id.to_string() + }) + } + } + + /// Get active session count + pub async fn active_session_count(&self) -> usize { + self.sessions.read().len() + } + + /// Get maximum sessions allowed + pub fn max_sessions(&self) -> usize { + self.max_sessions + } + + /// List active sessions + pub async fn list_sessions(&self) -> Vec<String> { + self.sessions.read().keys().cloned().collect() + } + + /// Get session statistics + pub async fn get_session_stats(&self, session_id: &str) -> RdmaResult<SessionStats> { + let session = self.get_session(session_id).await?; + let stats = { + let session_data = session.read(); + session_data.stats.clone() + }; + Ok(stats) + } + + /// Get manager statistics + pub fn get_manager_stats(&self) -> SessionManagerStats { + self.stats.read().clone() + } + + /// Start background cleanup task + pub async fn start_cleanup_task(&self) { + info!("๐ Session cleanup task initialized"); + + let sessions = Arc::clone(&self.sessions); + let shutdown_flag = Arc::clone(&self.shutdown_flag); + let stats = Arc::clone(&self.stats); + + let task = tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(30)); // Check every 30 seconds + + loop { + interval.tick().await; + + // Check shutdown flag + if *shutdown_flag.read() { + debug!("๐ Session cleanup task shutting down"); + break; + } + + let now = Instant::now(); + let mut expired_sessions = Vec::new(); + + // Find expired sessions + { + let sessions_guard = sessions.read(); + for (session_id, session) in sessions_guard.iter() { + if now > session.read().expires_at { + expired_sessions.push(session_id.clone()); + } + } + } + + // Remove expired sessions + if !expired_sessions.is_empty() { + let mut sessions_guard = sessions.write(); + let mut stats_guard = stats.write(); + + for session_id in expired_sessions { + if let Some(session) = sessions_guard.remove(&session_id) { + let session_data = session.read(); + info!("๐๏ธ Cleaned up expired session: {} (volume={}, needle={})", + session_id, session_data.volume_id, session_data.needle_id); + stats_guard.total_sessions_expired += 1; + } + } + + debug!("๐ Active sessions: {}", sessions_guard.len()); + } + } + }); + + *self.cleanup_task.write() = Some(task); + } + + /// Shutdown session manager + pub async fn shutdown(&self) { + info!("๐ Shutting down session manager"); + *self.shutdown_flag.write() = true; + + // Wait for cleanup task to finish + if let Some(task) = self.cleanup_task.write().take() { + let _ = task.await; + } + + // Clean up all remaining sessions + let session_ids: Vec<String> = { + self.sessions.read().keys().cloned().collect() + }; + + for session_id in session_ids { + let _ = self.remove_session(&session_id).await; + } + + let final_stats = self.get_manager_stats(); + info!("๐ Final session manager stats: {:?}", final_stats); + } + + /// Force cleanup of all sessions (for testing) + #[cfg(test)] + pub async fn cleanup_all_sessions(&self) { + let session_ids: Vec<String> = { + self.sessions.read().keys().cloned().collect() + }; + + for session_id in session_ids { + let _ = self.remove_session(&session_id).await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rdma::MemoryRegion; + + #[tokio::test] + async fn test_session_creation() { + let manager = SessionManager::new(10, Duration::from_secs(60)); + + let memory_region = MemoryRegion { + addr: 0x1000, + rkey: 0x12345678, + lkey: 0x87654321, + size: 4096, + registered: true, + }; + + let session = manager.create_session( + "test-session".to_string(), + 1, + 100, + 0x2000, + 0xabcd, + 4096, + vec![0; 4096], + memory_region, + chrono::Duration::seconds(60), + ).await.unwrap(); + + let session_data = session.read(); + assert_eq!(session_data.id, "test-session"); + assert_eq!(session_data.volume_id, 1); + assert_eq!(session_data.needle_id, 100); + assert_eq!(session_data.state, SessionState::Created); + assert!(!session_data.is_expired()); + } + + #[tokio::test] + async fn test_session_expiration() { + let manager = SessionManager::new(10, Duration::from_millis(10)); + + let memory_region = MemoryRegion { + addr: 0x1000, + rkey: 0x12345678, + lkey: 0x87654321, + size: 4096, + registered: true, + }; + + let _session = manager.create_session( + "expire-test".to_string(), + 1, + 100, + 0x2000, + 0xabcd, + 4096, + vec![0; 4096], + memory_region, + chrono::Duration::milliseconds(10), + ).await.unwrap(); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(20)).await; + + let result = manager.get_session("expire-test").await; + assert!(matches!(result, Err(RdmaError::SessionExpired { .. }))); + } + + #[tokio::test] + async fn test_session_limit() { + let manager = SessionManager::new(2, Duration::from_secs(60)); + + let memory_region = MemoryRegion { + addr: 0x1000, + rkey: 0x12345678, + lkey: 0x87654321, + size: 4096, + registered: true, + }; + + // Create first session + let _session1 = manager.create_session( + "session1".to_string(), + 1, 100, 0x2000, 0xabcd, 4096, + vec![0; 4096], + memory_region.clone(), + chrono::Duration::seconds(60), + ).await.unwrap(); + + // Create second session + let _session2 = manager.create_session( + "session2".to_string(), + 1, 101, 0x3000, 0xabcd, 4096, + vec![0; 4096], + memory_region.clone(), + chrono::Duration::seconds(60), + ).await.unwrap(); + + // Third session should fail + let result = manager.create_session( + "session3".to_string(), + 1, 102, 0x4000, 0xabcd, 4096, + vec![0; 4096], + memory_region, + chrono::Duration::seconds(60), + ).await; + + assert!(matches!(result, Err(RdmaError::TooManySessions { .. }))); + } + + #[tokio::test] + async fn test_session_stats() { + let manager = SessionManager::new(10, Duration::from_secs(60)); + + let memory_region = MemoryRegion { + addr: 0x1000, + rkey: 0x12345678, + lkey: 0x87654321, + size: 4096, + registered: true, + }; + + let session = manager.create_session( + "stats-test".to_string(), + 1, 100, 0x2000, 0xabcd, 4096, + vec![0; 4096], + memory_region, + chrono::Duration::seconds(60), + ).await.unwrap(); + + // Simulate some operations - now using proper interior mutability + { + let mut session_data = session.write(); + session_data.record_operation(1024, 1000000); // 1KB in 1ms + session_data.record_operation(2048, 2000000); // 2KB in 2ms + } + + let stats = manager.get_session_stats("stats-test").await.unwrap(); + assert_eq!(stats.operations_count, 2); + assert_eq!(stats.bytes_transferred, 3072); + assert_eq!(stats.rdma_time_ns, 3000000); + } +} diff --git a/seaweedfs-rdma-sidecar/rdma-engine/src/ucx.rs b/seaweedfs-rdma-sidecar/rdma-engine/src/ucx.rs new file mode 100644 index 000000000..901149858 --- /dev/null +++ b/seaweedfs-rdma-sidecar/rdma-engine/src/ucx.rs @@ -0,0 +1,606 @@ +//! UCX (Unified Communication X) FFI bindings and high-level wrapper +//! +//! UCX is a superior alternative to direct libibverbs for RDMA programming. +//! It provides production-proven abstractions and automatic transport selection. +//! +//! References: +//! - UCX Documentation: https://openucx.readthedocs.io/ +//! - UCX GitHub: https://github.com/openucx/ucx +//! - UCX Paper: "UCX: an open source framework for HPC network APIs and beyond" + +use crate::{RdmaError, RdmaResult}; +use libc::{c_char, c_int, c_void, size_t}; +use libloading::{Library, Symbol}; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::ffi::CStr; +use std::ptr; +use std::sync::Arc; +use tracing::{debug, info, warn, error}; + +/// UCX context handle +pub type UcpContext = *mut c_void; +/// UCX worker handle +pub type UcpWorker = *mut c_void; +/// UCX endpoint handle +pub type UcpEp = *mut c_void; +/// UCX memory handle +pub type UcpMem = *mut c_void; +/// UCX request handle +pub type UcpRequest = *mut c_void; + +/// UCX configuration parameters +#[repr(C)] +pub struct UcpParams { + pub field_mask: u64, + pub features: u64, + pub request_size: size_t, + pub request_init: extern "C" fn(*mut c_void), + pub request_cleanup: extern "C" fn(*mut c_void), + pub tag_sender_mask: u64, +} + +/// UCX worker parameters +#[repr(C)] +pub struct UcpWorkerParams { + pub field_mask: u64, + pub thread_mode: c_int, + pub cpu_mask: u64, + pub events: c_int, + pub user_data: *mut c_void, +} + +/// UCX endpoint parameters +#[repr(C)] +pub struct UcpEpParams { + pub field_mask: u64, + pub address: *const c_void, + pub flags: u64, + pub sock_addr: *const c_void, + pub err_handler: UcpErrHandler, + pub user_data: *mut c_void, +} + +/// UCX memory mapping parameters +#[repr(C)] +pub struct UcpMemMapParams { + pub field_mask: u64, + pub address: *mut c_void, + pub length: size_t, + pub flags: u64, + pub prot: c_int, +} + +/// UCX error handler callback +pub type UcpErrHandler = extern "C" fn( + arg: *mut c_void, + ep: UcpEp, + status: c_int, +); + +/// UCX request callback +pub type UcpSendCallback = extern "C" fn( + request: *mut c_void, + status: c_int, + user_data: *mut c_void, +); + +/// UCX feature flags +pub const UCP_FEATURE_TAG: u64 = 1 << 0; +pub const UCP_FEATURE_RMA: u64 = 1 << 1; +pub const UCP_FEATURE_ATOMIC32: u64 = 1 << 2; +pub const UCP_FEATURE_ATOMIC64: u64 = 1 << 3; +pub const UCP_FEATURE_WAKEUP: u64 = 1 << 4; +pub const UCP_FEATURE_STREAM: u64 = 1 << 5; + +/// UCX parameter field masks +pub const UCP_PARAM_FIELD_FEATURES: u64 = 1 << 0; +pub const UCP_PARAM_FIELD_REQUEST_SIZE: u64 = 1 << 1; +pub const UCP_PARAM_FIELD_REQUEST_INIT: u64 = 1 << 2; +pub const UCP_PARAM_FIELD_REQUEST_CLEANUP: u64 = 1 << 3; +pub const UCP_PARAM_FIELD_TAG_SENDER_MASK: u64 = 1 << 4; + +pub const UCP_WORKER_PARAM_FIELD_THREAD_MODE: u64 = 1 << 0; +pub const UCP_WORKER_PARAM_FIELD_CPU_MASK: u64 = 1 << 1; +pub const UCP_WORKER_PARAM_FIELD_EVENTS: u64 = 1 << 2; +pub const UCP_WORKER_PARAM_FIELD_USER_DATA: u64 = 1 << 3; + +pub const UCP_EP_PARAM_FIELD_REMOTE_ADDRESS: u64 = 1 << 0; +pub const UCP_EP_PARAM_FIELD_FLAGS: u64 = 1 << 1; +pub const UCP_EP_PARAM_FIELD_SOCK_ADDR: u64 = 1 << 2; +pub const UCP_EP_PARAM_FIELD_ERR_HANDLER: u64 = 1 << 3; +pub const UCP_EP_PARAM_FIELD_USER_DATA: u64 = 1 << 4; + +pub const UCP_MEM_MAP_PARAM_FIELD_ADDRESS: u64 = 1 << 0; +pub const UCP_MEM_MAP_PARAM_FIELD_LENGTH: u64 = 1 << 1; +pub const UCP_MEM_MAP_PARAM_FIELD_FLAGS: u64 = 1 << 2; +pub const UCP_MEM_MAP_PARAM_FIELD_PROT: u64 = 1 << 3; + +/// UCX status codes +pub const UCS_OK: c_int = 0; +pub const UCS_INPROGRESS: c_int = 1; +pub const UCS_ERR_NO_MESSAGE: c_int = -1; +pub const UCS_ERR_NO_RESOURCE: c_int = -2; +pub const UCS_ERR_IO_ERROR: c_int = -3; +pub const UCS_ERR_NO_MEMORY: c_int = -4; +pub const UCS_ERR_INVALID_PARAM: c_int = -5; +pub const UCS_ERR_UNREACHABLE: c_int = -6; +pub const UCS_ERR_INVALID_ADDR: c_int = -7; +pub const UCS_ERR_NOT_IMPLEMENTED: c_int = -8; +pub const UCS_ERR_MESSAGE_TRUNCATED: c_int = -9; +pub const UCS_ERR_NO_PROGRESS: c_int = -10; +pub const UCS_ERR_BUFFER_TOO_SMALL: c_int = -11; +pub const UCS_ERR_NO_ELEM: c_int = -12; +pub const UCS_ERR_SOME_CONNECTS_FAILED: c_int = -13; +pub const UCS_ERR_NO_DEVICE: c_int = -14; +pub const UCS_ERR_BUSY: c_int = -15; +pub const UCS_ERR_CANCELED: c_int = -16; +pub const UCS_ERR_SHMEM_SEGMENT: c_int = -17; +pub const UCS_ERR_ALREADY_EXISTS: c_int = -18; +pub const UCS_ERR_OUT_OF_RANGE: c_int = -19; +pub const UCS_ERR_TIMED_OUT: c_int = -20; + +/// UCX memory protection flags +pub const UCP_MEM_MAP_NONBLOCK: u64 = 1 << 0; +pub const UCP_MEM_MAP_ALLOCATE: u64 = 1 << 1; +pub const UCP_MEM_MAP_FIXED: u64 = 1 << 2; + +/// UCX FFI function signatures +pub struct UcxApi { + pub ucp_init: Symbol<'static, unsafe extern "C" fn(*const UcpParams, *const c_void, *mut UcpContext) -> c_int>, + pub ucp_cleanup: Symbol<'static, unsafe extern "C" fn(UcpContext)>, + pub ucp_worker_create: Symbol<'static, unsafe extern "C" fn(UcpContext, *const UcpWorkerParams, *mut UcpWorker) -> c_int>, + pub ucp_worker_destroy: Symbol<'static, unsafe extern "C" fn(UcpWorker)>, + pub ucp_ep_create: Symbol<'static, unsafe extern "C" fn(UcpWorker, *const UcpEpParams, *mut UcpEp) -> c_int>, + pub ucp_ep_destroy: Symbol<'static, unsafe extern "C" fn(UcpEp)>, + pub ucp_mem_map: Symbol<'static, unsafe extern "C" fn(UcpContext, *const UcpMemMapParams, *mut UcpMem) -> c_int>, + pub ucp_mem_unmap: Symbol<'static, unsafe extern "C" fn(UcpContext, UcpMem) -> c_int>, + pub ucp_put_nb: Symbol<'static, unsafe extern "C" fn(UcpEp, *const c_void, size_t, u64, u64, UcpSendCallback) -> UcpRequest>, + pub ucp_get_nb: Symbol<'static, unsafe extern "C" fn(UcpEp, *mut c_void, size_t, u64, u64, UcpSendCallback) -> UcpRequest>, + pub ucp_worker_progress: Symbol<'static, unsafe extern "C" fn(UcpWorker) -> c_int>, + pub ucp_request_check_status: Symbol<'static, unsafe extern "C" fn(UcpRequest) -> c_int>, + pub ucp_request_free: Symbol<'static, unsafe extern "C" fn(UcpRequest)>, + pub ucp_worker_get_address: Symbol<'static, unsafe extern "C" fn(UcpWorker, *mut *mut c_void, *mut size_t) -> c_int>, + pub ucp_worker_release_address: Symbol<'static, unsafe extern "C" fn(UcpWorker, *mut c_void)>, + pub ucs_status_string: Symbol<'static, unsafe extern "C" fn(c_int) -> *const c_char>, +} + +impl UcxApi { + /// Load UCX library and resolve symbols + pub fn load() -> RdmaResult<Self> { + info!("๐ Loading UCX library"); + + // Try to load UCX library + let lib_names = [ + "libucp.so.0", // Most common + "libucp.so", // Generic + "libucp.dylib", // macOS + "/usr/lib/x86_64-linux-gnu/libucp.so.0", // Ubuntu/Debian + "/usr/lib64/libucp.so.0", // RHEL/CentOS + ]; + + let library = lib_names.iter() + .find_map(|name| { + debug!("Trying to load UCX library: {}", name); + match unsafe { Library::new(name) } { + Ok(lib) => { + info!("โ
Successfully loaded UCX library: {}", name); + Some(lib) + } + Err(e) => { + debug!("Failed to load {}: {}", name, e); + None + } + } + }) + .ok_or_else(|| RdmaError::context_init_failed("UCX library not found"))?; + + // Leak the library to get 'static lifetime for symbols + let library: &'static Library = Box::leak(Box::new(library)); + + unsafe { + Ok(UcxApi { + ucp_init: library.get(b"ucp_init") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_init symbol: {}", e)))?, + ucp_cleanup: library.get(b"ucp_cleanup") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_cleanup symbol: {}", e)))?, + ucp_worker_create: library.get(b"ucp_worker_create") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_worker_create symbol: {}", e)))?, + ucp_worker_destroy: library.get(b"ucp_worker_destroy") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_worker_destroy symbol: {}", e)))?, + ucp_ep_create: library.get(b"ucp_ep_create") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_ep_create symbol: {}", e)))?, + ucp_ep_destroy: library.get(b"ucp_ep_destroy") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_ep_destroy symbol: {}", e)))?, + ucp_mem_map: library.get(b"ucp_mem_map") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_mem_map symbol: {}", e)))?, + ucp_mem_unmap: library.get(b"ucp_mem_unmap") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_mem_unmap symbol: {}", e)))?, + ucp_put_nb: library.get(b"ucp_put_nb") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_put_nb symbol: {}", e)))?, + ucp_get_nb: library.get(b"ucp_get_nb") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_get_nb symbol: {}", e)))?, + ucp_worker_progress: library.get(b"ucp_worker_progress") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_worker_progress symbol: {}", e)))?, + ucp_request_check_status: library.get(b"ucp_request_check_status") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_request_check_status symbol: {}", e)))?, + ucp_request_free: library.get(b"ucp_request_free") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_request_free symbol: {}", e)))?, + ucp_worker_get_address: library.get(b"ucp_worker_get_address") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_worker_get_address symbol: {}", e)))?, + ucp_worker_release_address: library.get(b"ucp_worker_release_address") + .map_err(|e| RdmaError::context_init_failed(format!("ucp_worker_release_address symbol: {}", e)))?, + ucs_status_string: library.get(b"ucs_status_string") + .map_err(|e| RdmaError::context_init_failed(format!("ucs_status_string symbol: {}", e)))?, + }) + } + } + + /// Convert UCX status code to human-readable string + pub fn status_string(&self, status: c_int) -> String { + unsafe { + let c_str = (self.ucs_status_string)(status); + if c_str.is_null() { + format!("Unknown status: {}", status) + } else { + CStr::from_ptr(c_str).to_string_lossy().to_string() + } + } + } +} + +/// High-level UCX context wrapper +pub struct UcxContext { + api: Arc<UcxApi>, + context: UcpContext, + worker: UcpWorker, + worker_address: Vec<u8>, + endpoints: Mutex<HashMap<String, UcpEp>>, + memory_regions: Mutex<HashMap<u64, UcpMem>>, +} + +impl UcxContext { + /// Initialize UCX context with RMA support + pub async fn new() -> RdmaResult<Self> { + info!("๐ Initializing UCX context for RDMA operations"); + + let api = Arc::new(UcxApi::load()?); + + // Initialize UCP context + let params = UcpParams { + field_mask: UCP_PARAM_FIELD_FEATURES, + features: UCP_FEATURE_RMA | UCP_FEATURE_WAKEUP, + request_size: 0, + request_init: request_init_cb, + request_cleanup: request_cleanup_cb, + tag_sender_mask: 0, + }; + + let mut context = ptr::null_mut(); + let status = unsafe { (api.ucp_init)(¶ms, ptr::null(), &mut context) }; + if status != UCS_OK { + return Err(RdmaError::context_init_failed(format!( + "ucp_init failed: {} ({})", + api.status_string(status), status + ))); + } + + info!("โ
UCX context initialized successfully"); + + // Create worker + let worker_params = UcpWorkerParams { + field_mask: UCP_WORKER_PARAM_FIELD_THREAD_MODE, + thread_mode: 0, // Single-threaded + cpu_mask: 0, + events: 0, + user_data: ptr::null_mut(), + }; + + let mut worker = ptr::null_mut(); + let status = unsafe { (api.ucp_worker_create)(context, &worker_params, &mut worker) }; + if status != UCS_OK { + unsafe { (api.ucp_cleanup)(context) }; + return Err(RdmaError::context_init_failed(format!( + "ucp_worker_create failed: {} ({})", + api.status_string(status), status + ))); + } + + info!("โ
UCX worker created successfully"); + + // Get worker address for connection establishment + let mut address_ptr = ptr::null_mut(); + let mut address_len = 0; + let status = unsafe { (api.ucp_worker_get_address)(worker, &mut address_ptr, &mut address_len) }; + if status != UCS_OK { + unsafe { + (api.ucp_worker_destroy)(worker); + (api.ucp_cleanup)(context); + } + return Err(RdmaError::context_init_failed(format!( + "ucp_worker_get_address failed: {} ({})", + api.status_string(status), status + ))); + } + + let worker_address = unsafe { + std::slice::from_raw_parts(address_ptr as *const u8, address_len).to_vec() + }; + + unsafe { (api.ucp_worker_release_address)(worker, address_ptr) }; + + info!("โ
UCX worker address obtained ({} bytes)", worker_address.len()); + + Ok(UcxContext { + api, + context, + worker, + worker_address, + endpoints: Mutex::new(HashMap::new()), + memory_regions: Mutex::new(HashMap::new()), + }) + } + + /// Map memory for RDMA operations + pub async fn map_memory(&self, addr: u64, size: usize) -> RdmaResult<u64> { + debug!("๐ Mapping memory for RDMA: addr=0x{:x}, size={}", addr, size); + + let params = UcpMemMapParams { + field_mask: UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH, + address: addr as *mut c_void, + length: size, + flags: 0, + prot: libc::PROT_READ | libc::PROT_WRITE, + }; + + let mut mem_handle = ptr::null_mut(); + let status = unsafe { (self.api.ucp_mem_map)(self.context, ¶ms, &mut mem_handle) }; + + if status != UCS_OK { + return Err(RdmaError::memory_reg_failed(format!( + "ucp_mem_map failed: {} ({})", + self.api.status_string(status), status + ))); + } + + // Store memory handle for cleanup + { + let mut regions = self.memory_regions.lock(); + regions.insert(addr, mem_handle); + } + + info!("โ
Memory mapped successfully: addr=0x{:x}, size={}", addr, size); + Ok(addr) // Return the same address as remote key equivalent + } + + /// Unmap memory + pub async fn unmap_memory(&self, addr: u64) -> RdmaResult<()> { + debug!("๐๏ธ Unmapping memory: addr=0x{:x}", addr); + + let mem_handle = { + let mut regions = self.memory_regions.lock(); + regions.remove(&addr) + }; + + if let Some(handle) = mem_handle { + let status = unsafe { (self.api.ucp_mem_unmap)(self.context, handle) }; + if status != UCS_OK { + warn!("ucp_mem_unmap failed: {} ({})", + self.api.status_string(status), status); + } + } + + Ok(()) + } + + /// Perform RDMA GET (read from remote memory) + pub async fn get(&self, local_addr: u64, remote_addr: u64, size: usize) -> RdmaResult<()> { + debug!("๐ฅ RDMA GET: local=0x{:x}, remote=0x{:x}, size={}", + local_addr, remote_addr, size); + + // For now, use a simple synchronous approach + // In production, this would be properly async with completion callbacks + + // Find or create endpoint (simplified - would need proper address resolution) + let ep = self.get_or_create_endpoint("default").await?; + + let request = unsafe { + (self.api.ucp_get_nb)( + ep, + local_addr as *mut c_void, + size, + remote_addr, + 0, // No remote key needed with UCX + get_completion_cb, + ) + }; + + // Wait for completion + if !request.is_null() { + loop { + let status = unsafe { (self.api.ucp_request_check_status)(request) }; + if status != UCS_INPROGRESS { + unsafe { (self.api.ucp_request_free)(request) }; + if status == UCS_OK { + break; + } else { + return Err(RdmaError::operation_failed( + "RDMA GET", status + )); + } + } + + // Progress the worker + unsafe { (self.api.ucp_worker_progress)(self.worker) }; + tokio::task::yield_now().await; + } + } + + info!("โ
RDMA GET completed successfully"); + Ok(()) + } + + /// Perform RDMA PUT (write to remote memory) + pub async fn put(&self, local_addr: u64, remote_addr: u64, size: usize) -> RdmaResult<()> { + debug!("๐ค RDMA PUT: local=0x{:x}, remote=0x{:x}, size={}", + local_addr, remote_addr, size); + + let ep = self.get_or_create_endpoint("default").await?; + + let request = unsafe { + (self.api.ucp_put_nb)( + ep, + local_addr as *const c_void, + size, + remote_addr, + 0, // No remote key needed with UCX + put_completion_cb, + ) + }; + + // Wait for completion (same pattern as GET) + if !request.is_null() { + loop { + let status = unsafe { (self.api.ucp_request_check_status)(request) }; + if status != UCS_INPROGRESS { + unsafe { (self.api.ucp_request_free)(request) }; + if status == UCS_OK { + break; + } else { + return Err(RdmaError::operation_failed( + "RDMA PUT", status + )); + } + } + + unsafe { (self.api.ucp_worker_progress)(self.worker) }; + tokio::task::yield_now().await; + } + } + + info!("โ
RDMA PUT completed successfully"); + Ok(()) + } + + /// Get worker address for connection establishment + pub fn worker_address(&self) -> &[u8] { + &self.worker_address + } + + /// Create endpoint for communication (simplified version) + async fn get_or_create_endpoint(&self, key: &str) -> RdmaResult<UcpEp> { + let mut endpoints = self.endpoints.lock(); + + if let Some(&ep) = endpoints.get(key) { + return Ok(ep); + } + + // For simplicity, create a dummy endpoint + // In production, this would use actual peer address + let ep_params = UcpEpParams { + field_mask: 0, // Simplified for mock + address: ptr::null(), + flags: 0, + sock_addr: ptr::null(), + err_handler: error_handler_cb, + user_data: ptr::null_mut(), + }; + + let mut endpoint = ptr::null_mut(); + let status = unsafe { (self.api.ucp_ep_create)(self.worker, &ep_params, &mut endpoint) }; + + if status != UCS_OK { + return Err(RdmaError::context_init_failed(format!( + "ucp_ep_create failed: {} ({})", + self.api.status_string(status), status + ))); + } + + endpoints.insert(key.to_string(), endpoint); + Ok(endpoint) + } +} + +impl Drop for UcxContext { + fn drop(&mut self) { + info!("๐งน Cleaning up UCX context"); + + // Clean up endpoints + { + let mut endpoints = self.endpoints.lock(); + for (_, ep) in endpoints.drain() { + unsafe { (self.api.ucp_ep_destroy)(ep) }; + } + } + + // Clean up memory regions + { + let mut regions = self.memory_regions.lock(); + for (_, handle) in regions.drain() { + unsafe { (self.api.ucp_mem_unmap)(self.context, handle) }; + } + } + + // Clean up worker and context + unsafe { + (self.api.ucp_worker_destroy)(self.worker); + (self.api.ucp_cleanup)(self.context); + } + + info!("โ
UCX context cleanup completed"); + } +} + +// UCX callback functions +extern "C" fn request_init_cb(_request: *mut c_void) { + // Request initialization callback +} + +extern "C" fn request_cleanup_cb(_request: *mut c_void) { + // Request cleanup callback +} + +extern "C" fn get_completion_cb(_request: *mut c_void, status: c_int, _user_data: *mut c_void) { + if status != UCS_OK { + error!("RDMA GET completion error: {}", status); + } +} + +extern "C" fn put_completion_cb(_request: *mut c_void, status: c_int, _user_data: *mut c_void) { + if status != UCS_OK { + error!("RDMA PUT completion error: {}", status); + } +} + +extern "C" fn error_handler_cb( + _arg: *mut c_void, + _ep: UcpEp, + status: c_int, +) { + error!("UCX endpoint error: {}", status); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_ucx_api_loading() { + // This test will fail without UCX installed, which is expected + match UcxApi::load() { + Ok(api) => { + info!("UCX API loaded successfully"); + assert_eq!(api.status_string(UCS_OK), "Success"); + } + Err(_) => { + warn!("UCX library not found - expected in development environment"); + } + } + } + + #[tokio::test] + async fn test_ucx_context_mock() { + // This would test the mock implementation + // Real test requires UCX installation + } +} |
