diff options
Diffstat (limited to 'weed/iam/sts/token_utils.go')
| -rw-r--r-- | weed/iam/sts/token_utils.go | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/weed/iam/sts/token_utils.go b/weed/iam/sts/token_utils.go new file mode 100644 index 000000000..07c195326 --- /dev/null +++ b/weed/iam/sts/token_utils.go @@ -0,0 +1,217 @@ +package sts + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/seaweedfs/seaweedfs/weed/iam/utils" +) + +// TokenGenerator handles token generation and validation +type TokenGenerator struct { + signingKey []byte + issuer string +} + +// NewTokenGenerator creates a new token generator +func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator { + return &TokenGenerator{ + signingKey: signingKey, + issuer: issuer, + } +} + +// GenerateSessionToken creates a signed JWT session token (legacy method for compatibility) +func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) { + claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt) + return t.GenerateJWTWithClaims(claims) +} + +// GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims +func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) { + if claims == nil { + return "", fmt.Errorf("claims cannot be nil") + } + + // Ensure issuer is set from token generator + if claims.Issuer == "" { + claims.Issuer = t.issuer + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(t.signingKey) +} + +// ValidateSessionToken validates and extracts claims from a session token +func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return t.signingKey, nil + }) + + if err != nil { + return nil, fmt.Errorf(ErrInvalidToken, err) + } + + if !token.Valid { + return nil, fmt.Errorf(ErrTokenNotValid) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf(ErrInvalidTokenClaims) + } + + // Verify issuer + if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer { + return nil, fmt.Errorf(ErrInvalidIssuer) + } + + // Extract session ID + sessionId, ok := claims[JWTClaimSubject].(string) + if !ok { + return nil, fmt.Errorf(ErrMissingSessionID) + } + + return &SessionTokenClaims{ + SessionId: sessionId, + ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0), + IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0), + }, nil +} + +// ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token +func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return t.signingKey, nil + }) + + if err != nil { + return nil, fmt.Errorf(ErrInvalidToken, err) + } + + if !token.Valid { + return nil, fmt.Errorf(ErrTokenNotValid) + } + + claims, ok := token.Claims.(*STSSessionClaims) + if !ok { + return nil, fmt.Errorf(ErrInvalidTokenClaims) + } + + // Validate issuer + if claims.Issuer != t.issuer { + return nil, fmt.Errorf(ErrInvalidIssuer) + } + + // Validate that required fields are present + if claims.SessionId == "" { + return nil, fmt.Errorf(ErrMissingSessionID) + } + + // Additional validation using the claims' own validation method + if !claims.IsValid() { + return nil, fmt.Errorf(ErrTokenNotValid) + } + + return claims, nil +} + +// SessionTokenClaims represents parsed session token claims +type SessionTokenClaims struct { + SessionId string + ExpiresAt time.Time + IssuedAt time.Time +} + +// CredentialGenerator generates AWS-compatible temporary credentials +type CredentialGenerator struct{} + +// NewCredentialGenerator creates a new credential generator +func NewCredentialGenerator() *CredentialGenerator { + return &CredentialGenerator{} +} + +// GenerateTemporaryCredentials creates temporary AWS credentials +func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) { + accessKeyId, err := c.generateAccessKeyId(sessionId) + if err != nil { + return nil, fmt.Errorf("failed to generate access key ID: %w", err) + } + + secretAccessKey, err := c.generateSecretAccessKey() + if err != nil { + return nil, fmt.Errorf("failed to generate secret access key: %w", err) + } + + sessionToken, err := c.generateSessionTokenId(sessionId) + if err != nil { + return nil, fmt.Errorf("failed to generate session token: %w", err) + } + + return &Credentials{ + AccessKeyId: accessKeyId, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + Expiration: expiration, + }, nil +} + +// generateAccessKeyId generates an AWS-style access key ID +func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) { + // Create a deterministic but unique access key ID based on session + hash := sha256.Sum256([]byte("access-key:" + sessionId)) + return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars +} + +// generateSecretAccessKey generates a random secret access key +func (c *CredentialGenerator) generateSecretAccessKey() (string, error) { + // Generate 32 random bytes for secret key + secretBytes := make([]byte, 32) + _, err := rand.Read(secretBytes) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(secretBytes), nil +} + +// generateSessionTokenId generates a session token identifier +func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) { + // Create session token with session ID embedded + hash := sha256.Sum256([]byte("session-token:" + sessionId)) + return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format +} + +// generateSessionId generates a unique session ID +func GenerateSessionId() (string, error) { + randomBytes := make([]byte, 16) + _, err := rand.Read(randomBytes) + if err != nil { + return "", err + } + + return hex.EncodeToString(randomBytes), nil +} + +// generateAssumedRoleArn generates the ARN for an assumed role user +func GenerateAssumedRoleArn(roleArn, sessionName string) string { + // Convert role ARN to assumed role user ARN + // arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName + roleName := utils.ExtractRoleNameFromArn(roleArn) + if roleName == "" { + // This should not happen if validation is done properly upstream + return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName) + } + return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName) +} |
