aboutsummaryrefslogtreecommitdiff
path: root/weed/iam/sts/sts_service_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/iam/sts/sts_service_test.go')
-rw-r--r--weed/iam/sts/sts_service_test.go453
1 files changed, 453 insertions, 0 deletions
diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go
new file mode 100644
index 000000000..60d78118f
--- /dev/null
+++ b/weed/iam/sts/sts_service_test.go
@@ -0,0 +1,453 @@
+package sts
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/seaweedfs/seaweedfs/weed/iam/providers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// createSTSTestJWT creates a test JWT token for STS service tests
+func createSTSTestJWT(t *testing.T, issuer, subject string) string {
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iss": issuer,
+ "sub": subject,
+ "aud": "test-client",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ tokenString, err := token.SignedString([]byte("test-signing-key"))
+ require.NoError(t, err)
+ return tokenString
+}
+
+// TestSTSServiceInitialization tests STS service initialization
+func TestSTSServiceInitialization(t *testing.T) {
+ tests := []struct {
+ name string
+ config *STSConfig
+ wantErr bool
+ }{
+ {
+ name: "valid config",
+ config: &STSConfig{
+ TokenDuration: FlexibleDuration{time.Hour},
+ MaxSessionLength: FlexibleDuration{time.Hour * 12},
+ Issuer: "seaweedfs-sts",
+ SigningKey: []byte("test-signing-key"),
+ },
+ wantErr: false,
+ },
+ {
+ name: "missing signing key",
+ config: &STSConfig{
+ TokenDuration: FlexibleDuration{time.Hour},
+ Issuer: "seaweedfs-sts",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid token duration",
+ config: &STSConfig{
+ TokenDuration: FlexibleDuration{-time.Hour},
+ Issuer: "seaweedfs-sts",
+ SigningKey: []byte("test-key"),
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ service := NewSTSService()
+
+ err := service.Initialize(tt.config)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.True(t, service.IsInitialized())
+ }
+ })
+ }
+}
+
+// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
+func TestAssumeRoleWithWebIdentity(t *testing.T) {
+ service := setupTestSTSService(t)
+
+ tests := []struct {
+ name string
+ roleArn string
+ webIdentityToken string
+ sessionName string
+ durationSeconds *int64
+ wantErr bool
+ expectedSubject string
+ }{
+ {
+ name: "successful role assumption",
+ roleArn: "arn:seaweed:iam::role/TestRole",
+ webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
+ sessionName: "test-session",
+ durationSeconds: nil, // Use default
+ wantErr: false,
+ expectedSubject: "test-user-id",
+ },
+ {
+ name: "invalid web identity token",
+ roleArn: "arn:seaweed:iam::role/TestRole",
+ webIdentityToken: "invalid-token",
+ sessionName: "test-session",
+ wantErr: true,
+ },
+ {
+ name: "non-existent role",
+ roleArn: "arn:seaweed:iam::role/NonExistentRole",
+ webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
+ sessionName: "test-session",
+ wantErr: true,
+ },
+ {
+ name: "custom session duration",
+ roleArn: "arn:seaweed:iam::role/TestRole",
+ webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
+ sessionName: "test-session",
+ durationSeconds: int64Ptr(7200), // 2 hours
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+
+ request := &AssumeRoleWithWebIdentityRequest{
+ RoleArn: tt.roleArn,
+ WebIdentityToken: tt.webIdentityToken,
+ RoleSessionName: tt.sessionName,
+ DurationSeconds: tt.durationSeconds,
+ }
+
+ response, err := service.AssumeRoleWithWebIdentity(ctx, request)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, response)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, response)
+ assert.NotNil(t, response.Credentials)
+ assert.NotNil(t, response.AssumedRoleUser)
+
+ // Verify credentials
+ creds := response.Credentials
+ assert.NotEmpty(t, creds.AccessKeyId)
+ assert.NotEmpty(t, creds.SecretAccessKey)
+ assert.NotEmpty(t, creds.SessionToken)
+ assert.True(t, creds.Expiration.After(time.Now()))
+
+ // Verify assumed role user
+ user := response.AssumedRoleUser
+ assert.Equal(t, tt.roleArn, user.AssumedRoleId)
+ assert.Contains(t, user.Arn, tt.sessionName)
+
+ if tt.expectedSubject != "" {
+ assert.Equal(t, tt.expectedSubject, user.Subject)
+ }
+ }
+ })
+ }
+}
+
+// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
+func TestAssumeRoleWithLDAP(t *testing.T) {
+ service := setupTestSTSService(t)
+
+ tests := []struct {
+ name string
+ roleArn string
+ username string
+ password string
+ sessionName string
+ wantErr bool
+ }{
+ {
+ name: "successful LDAP role assumption",
+ roleArn: "arn:seaweed:iam::role/LDAPRole",
+ username: "testuser",
+ password: "testpass",
+ sessionName: "ldap-session",
+ wantErr: false,
+ },
+ {
+ name: "invalid LDAP credentials",
+ roleArn: "arn:seaweed:iam::role/LDAPRole",
+ username: "testuser",
+ password: "wrongpass",
+ sessionName: "ldap-session",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+
+ request := &AssumeRoleWithCredentialsRequest{
+ RoleArn: tt.roleArn,
+ Username: tt.username,
+ Password: tt.password,
+ RoleSessionName: tt.sessionName,
+ ProviderName: "test-ldap",
+ }
+
+ response, err := service.AssumeRoleWithCredentials(ctx, request)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, response)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, response)
+ assert.NotNil(t, response.Credentials)
+ }
+ })
+ }
+}
+
+// TestSessionTokenValidation tests session token validation
+func TestSessionTokenValidation(t *testing.T) {
+ service := setupTestSTSService(t)
+ ctx := context.Background()
+
+ // First, create a session
+ request := &AssumeRoleWithWebIdentityRequest{
+ RoleArn: "arn:seaweed:iam::role/TestRole",
+ WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
+ RoleSessionName: "test-session",
+ }
+
+ response, err := service.AssumeRoleWithWebIdentity(ctx, request)
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ sessionToken := response.Credentials.SessionToken
+
+ tests := []struct {
+ name string
+ token string
+ wantErr bool
+ }{
+ {
+ name: "valid session token",
+ token: sessionToken,
+ wantErr: false,
+ },
+ {
+ name: "invalid session token",
+ token: "invalid-session-token",
+ wantErr: true,
+ },
+ {
+ name: "empty session token",
+ token: "",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ session, err := service.ValidateSessionToken(ctx, tt.token)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, session)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, session)
+ assert.Equal(t, "test-session", session.SessionName)
+ assert.Equal(t, "arn:seaweed:iam::role/TestRole", session.RoleArn)
+ }
+ })
+ }
+}
+
+// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
+// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
+func TestSessionTokenPersistence(t *testing.T) {
+ service := setupTestSTSService(t)
+ ctx := context.Background()
+
+ // Create a session first
+ request := &AssumeRoleWithWebIdentityRequest{
+ RoleArn: "arn:seaweed:iam::role/TestRole",
+ WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
+ RoleSessionName: "test-session",
+ }
+
+ response, err := service.AssumeRoleWithWebIdentity(ctx, request)
+ require.NoError(t, err)
+
+ sessionToken := response.Credentials.SessionToken
+
+ // Verify token is valid initially
+ session, err := service.ValidateSessionToken(ctx, sessionToken)
+ assert.NoError(t, err)
+ assert.NotNil(t, session)
+ assert.Equal(t, "test-session", session.SessionName)
+
+ // In a stateless JWT system, tokens remain valid throughout their lifetime
+ // Multiple validations should all succeed as long as the token hasn't expired
+ session2, err := service.ValidateSessionToken(ctx, sessionToken)
+ assert.NoError(t, err, "Token should remain valid in stateless system")
+ assert.NotNil(t, session2, "Session should be returned from JWT token")
+ assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
+}
+
+// Helper functions
+
+func setupTestSTSService(t *testing.T) *STSService {
+ service := NewSTSService()
+
+ config := &STSConfig{
+ TokenDuration: FlexibleDuration{time.Hour},
+ MaxSessionLength: FlexibleDuration{time.Hour * 12},
+ Issuer: "test-sts",
+ SigningKey: []byte("test-signing-key-32-characters-long"),
+ }
+
+ err := service.Initialize(config)
+ require.NoError(t, err)
+
+ // Set up mock trust policy validator (required for STS testing)
+ mockValidator := &MockTrustPolicyValidator{}
+ service.SetTrustPolicyValidator(mockValidator)
+
+ // Register test providers
+ mockOIDCProvider := &MockIdentityProvider{
+ name: "test-oidc",
+ validTokens: map[string]*providers.TokenClaims{
+ createSTSTestJWT(t, "test-issuer", "test-user"): {
+ Subject: "test-user-id",
+ Issuer: "test-issuer",
+ Claims: map[string]interface{}{
+ "email": "test@example.com",
+ "name": "Test User",
+ },
+ },
+ },
+ }
+
+ mockLDAPProvider := &MockIdentityProvider{
+ name: "test-ldap",
+ validCredentials: map[string]string{
+ "testuser": "testpass",
+ },
+ }
+
+ service.RegisterProvider(mockOIDCProvider)
+ service.RegisterProvider(mockLDAPProvider)
+
+ return service
+}
+
+func int64Ptr(v int64) *int64 {
+ return &v
+}
+
+// Mock identity provider for testing
+type MockIdentityProvider struct {
+ name string
+ validTokens map[string]*providers.TokenClaims
+ validCredentials map[string]string
+}
+
+func (m *MockIdentityProvider) Name() string {
+ return m.name
+}
+
+func (m *MockIdentityProvider) GetIssuer() string {
+ return "test-issuer" // This matches the issuer in the token claims
+}
+
+func (m *MockIdentityProvider) Initialize(config interface{}) error {
+ return nil
+}
+
+func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
+ // First try to parse as JWT token
+ if len(token) > 20 && strings.Count(token, ".") >= 2 {
+ parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
+ if err == nil {
+ if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
+ issuer, _ := claims["iss"].(string)
+ subject, _ := claims["sub"].(string)
+
+ // Verify the issuer matches what we expect
+ if issuer == "test-issuer" && subject != "" {
+ return &providers.ExternalIdentity{
+ UserID: subject,
+ Email: subject + "@test-domain.com",
+ DisplayName: "Test User " + subject,
+ Provider: m.name,
+ }, nil
+ }
+ }
+ }
+ }
+
+ // Handle legacy OIDC tokens (for backwards compatibility)
+ if claims, exists := m.validTokens[token]; exists {
+ email, _ := claims.GetClaimString("email")
+ name, _ := claims.GetClaimString("name")
+
+ return &providers.ExternalIdentity{
+ UserID: claims.Subject,
+ Email: email,
+ DisplayName: name,
+ Provider: m.name,
+ }, nil
+ }
+
+ // Handle LDAP credentials (username:password format)
+ if m.validCredentials != nil {
+ parts := strings.Split(token, ":")
+ if len(parts) == 2 {
+ username, password := parts[0], parts[1]
+ if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
+ return &providers.ExternalIdentity{
+ UserID: username,
+ Email: username + "@" + m.name + ".com",
+ DisplayName: "Test User " + username,
+ Provider: m.name,
+ }, nil
+ }
+ }
+ }
+
+ return nil, fmt.Errorf("unknown test token: %s", token)
+}
+
+func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
+ return &providers.ExternalIdentity{
+ UserID: userID,
+ Email: userID + "@" + m.name + ".com",
+ Provider: m.name,
+ }, nil
+}
+
+func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
+ if claims, exists := m.validTokens[token]; exists {
+ return claims, nil
+ }
+ return nil, fmt.Errorf("invalid token")
+}