diff options
Diffstat (limited to 'weed/iam/sts/sts_service_test.go')
| -rw-r--r-- | weed/iam/sts/sts_service_test.go | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go index 72d69c8c8..56b6755de 100644 --- a/weed/iam/sts/sts_service_test.go +++ b/weed/iam/sts/sts_service_test.go @@ -451,3 +451,212 @@ func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) } return nil, fmt.Errorf("invalid token") } + +// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim +func TestSessionDurationCappedByTokenExpiration(t *testing.T) { + service := NewSTSService() + + config := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour + MaxSessionLength: FlexibleDuration{time.Hour * 12}, + Issuer: "test-sts", + SigningKey: []byte("test-signing-key-32-characters-long"), + } + + err := service.Initialize(config) + require.NoError(t, err) + + tests := []struct { + name string + durationSeconds *int64 + tokenExpiration *time.Time + expectedMaxSeconds int64 + description string + }{ + { + name: "no token expiration - use default duration", + durationSeconds: nil, + tokenExpiration: nil, + expectedMaxSeconds: 3600, // 1 hour default + description: "When no token expiration is set, use the configured default duration", + }, + { + name: "token expires before default duration", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)), + expectedMaxSeconds: 30 * 60, // 30 minutes + description: "When token expires in 30 min, session should be capped at 30 min", + }, + { + name: "token expires after default duration - use default", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)), + expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry + description: "When token expires after default duration, use the default duration", + }, + { + name: "requested duration shorter than token expiry", + durationSeconds: int64Ptr(1800), // 30 min requested + tokenExpiration: timePtr(time.Now().Add(time.Hour)), + expectedMaxSeconds: 1800, // 30 minutes as requested + description: "When requested duration is shorter than token expiry, use requested duration", + }, + { + name: "requested duration longer than token expiry - cap at token expiry", + durationSeconds: int64Ptr(3600), // 1 hour requested + tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)), + expectedMaxSeconds: 15 * 60, // Capped at 15 minutes + description: "When requested duration exceeds token expiry, cap at token expiry", + }, + { + name: "GitLab CI short-lived token scenario", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)), + expectedMaxSeconds: 5 * 60, // 5 minutes + description: "GitLab CI job with 5 minute timeout should result in 5 minute session", + }, + { + name: "already expired token - defense in depth", + durationSeconds: nil, + tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago + expectedMaxSeconds: 60, // 1 minute minimum + description: "Already expired token should result in minimal 1 minute session", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration) + + // Allow 5 second tolerance for time calculations + maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second + minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second + + assert.GreaterOrEqual(t, duration, minExpected, + "%s: duration %v should be >= %v", tt.description, duration, minExpected) + assert.LessOrEqual(t, duration, maxExpected, + "%s: duration %v should be <= %v", tt.description, duration, maxExpected) + }) + } +} + +// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped +func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) { + service := NewSTSService() + + config := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{time.Hour * 12}, + Issuer: "test-sts", + SigningKey: []byte("test-signing-key-32-characters-long"), + } + + err := service.Initialize(config) + require.NoError(t, err) + + // Set up mock trust policy validator + mockValidator := &MockTrustPolicyValidator{} + service.SetTrustPolicyValidator(mockValidator) + + // Create a mock provider that returns tokens with short expiration + shortLivedTokenExpiration := time.Now().Add(10 * time.Minute) + mockProvider := &MockIdentityProviderWithExpiration{ + name: "short-lived-issuer", + tokenExpiration: &shortLivedTokenExpiration, + } + service.RegisterProvider(mockProvider) + + ctx := context.Background() + + // Create a JWT token with short expiration + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iss": "short-lived-issuer", + "sub": "test-user", + "aud": "test-client", + "exp": shortLivedTokenExpiration.Unix(), + "iat": time.Now().Unix(), + }) + tokenString, err := token.SignedString([]byte("test-signing-key")) + require.NoError(t, err) + + request := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:aws:iam::role/TestRole", + WebIdentityToken: tokenString, + RoleSessionName: "test-session", + } + + response, err := service.AssumeRoleWithWebIdentity(ctx, request) + require.NoError(t, err) + require.NotNil(t, response) + + // Verify the session expires at or before the token expiration + // Allow 5 second tolerance + assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)), + "Session expiration (%v) should not exceed token expiration (%v)", + response.Credentials.Expiration, shortLivedTokenExpiration) +} + +// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration +type MockIdentityProviderWithExpiration struct { + name string + tokenExpiration *time.Time +} + +func (m *MockIdentityProviderWithExpiration) Name() string { + return m.name +} + +func (m *MockIdentityProviderWithExpiration) GetIssuer() string { + return m.name +} + +func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error { + return nil +} + +func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { + // Parse the token to get subject + parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + claims, ok := parsedToken.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid claims") + } + + subject, _ := claims["sub"].(string) + + identity := &providers.ExternalIdentity{ + UserID: subject, + Email: subject + "@example.com", + DisplayName: "Test User", + Provider: m.name, + TokenExpiration: m.tokenExpiration, + } + + return identity, nil +} + +func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { + return &providers.ExternalIdentity{ + UserID: userID, + Provider: m.name, + }, nil +} + +func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { + claims := &providers.TokenClaims{ + Subject: "test-user", + Issuer: m.name, + } + if m.tokenExpiration != nil { + claims.ExpiresAt = *m.tokenExpiration + } + return claims, nil +} + +func timePtr(t time.Time) *time.Time { + return &t +} |
