aboutsummaryrefslogtreecommitdiff
path: root/weed/server/postgres/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/server/postgres/server.go')
-rw-r--r--weed/server/postgres/server.go704
1 files changed, 704 insertions, 0 deletions
diff --git a/weed/server/postgres/server.go b/weed/server/postgres/server.go
new file mode 100644
index 000000000..f35d3704e
--- /dev/null
+++ b/weed/server/postgres/server.go
@@ -0,0 +1,704 @@
+package postgres
+
+import (
+ "bufio"
+ "crypto/md5"
+ "crypto/rand"
+ "crypto/tls"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/query/engine"
+ "github.com/seaweedfs/seaweedfs/weed/util/version"
+)
+
+// PostgreSQL protocol constants
+const (
+ // Protocol versions
+ PG_PROTOCOL_VERSION_3 = 196608 // PostgreSQL 3.0 protocol (0x00030000)
+ PG_SSL_REQUEST = 80877103 // SSL request (0x04d2162f)
+ PG_GSSAPI_REQUEST = 80877104 // GSSAPI request (0x04d21630)
+
+ // Message types from client
+ PG_MSG_STARTUP = 0x00
+ PG_MSG_QUERY = 'Q'
+ PG_MSG_PARSE = 'P'
+ PG_MSG_BIND = 'B'
+ PG_MSG_EXECUTE = 'E'
+ PG_MSG_DESCRIBE = 'D'
+ PG_MSG_CLOSE = 'C'
+ PG_MSG_FLUSH = 'H'
+ PG_MSG_SYNC = 'S'
+ PG_MSG_TERMINATE = 'X'
+ PG_MSG_PASSWORD = 'p'
+
+ // Response types to client
+ PG_RESP_AUTH_OK = 'R'
+ PG_RESP_BACKEND_KEY = 'K'
+ PG_RESP_PARAMETER = 'S'
+ PG_RESP_READY = 'Z'
+ PG_RESP_COMMAND = 'C'
+ PG_RESP_DATA_ROW = 'D'
+ PG_RESP_ROW_DESC = 'T'
+ PG_RESP_PARSE_COMPLETE = '1'
+ PG_RESP_BIND_COMPLETE = '2'
+ PG_RESP_CLOSE_COMPLETE = '3'
+ PG_RESP_ERROR = 'E'
+ PG_RESP_NOTICE = 'N'
+
+ // Transaction states
+ PG_TRANS_IDLE = 'I'
+ PG_TRANS_INTRANS = 'T'
+ PG_TRANS_ERROR = 'E'
+
+ // Authentication methods
+ AUTH_OK = 0
+ AUTH_CLEAR = 3
+ AUTH_MD5 = 5
+ AUTH_TRUST = 10
+
+ // PostgreSQL data types
+ PG_TYPE_BOOL = 16
+ PG_TYPE_BYTEA = 17
+ PG_TYPE_INT8 = 20
+ PG_TYPE_INT4 = 23
+ PG_TYPE_TEXT = 25
+ PG_TYPE_FLOAT4 = 700
+ PG_TYPE_FLOAT8 = 701
+ PG_TYPE_VARCHAR = 1043
+ PG_TYPE_TIMESTAMP = 1114
+ PG_TYPE_JSON = 114
+ PG_TYPE_JSONB = 3802
+
+ // Default values
+ DEFAULT_POSTGRES_PORT = 5432
+)
+
+// Authentication method type
+type AuthMethod int
+
+const (
+ AuthTrust AuthMethod = iota
+ AuthPassword
+ AuthMD5
+)
+
+// PostgreSQL server configuration
+type PostgreSQLServerConfig struct {
+ Host string
+ Port int
+ AuthMethod AuthMethod
+ Users map[string]string
+ TLSConfig *tls.Config
+ MaxConns int
+ IdleTimeout time.Duration
+ StartupTimeout time.Duration // Timeout for client startup handshake
+ Database string
+}
+
+// PostgreSQL server
+type PostgreSQLServer struct {
+ config *PostgreSQLServerConfig
+ listener net.Listener
+ sqlEngine *engine.SQLEngine
+ sessions map[uint32]*PostgreSQLSession
+ sessionMux sync.RWMutex
+ shutdown chan struct{}
+ wg sync.WaitGroup
+ nextConnID uint32
+}
+
+// PostgreSQL session
+type PostgreSQLSession struct {
+ conn net.Conn
+ reader *bufio.Reader
+ writer *bufio.Writer
+ authenticated bool
+ username string
+ database string
+ parameters map[string]string
+ preparedStmts map[string]*PreparedStatement
+ portals map[string]*Portal
+ transactionState byte
+ processID uint32
+ secretKey uint32
+ created time.Time
+ lastActivity time.Time
+ mutex sync.Mutex
+}
+
+// Prepared statement
+type PreparedStatement struct {
+ Name string
+ Query string
+ ParamTypes []uint32
+ Fields []FieldDescription
+}
+
+// Portal (cursor)
+type Portal struct {
+ Name string
+ Statement string
+ Parameters [][]byte
+ Suspended bool
+}
+
+// Field description
+type FieldDescription struct {
+ Name string
+ TableOID uint32
+ AttrNum int16
+ TypeOID uint32
+ TypeSize int16
+ TypeMod int32
+ Format int16
+}
+
+// NewPostgreSQLServer creates a new PostgreSQL protocol server
+func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
+ if config.Port <= 0 {
+ config.Port = DEFAULT_POSTGRES_PORT
+ }
+ if config.Host == "" {
+ config.Host = "localhost"
+ }
+ if config.Database == "" {
+ config.Database = "default"
+ }
+ if config.MaxConns <= 0 {
+ config.MaxConns = 100
+ }
+ if config.IdleTimeout <= 0 {
+ config.IdleTimeout = time.Hour
+ }
+ if config.StartupTimeout <= 0 {
+ config.StartupTimeout = 30 * time.Second
+ }
+
+ // Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
+ sqlEngine := engine.NewSQLEngine(masterAddr)
+
+ server := &PostgreSQLServer{
+ config: config,
+ sqlEngine: sqlEngine,
+ sessions: make(map[uint32]*PostgreSQLSession),
+ shutdown: make(chan struct{}),
+ nextConnID: 1,
+ }
+
+ return server, nil
+}
+
+// Start begins listening for PostgreSQL connections
+func (s *PostgreSQLServer) Start() error {
+ addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
+
+ var listener net.Listener
+ var err error
+
+ if s.config.TLSConfig != nil {
+ listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
+ glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
+ } else {
+ listener, err = net.Listen("tcp", addr)
+ glog.Infof("PostgreSQL Server listening on %s", addr)
+ }
+
+ if err != nil {
+ return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
+ }
+
+ s.listener = listener
+
+ // Start accepting connections
+ s.wg.Add(1)
+ go s.acceptConnections()
+
+ // Start cleanup routine
+ s.wg.Add(1)
+ go s.cleanupSessions()
+
+ return nil
+}
+
+// Stop gracefully shuts down the PostgreSQL server
+func (s *PostgreSQLServer) Stop() error {
+ close(s.shutdown)
+
+ if s.listener != nil {
+ s.listener.Close()
+ }
+
+ // Close all sessions
+ s.sessionMux.Lock()
+ for _, session := range s.sessions {
+ session.close()
+ }
+ s.sessions = make(map[uint32]*PostgreSQLSession)
+ s.sessionMux.Unlock()
+
+ s.wg.Wait()
+ glog.Infof("PostgreSQL Server stopped")
+ return nil
+}
+
+// acceptConnections handles incoming PostgreSQL connections
+func (s *PostgreSQLServer) acceptConnections() {
+ defer s.wg.Done()
+
+ for {
+ select {
+ case <-s.shutdown:
+ return
+ default:
+ }
+
+ conn, err := s.listener.Accept()
+ if err != nil {
+ select {
+ case <-s.shutdown:
+ return
+ default:
+ glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
+ continue
+ }
+ }
+
+ // Check connection limit
+ s.sessionMux.RLock()
+ sessionCount := len(s.sessions)
+ s.sessionMux.RUnlock()
+
+ if sessionCount >= s.config.MaxConns {
+ glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
+ s.config.MaxConns, conn.RemoteAddr())
+ conn.Close()
+ continue
+ }
+
+ s.wg.Add(1)
+ go s.handleConnection(conn)
+ }
+}
+
+// handleConnection processes a single PostgreSQL connection
+func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
+ defer s.wg.Done()
+ defer conn.Close()
+
+ // Generate unique connection ID
+ connID := s.generateConnectionID()
+ secretKey := s.generateSecretKey()
+
+ // Create session
+ session := &PostgreSQLSession{
+ conn: conn,
+ reader: bufio.NewReader(conn),
+ writer: bufio.NewWriter(conn),
+ authenticated: false,
+ database: s.config.Database,
+ parameters: make(map[string]string),
+ preparedStmts: make(map[string]*PreparedStatement),
+ portals: make(map[string]*Portal),
+ transactionState: PG_TRANS_IDLE,
+ processID: connID,
+ secretKey: secretKey,
+ created: time.Now(),
+ lastActivity: time.Now(),
+ }
+
+ // Register session
+ s.sessionMux.Lock()
+ s.sessions[connID] = session
+ s.sessionMux.Unlock()
+
+ // Clean up on exit
+ defer func() {
+ s.sessionMux.Lock()
+ delete(s.sessions, connID)
+ s.sessionMux.Unlock()
+ }()
+
+ glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
+
+ // Handle startup
+ err := s.handleStartup(session)
+ if err != nil {
+ // Handle common disconnection scenarios more gracefully
+ if strings.Contains(err.Error(), "client disconnected") {
+ glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
+ } else if strings.Contains(err.Error(), "timeout") {
+ glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
+ } else {
+ glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
+ }
+ return
+ }
+
+ // Handle messages
+ for {
+ select {
+ case <-s.shutdown:
+ return
+ default:
+ }
+
+ // Set read timeout
+ conn.SetReadDeadline(time.Now().Add(30 * time.Second))
+
+ err := s.handleMessage(session)
+ if err != nil {
+ if err == io.EOF {
+ glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
+ } else {
+ glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
+ }
+ return
+ }
+
+ session.lastActivity = time.Now()
+ }
+}
+
+// handleStartup processes the PostgreSQL startup sequence
+func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
+ // Set a startup timeout to prevent hanging connections
+ startupTimeout := s.config.StartupTimeout
+ session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
+ defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
+
+ for {
+ // Read startup message length
+ length := make([]byte, 4)
+ _, err := io.ReadFull(session.reader, length)
+ if err != nil {
+ if err == io.EOF {
+ // Client disconnected during startup - this is common for health checks
+ return fmt.Errorf("client disconnected during startup handshake")
+ }
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
+ }
+ return fmt.Errorf("failed to read message length during startup: %v", err)
+ }
+
+ msgLength := binary.BigEndian.Uint32(length) - 4
+ if msgLength > 10000 { // Reasonable limit for startup messages
+ return fmt.Errorf("startup message too large: %d bytes", msgLength)
+ }
+
+ // Read startup message content
+ msg := make([]byte, msgLength)
+ _, err = io.ReadFull(session.reader, msg)
+ if err != nil {
+ if err == io.EOF {
+ return fmt.Errorf("client disconnected while reading startup message")
+ }
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ return fmt.Errorf("startup message read timeout")
+ }
+ return fmt.Errorf("failed to read startup message: %v", err)
+ }
+
+ // Parse protocol version
+ protocolVersion := binary.BigEndian.Uint32(msg[0:4])
+
+ switch protocolVersion {
+ case PG_SSL_REQUEST:
+ // Reject SSL request - send 'N' to indicate SSL not supported
+ _, err = session.conn.Write([]byte{'N'})
+ if err != nil {
+ return fmt.Errorf("failed to reject SSL request: %v", err)
+ }
+ // Continue loop to read the actual startup message
+ continue
+
+ case PG_GSSAPI_REQUEST:
+ // Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
+ _, err = session.conn.Write([]byte{'N'})
+ if err != nil {
+ return fmt.Errorf("failed to reject GSSAPI request: %v", err)
+ }
+ // Continue loop to read the actual startup message
+ continue
+
+ case PG_PROTOCOL_VERSION_3:
+ // This is the actual startup message, break out of loop
+ break
+
+ default:
+ return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
+ }
+
+ // Parse parameters
+ params := strings.Split(string(msg[4:]), "\x00")
+ for i := 0; i < len(params)-1; i += 2 {
+ if params[i] == "user" {
+ session.username = params[i+1]
+ } else if params[i] == "database" {
+ session.database = params[i+1]
+ }
+ session.parameters[params[i]] = params[i+1]
+ }
+
+ // Break out of the main loop - we have the startup message
+ break
+ }
+
+ // Handle authentication
+ err := s.handleAuthentication(session)
+ if err != nil {
+ return err
+ }
+
+ // Send parameter status messages
+ err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
+ if err != nil {
+ return err
+ }
+ err = s.sendParameterStatus(session, "server_encoding", "UTF8")
+ if err != nil {
+ return err
+ }
+ err = s.sendParameterStatus(session, "client_encoding", "UTF8")
+ if err != nil {
+ return err
+ }
+ err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
+ if err != nil {
+ return err
+ }
+ err = s.sendParameterStatus(session, "integer_datetimes", "on")
+ if err != nil {
+ return err
+ }
+
+ // Send backend key data
+ err = s.sendBackendKeyData(session)
+ if err != nil {
+ return err
+ }
+
+ // Send ready for query
+ err = s.sendReadyForQuery(session)
+ if err != nil {
+ return err
+ }
+
+ session.authenticated = true
+ return nil
+}
+
+// handleAuthentication processes authentication
+func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
+ switch s.config.AuthMethod {
+ case AuthTrust:
+ return s.sendAuthenticationOk(session)
+ case AuthPassword:
+ return s.handlePasswordAuth(session)
+ case AuthMD5:
+ return s.handleMD5Auth(session)
+ default:
+ return fmt.Errorf("unsupported authentication method")
+ }
+}
+
+// sendAuthenticationOk sends authentication OK message
+func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
+ msg := make([]byte, 9)
+ msg[0] = PG_RESP_AUTH_OK
+ binary.BigEndian.PutUint32(msg[1:5], 8)
+ binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
+
+ _, err := session.writer.Write(msg)
+ if err == nil {
+ err = session.writer.Flush()
+ }
+ return err
+}
+
+// handlePasswordAuth handles clear password authentication
+func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
+ // Send password request
+ msg := make([]byte, 9)
+ msg[0] = PG_RESP_AUTH_OK
+ binary.BigEndian.PutUint32(msg[1:5], 8)
+ binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
+
+ _, err := session.writer.Write(msg)
+ if err != nil {
+ return err
+ }
+ err = session.writer.Flush()
+ if err != nil {
+ return err
+ }
+
+ // Read password response
+ msgType := make([]byte, 1)
+ _, err = io.ReadFull(session.reader, msgType)
+ if err != nil {
+ return err
+ }
+
+ if msgType[0] != PG_MSG_PASSWORD {
+ return fmt.Errorf("expected password message, got %c", msgType[0])
+ }
+
+ length := make([]byte, 4)
+ _, err = io.ReadFull(session.reader, length)
+ if err != nil {
+ return err
+ }
+
+ msgLength := binary.BigEndian.Uint32(length) - 4
+ password := make([]byte, msgLength)
+ _, err = io.ReadFull(session.reader, password)
+ if err != nil {
+ return err
+ }
+
+ // Verify password
+ expectedPassword, exists := s.config.Users[session.username]
+ if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
+ return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
+ }
+
+ return s.sendAuthenticationOk(session)
+}
+
+// handleMD5Auth handles MD5 password authentication
+func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
+ // Generate salt
+ salt := make([]byte, 4)
+ _, err := rand.Read(salt)
+ if err != nil {
+ return err
+ }
+
+ // Send MD5 request
+ msg := make([]byte, 13)
+ msg[0] = PG_RESP_AUTH_OK
+ binary.BigEndian.PutUint32(msg[1:5], 12)
+ binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
+ copy(msg[9:13], salt)
+
+ _, err = session.writer.Write(msg)
+ if err != nil {
+ return err
+ }
+ err = session.writer.Flush()
+ if err != nil {
+ return err
+ }
+
+ // Read password response
+ msgType := make([]byte, 1)
+ _, err = io.ReadFull(session.reader, msgType)
+ if err != nil {
+ return err
+ }
+
+ if msgType[0] != PG_MSG_PASSWORD {
+ return fmt.Errorf("expected password message, got %c", msgType[0])
+ }
+
+ length := make([]byte, 4)
+ _, err = io.ReadFull(session.reader, length)
+ if err != nil {
+ return err
+ }
+
+ msgLength := binary.BigEndian.Uint32(length) - 4
+ response := make([]byte, msgLength)
+ _, err = io.ReadFull(session.reader, response)
+ if err != nil {
+ return err
+ }
+
+ // Verify MD5 hash
+ expectedPassword, exists := s.config.Users[session.username]
+ if !exists {
+ return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
+ }
+
+ // Calculate expected hash: md5(md5(password + username) + salt)
+ inner := md5.Sum([]byte(expectedPassword + session.username))
+ expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
+
+ if string(response[:len(response)-1]) != expected { // Remove null terminator
+ return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
+ }
+
+ return s.sendAuthenticationOk(session)
+}
+
+// generateConnectionID generates a unique connection ID
+func (s *PostgreSQLServer) generateConnectionID() uint32 {
+ s.sessionMux.Lock()
+ defer s.sessionMux.Unlock()
+ id := s.nextConnID
+ s.nextConnID++
+ return id
+}
+
+// generateSecretKey generates a secret key for the connection
+func (s *PostgreSQLServer) generateSecretKey() uint32 {
+ key := make([]byte, 4)
+ rand.Read(key)
+ return binary.BigEndian.Uint32(key)
+}
+
+// close marks the session as closed
+func (s *PostgreSQLSession) close() {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+ if s.conn != nil {
+ s.conn.Close()
+ s.conn = nil
+ }
+}
+
+// cleanupSessions periodically cleans up idle sessions
+func (s *PostgreSQLServer) cleanupSessions() {
+ defer s.wg.Done()
+
+ ticker := time.NewTicker(time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-s.shutdown:
+ return
+ case <-ticker.C:
+ s.cleanupIdleSessions()
+ }
+ }
+}
+
+// cleanupIdleSessions removes sessions that have been idle too long
+func (s *PostgreSQLServer) cleanupIdleSessions() {
+ now := time.Now()
+
+ s.sessionMux.Lock()
+ defer s.sessionMux.Unlock()
+
+ for id, session := range s.sessions {
+ if now.Sub(session.lastActivity) > s.config.IdleTimeout {
+ glog.Infof("Closing idle PostgreSQL session %d", id)
+ session.close()
+ delete(s.sessions, id)
+ }
+ }
+}
+
+// GetAddress returns the server address
+func (s *PostgreSQLServer) GetAddress() string {
+ return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
+}