aboutsummaryrefslogtreecommitdiff
path: root/weed/iam/sts/token_utils.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/iam/sts/token_utils.go')
-rw-r--r--weed/iam/sts/token_utils.go217
1 files changed, 217 insertions, 0 deletions
diff --git a/weed/iam/sts/token_utils.go b/weed/iam/sts/token_utils.go
new file mode 100644
index 000000000..07c195326
--- /dev/null
+++ b/weed/iam/sts/token_utils.go
@@ -0,0 +1,217 @@
+package sts
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/seaweedfs/seaweedfs/weed/iam/utils"
+)
+
+// TokenGenerator handles token generation and validation
+type TokenGenerator struct {
+ signingKey []byte
+ issuer string
+}
+
+// NewTokenGenerator creates a new token generator
+func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
+ return &TokenGenerator{
+ signingKey: signingKey,
+ issuer: issuer,
+ }
+}
+
+// GenerateSessionToken creates a signed JWT session token (legacy method for compatibility)
+func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
+ claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt)
+ return t.GenerateJWTWithClaims(claims)
+}
+
+// GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims
+func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) {
+ if claims == nil {
+ return "", fmt.Errorf("claims cannot be nil")
+ }
+
+ // Ensure issuer is set from token generator
+ if claims.Issuer == "" {
+ claims.Issuer = t.issuer
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString(t.signingKey)
+}
+
+// ValidateSessionToken validates and extracts claims from a session token
+func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) {
+ token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return t.signingKey, nil
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf(ErrInvalidToken, err)
+ }
+
+ if !token.Valid {
+ return nil, fmt.Errorf(ErrTokenNotValid)
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return nil, fmt.Errorf(ErrInvalidTokenClaims)
+ }
+
+ // Verify issuer
+ if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
+ return nil, fmt.Errorf(ErrInvalidIssuer)
+ }
+
+ // Extract session ID
+ sessionId, ok := claims[JWTClaimSubject].(string)
+ if !ok {
+ return nil, fmt.Errorf(ErrMissingSessionID)
+ }
+
+ return &SessionTokenClaims{
+ SessionId: sessionId,
+ ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
+ IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
+ }, nil
+}
+
+// ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token
+func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) {
+ token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
+ }
+ return t.signingKey, nil
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf(ErrInvalidToken, err)
+ }
+
+ if !token.Valid {
+ return nil, fmt.Errorf(ErrTokenNotValid)
+ }
+
+ claims, ok := token.Claims.(*STSSessionClaims)
+ if !ok {
+ return nil, fmt.Errorf(ErrInvalidTokenClaims)
+ }
+
+ // Validate issuer
+ if claims.Issuer != t.issuer {
+ return nil, fmt.Errorf(ErrInvalidIssuer)
+ }
+
+ // Validate that required fields are present
+ if claims.SessionId == "" {
+ return nil, fmt.Errorf(ErrMissingSessionID)
+ }
+
+ // Additional validation using the claims' own validation method
+ if !claims.IsValid() {
+ return nil, fmt.Errorf(ErrTokenNotValid)
+ }
+
+ return claims, nil
+}
+
+// SessionTokenClaims represents parsed session token claims
+type SessionTokenClaims struct {
+ SessionId string
+ ExpiresAt time.Time
+ IssuedAt time.Time
+}
+
+// CredentialGenerator generates AWS-compatible temporary credentials
+type CredentialGenerator struct{}
+
+// NewCredentialGenerator creates a new credential generator
+func NewCredentialGenerator() *CredentialGenerator {
+ return &CredentialGenerator{}
+}
+
+// GenerateTemporaryCredentials creates temporary AWS credentials
+func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) {
+ accessKeyId, err := c.generateAccessKeyId(sessionId)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate access key ID: %w", err)
+ }
+
+ secretAccessKey, err := c.generateSecretAccessKey()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate secret access key: %w", err)
+ }
+
+ sessionToken, err := c.generateSessionTokenId(sessionId)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate session token: %w", err)
+ }
+
+ return &Credentials{
+ AccessKeyId: accessKeyId,
+ SecretAccessKey: secretAccessKey,
+ SessionToken: sessionToken,
+ Expiration: expiration,
+ }, nil
+}
+
+// generateAccessKeyId generates an AWS-style access key ID
+func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) {
+ // Create a deterministic but unique access key ID based on session
+ hash := sha256.Sum256([]byte("access-key:" + sessionId))
+ return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars
+}
+
+// generateSecretAccessKey generates a random secret access key
+func (c *CredentialGenerator) generateSecretAccessKey() (string, error) {
+ // Generate 32 random bytes for secret key
+ secretBytes := make([]byte, 32)
+ _, err := rand.Read(secretBytes)
+ if err != nil {
+ return "", err
+ }
+
+ return base64.StdEncoding.EncodeToString(secretBytes), nil
+}
+
+// generateSessionTokenId generates a session token identifier
+func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) {
+ // Create session token with session ID embedded
+ hash := sha256.Sum256([]byte("session-token:" + sessionId))
+ return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format
+}
+
+// generateSessionId generates a unique session ID
+func GenerateSessionId() (string, error) {
+ randomBytes := make([]byte, 16)
+ _, err := rand.Read(randomBytes)
+ if err != nil {
+ return "", err
+ }
+
+ return hex.EncodeToString(randomBytes), nil
+}
+
+// generateAssumedRoleArn generates the ARN for an assumed role user
+func GenerateAssumedRoleArn(roleArn, sessionName string) string {
+ // Convert role ARN to assumed role user ARN
+ // arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName
+ roleName := utils.ExtractRoleNameFromArn(roleArn)
+ if roleName == "" {
+ // This should not happen if validation is done properly upstream
+ return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName)
+ }
+ return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName)
+}