diff options
Diffstat (limited to 'weed/iam/sts/cross_instance_token_test.go')
| -rw-r--r-- | weed/iam/sts/cross_instance_token_test.go | 503 |
1 files changed, 503 insertions, 0 deletions
diff --git a/weed/iam/sts/cross_instance_token_test.go b/weed/iam/sts/cross_instance_token_test.go new file mode 100644 index 000000000..243951d82 --- /dev/null +++ b/weed/iam/sts/cross_instance_token_test.go @@ -0,0 +1,503 @@ +package sts + +import ( + "context" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/seaweedfs/seaweedfs/weed/iam/oidc" + "github.com/seaweedfs/seaweedfs/weed/iam/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test-only constants for mock providers +const ( + ProviderTypeMock = "mock" +) + +// createMockOIDCProvider creates a mock OIDC provider for testing +// This is only available in test builds +func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) { + // Convert config to OIDC format + factory := NewProviderFactory() + oidcConfig, err := factory.convertToOIDCConfig(config) + if err != nil { + return nil, err + } + + // Set default values for mock provider if not provided + if oidcConfig.Issuer == "" { + oidcConfig.Issuer = "http://localhost:9999" + } + + provider := oidc.NewMockOIDCProvider(name) + if err := provider.Initialize(oidcConfig); err != nil { + return nil, err + } + + // Set up default test data for the mock provider + provider.SetupDefaultTestData() + + return provider, nil +} + +// createMockJWT creates a test JWT token with the specified issuer for mock provider testing +func createMockJWT(t *testing.T, issuer, subject string) string { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iss": issuer, + "sub": subject, + "aud": "test-client", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }) + + tokenString, err := token.SignedString([]byte("test-signing-key")) + require.NoError(t, err) + return tokenString +} + +// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance +// can be used and validated by other STS instances in a distributed environment +func TestCrossInstanceTokenUsage(t *testing.T) { + ctx := context.Background() + // Dummy filer address for testing + + // Common configuration that would be shared across all instances in production + sharedConfig := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "distributed-sts-cluster", // SAME across all instances + SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances + Providers: []*ProviderConfig{ + { + Name: "company-oidc", + Type: ProviderTypeOIDC, + Enabled: true, + Config: map[string]interface{}{ + ConfigFieldIssuer: "https://sso.company.com/realms/production", + ConfigFieldClientID: "seaweedfs-cluster", + ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs", + }, + }, + }, + } + + // Create multiple STS instances simulating different S3 gateway instances + instanceA := NewSTSService() // e.g., s3-gateway-1 + instanceB := NewSTSService() // e.g., s3-gateway-2 + instanceC := NewSTSService() // e.g., s3-gateway-3 + + // Initialize all instances with IDENTICAL configuration + err := instanceA.Initialize(sharedConfig) + require.NoError(t, err, "Instance A should initialize") + + err = instanceB.Initialize(sharedConfig) + require.NoError(t, err, "Instance B should initialize") + + err = instanceC.Initialize(sharedConfig) + require.NoError(t, err, "Instance C should initialize") + + // Set up mock trust policy validator for all instances (required for STS testing) + mockValidator := &MockTrustPolicyValidator{} + instanceA.SetTrustPolicyValidator(mockValidator) + instanceB.SetTrustPolicyValidator(mockValidator) + instanceC.SetTrustPolicyValidator(mockValidator) + + // Manually register mock provider for testing (not available in production) + mockProviderConfig := map[string]interface{}{ + ConfigFieldIssuer: "http://test-mock:9999", + ConfigFieldClientID: TestClientID, + } + mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + + instanceA.RegisterProvider(mockProviderA) + instanceB.RegisterProvider(mockProviderB) + instanceC.RegisterProvider(mockProviderC) + + // Test 1: Token generated on Instance A can be validated on Instance B & C + t.Run("cross_instance_token_validation", func(t *testing.T) { + // Generate session token on Instance A + sessionId := TestSessionID + expiresAt := time.Now().Add(time.Hour) + + tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + require.NoError(t, err, "Instance A should generate token") + + // Validate token on Instance B + claimsFromB, err := instanceB.tokenGenerator.ValidateSessionToken(tokenFromA) + require.NoError(t, err, "Instance B should validate token from Instance A") + assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match") + + // Validate same token on Instance C + claimsFromC, err := instanceC.tokenGenerator.ValidateSessionToken(tokenFromA) + require.NoError(t, err, "Instance C should validate token from Instance A") + assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match") + + // All instances should extract identical claims + assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId) + assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix()) + assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix()) + }) + + // Test 2: Complete assume role flow across instances + t.Run("cross_instance_assume_role_flow", func(t *testing.T) { + // Step 1: User authenticates and assumes role on Instance A + // Create a valid JWT token for the mock provider + mockToken := createMockJWT(t, "http://test-mock:9999", "test-user") + + assumeRequest := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/CrossInstanceTestRole", + WebIdentityToken: mockToken, // JWT token for mock provider + RoleSessionName: "cross-instance-test-session", + DurationSeconds: int64ToPtr(3600), + } + + // Instance A processes assume role request + responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) + require.NoError(t, err, "Instance A should process assume role") + + sessionToken := responseFromA.Credentials.SessionToken + accessKeyId := responseFromA.Credentials.AccessKeyId + secretAccessKey := responseFromA.Credentials.SecretAccessKey + + // Verify response structure + assert.NotEmpty(t, sessionToken, "Should have session token") + assert.NotEmpty(t, accessKeyId, "Should have access key ID") + assert.NotEmpty(t, secretAccessKey, "Should have secret access key") + assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user") + + // Step 2: Use session token on Instance B (different instance) + sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Instance B should validate session token from Instance A") + + assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName) + assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn) + + // Step 3: Use same session token on Instance C (yet another instance) + sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Instance C should validate session token from Instance A") + + // All instances should return identical session information + assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId) + assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName) + assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn) + assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject) + assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider) + }) + + // Test 3: Session revocation across instances + t.Run("cross_instance_session_revocation", func(t *testing.T) { + // Create session on Instance A + mockToken := createMockJWT(t, "http://test-mock:9999", "test-user") + + assumeRequest := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/RevocationTestRole", + WebIdentityToken: mockToken, + RoleSessionName: "revocation-test-session", + } + + response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) + require.NoError(t, err) + sessionToken := response.Credentials.SessionToken + + // Verify token works on Instance B + _, err = instanceB.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Token should be valid on Instance B initially") + + // Validate session on Instance C to verify cross-instance token compatibility + _, err = instanceC.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Instance C should be able to validate session token") + + // In a stateless JWT system, tokens remain valid on all instances since they're self-contained + // No revocation is possible without breaking the stateless architecture + _, err = instanceA.ValidateSessionToken(ctx, sessionToken) + assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)") + + // Verify token is still valid on Instance B + _, err = instanceB.ValidateSessionToken(ctx, sessionToken) + assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)") + }) + + // Test 4: Provider consistency across instances + t.Run("provider_consistency_affects_token_generation", func(t *testing.T) { + // All instances should have same providers and be able to process same OIDC tokens + providerNamesA := instanceA.getProviderNames() + providerNamesB := instanceB.getProviderNames() + providerNamesC := instanceC.getProviderNames() + + assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers") + assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers") + + // All instances should be able to process same web identity token + testToken := createMockJWT(t, "http://test-mock:9999", "test-user") + + // Try to assume role with same token on different instances + assumeRequest := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/ProviderTestRole", + WebIdentityToken: testToken, + RoleSessionName: "provider-consistency-test", + } + + // Should work on any instance + responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) + responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest) + responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest) + + require.NoError(t, errA, "Instance A should process OIDC token") + require.NoError(t, errB, "Instance B should process OIDC token") + require.NoError(t, errC, "Instance C should process OIDC token") + + // All should return valid responses (sessions will have different IDs but same structure) + assert.NotEmpty(t, responseA.Credentials.SessionToken) + assert.NotEmpty(t, responseB.Credentials.SessionToken) + assert.NotEmpty(t, responseC.Credentials.SessionToken) + }) +} + +// TestSTSDistributedConfigurationRequirements tests the configuration requirements +// for cross-instance token compatibility +func TestSTSDistributedConfigurationRequirements(t *testing.T) { + _ = "localhost:8888" // Dummy filer address for testing (not used in these tests) + + t.Run("same_signing_key_required", func(t *testing.T) { + // Instance A with signing key 1 + configA := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "test-sts", + SigningKey: []byte("signing-key-1-32-characters-long"), + } + + // Instance B with different signing key + configB := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "test-sts", + SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT! + } + + instanceA := NewSTSService() + instanceB := NewSTSService() + + err := instanceA.Initialize(configA) + require.NoError(t, err) + + err = instanceB.Initialize(configB) + require.NoError(t, err) + + // Generate token on Instance A + sessionId := "test-session" + expiresAt := time.Now().Add(time.Hour) + tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + require.NoError(t, err) + + // Instance A should validate its own token + _, err = instanceA.tokenGenerator.ValidateSessionToken(tokenFromA) + assert.NoError(t, err, "Instance A should validate own token") + + // Instance B should REJECT token due to different signing key + _, err = instanceB.tokenGenerator.ValidateSessionToken(tokenFromA) + assert.Error(t, err, "Instance B should reject token with different signing key") + assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error") + }) + + t.Run("same_issuer_required", func(t *testing.T) { + sharedSigningKey := []byte("shared-signing-key-32-characters-lo") + + // Instance A with issuer 1 + configA := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "sts-cluster-1", + SigningKey: sharedSigningKey, + } + + // Instance B with different issuer + configB := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "sts-cluster-2", // DIFFERENT! + SigningKey: sharedSigningKey, + } + + instanceA := NewSTSService() + instanceB := NewSTSService() + + err := instanceA.Initialize(configA) + require.NoError(t, err) + + err = instanceB.Initialize(configB) + require.NoError(t, err) + + // Generate token on Instance A + sessionId := "test-session" + expiresAt := time.Now().Add(time.Hour) + tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + require.NoError(t, err) + + // Instance B should REJECT token due to different issuer + _, err = instanceB.tokenGenerator.ValidateSessionToken(tokenFromA) + assert.Error(t, err, "Instance B should reject token with different issuer") + assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error") + }) + + t.Run("identical_configuration_required", func(t *testing.T) { + // Identical configuration + identicalConfig := &STSConfig{ + TokenDuration: FlexibleDuration{time.Hour}, + MaxSessionLength: FlexibleDuration{12 * time.Hour}, + Issuer: "production-sts-cluster", + SigningKey: []byte("production-signing-key-32-chars-l"), + } + + // Create multiple instances with identical config + instances := make([]*STSService, 5) + for i := 0; i < 5; i++ { + instances[i] = NewSTSService() + err := instances[i].Initialize(identicalConfig) + require.NoError(t, err, "Instance %d should initialize", i) + } + + // Generate token on Instance 0 + sessionId := "multi-instance-test" + expiresAt := time.Now().Add(time.Hour) + token, err := instances[0].tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + require.NoError(t, err) + + // All other instances should validate the token + for i := 1; i < 5; i++ { + claims, err := instances[i].tokenGenerator.ValidateSessionToken(token) + require.NoError(t, err, "Instance %d should validate token", i) + assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i) + } + }) +} + +// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios +func TestSTSRealWorldDistributedScenarios(t *testing.T) { + ctx := context.Background() + + t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) { + // Simulate real production scenario: + // 1. User authenticates with OIDC provider + // 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1 + // 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer + // 4. All instances should handle the session token correctly + + productionConfig := &STSConfig{ + TokenDuration: FlexibleDuration{2 * time.Hour}, + MaxSessionLength: FlexibleDuration{24 * time.Hour}, + Issuer: "seaweedfs-production-sts", + SigningKey: []byte("prod-signing-key-32-characters-lon"), + + Providers: []*ProviderConfig{ + { + Name: "corporate-oidc", + Type: "oidc", + Enabled: true, + Config: map[string]interface{}{ + "issuer": "https://sso.company.com/realms/production", + "clientId": "seaweedfs-prod-cluster", + "clientSecret": "supersecret-prod-key", + "scopes": []string{"openid", "profile", "email", "groups"}, + }, + }, + }, + } + + // Create 3 S3 Gateway instances behind load balancer + gateway1 := NewSTSService() + gateway2 := NewSTSService() + gateway3 := NewSTSService() + + err := gateway1.Initialize(productionConfig) + require.NoError(t, err) + + err = gateway2.Initialize(productionConfig) + require.NoError(t, err) + + err = gateway3.Initialize(productionConfig) + require.NoError(t, err) + + // Set up mock trust policy validator for all gateway instances + mockValidator := &MockTrustPolicyValidator{} + gateway1.SetTrustPolicyValidator(mockValidator) + gateway2.SetTrustPolicyValidator(mockValidator) + gateway3.SetTrustPolicyValidator(mockValidator) + + // Manually register mock provider for testing (not available in production) + mockProviderConfig := map[string]interface{}{ + ConfigFieldIssuer: "http://test-mock:9999", + ConfigFieldClientID: "test-client-id", + } + mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig) + require.NoError(t, err) + + gateway1.RegisterProvider(mockProvider1) + gateway2.RegisterProvider(mockProvider2) + gateway3.RegisterProvider(mockProvider3) + + // Step 1: User authenticates and hits Gateway 1 for AssumeRole + mockToken := createMockJWT(t, "http://test-mock:9999", "production-user") + + assumeRequest := &AssumeRoleWithWebIdentityRequest{ + RoleArn: "arn:seaweed:iam::role/ProductionS3User", + WebIdentityToken: mockToken, // JWT token from mock provider + RoleSessionName: "user-production-session", + DurationSeconds: int64ToPtr(7200), // 2 hours + } + + stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest) + require.NoError(t, err, "Gateway 1 should handle AssumeRole") + + sessionToken := stsResponse.Credentials.SessionToken + accessKey := stsResponse.Credentials.AccessKeyId + secretKey := stsResponse.Credentials.SecretAccessKey + + // Step 2: User makes S3 requests that hit different gateways via load balancer + // Simulate S3 request validation on Gateway 2 + sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Gateway 2 should validate session from Gateway 1") + assert.Equal(t, "user-production-session", sessionInfo2.SessionName) + assert.Equal(t, "arn:seaweed:iam::role/ProductionS3User", sessionInfo2.RoleArn) + + // Simulate S3 request validation on Gateway 3 + sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err, "Gateway 3 should validate session from Gateway 1") + assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session") + + // Step 3: Verify credentials are consistent + assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent") + assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent") + + // Step 4: Session expiration should be honored across all instances + assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired") + assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired") + + // Step 5: Token should be identical when parsed + claims2, err := gateway2.tokenGenerator.ValidateSessionToken(sessionToken) + require.NoError(t, err) + + claims3, err := gateway3.tokenGenerator.ValidateSessionToken(sessionToken) + require.NoError(t, err) + + assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match") + assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match") + }) +} + +// Helper function to convert int64 to pointer +func int64ToPtr(i int64) *int64 { + return &i +} |
