aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--weed/iam/oidc/oidc_provider.go30
-rw-r--r--weed/iam/providers/provider.go4
-rw-r--r--weed/iam/sts/sts_service.go46
-rw-r--r--weed/iam/sts/sts_service_test.go209
4 files changed, 279 insertions, 10 deletions
diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go
index d31f322b0..fe1cdaccb 100644
--- a/weed/iam/oidc/oidc_provider.go
+++ b/weed/iam/oidc/oidc_provider.go
@@ -186,14 +186,22 @@ func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*provide
attributes["roles"] = strings.Join(roles, ",")
}
- return &providers.ExternalIdentity{
+ identity := &providers.ExternalIdentity{
UserID: claims.Subject,
Email: email,
DisplayName: displayName,
Groups: groups,
Attributes: attributes,
Provider: p.name,
- }, nil
+ }
+
+ // Pass the token expiration to limit session duration
+ // This ensures the STS session doesn't exceed the source token's validity
+ if !claims.ExpiresAt.IsZero() {
+ identity.TokenExpiration = &claims.ExpiresAt
+ }
+
+ return identity, nil
}
// GetUserInfo retrieves user information from the UserInfo endpoint
@@ -372,6 +380,24 @@ func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*provid
Claims: make(map[string]interface{}),
}
+ // Extract time-based claims (exp, iat, nbf)
+ for key, target := range map[string]*time.Time{
+ "exp": &tokenClaims.ExpiresAt,
+ "iat": &tokenClaims.IssuedAt,
+ "nbf": &tokenClaims.NotBefore,
+ } {
+ if val, ok := claims[key]; ok {
+ switch v := val.(type) {
+ case float64:
+ *target = time.Unix(int64(v), 0)
+ case json.Number:
+ if intVal, err := v.Int64(); err == nil {
+ *target = time.Unix(intVal, 0)
+ }
+ }
+ }
+ }
+
// Copy all claims
for key, value := range claims {
tokenClaims.Claims[key] = value
diff --git a/weed/iam/providers/provider.go b/weed/iam/providers/provider.go
index 5c1deb03d..3b7affc8e 100644
--- a/weed/iam/providers/provider.go
+++ b/weed/iam/providers/provider.go
@@ -47,6 +47,10 @@ type ExternalIdentity struct {
// Provider is the name of the identity provider
Provider string `json:"provider"`
+
+ // TokenExpiration is the expiration time of the source identity token
+ // This is used to limit session duration to not exceed the token's exp claim
+ TokenExpiration *time.Time `json:"tokenExpiration,omitempty"`
}
// Validate validates the external identity structure
diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go
index 3d9f9af35..e28340f30 100644
--- a/weed/iam/sts/sts_service.go
+++ b/weed/iam/sts/sts_service.go
@@ -422,8 +422,9 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
return nil, fmt.Errorf("role assumption denied: %w", err)
}
- // 3. Calculate session duration
- sessionDuration := s.calculateSessionDuration(request.DurationSeconds)
+ // 3. Calculate session duration, capping at the source token's expiration
+ // This ensures sessions from short-lived tokens (e.g., GitLab CI job tokens) don't outlive their source
+ sessionDuration := s.calculateSessionDuration(request.DurationSeconds, externalIdentity.TokenExpiration)
expiresAt := time.Now().Add(sessionDuration)
// 4. Generate session ID and credentials
@@ -502,7 +503,8 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
}
// 4. Calculate session duration
- sessionDuration := s.calculateSessionDuration(request.DurationSeconds)
+ // For credential-based auth, there's no source token with expiration to cap against
+ sessionDuration := s.calculateSessionDuration(request.DurationSeconds, nil)
expiresAt := time.Now().Add(sessionDuration)
// 5. Generate session ID and temporary credentials
@@ -745,14 +747,42 @@ func (s *STSService) validateRoleAssumptionForCredentials(ctx context.Context, r
return nil
}
-// calculateSessionDuration calculates the session duration
-func (s *STSService) calculateSessionDuration(durationSeconds *int64) time.Duration {
+// calculateSessionDuration calculates the session duration, respecting the source token's expiration
+// If the incoming web identity token has an exp claim, the session duration is capped to not exceed it
+// This ensures that sessions from short-lived tokens (e.g., GitLab CI job tokens) don't outlive their source
+func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpiration *time.Time) time.Duration {
+ var duration time.Duration
if durationSeconds != nil {
- return time.Duration(*durationSeconds) * time.Second
+ duration = time.Duration(*durationSeconds) * time.Second
+ } else {
+ // Use default from config
+ duration = s.Config.TokenDuration.Duration
+ }
+
+ // If the source token has an expiration, cap the session duration to not exceed it
+ // This follows the principle: "if calculated exp > incoming exp claim, then limit outgoing exp to incoming exp"
+ if tokenExpiration != nil && !tokenExpiration.IsZero() {
+ timeUntilTokenExpiry := time.Until(*tokenExpiration)
+ if timeUntilTokenExpiry <= 0 {
+ // Token already expired - use minimal duration as defense-in-depth
+ // The token should have been rejected during validation, but we handle this defensively
+ glog.V(2).Infof("Source token already expired, using minimal session duration")
+ duration = time.Minute
+ } else if timeUntilTokenExpiry < duration {
+ glog.V(2).Infof("Limiting session duration from %v to %v based on source token expiration",
+ duration, timeUntilTokenExpiry)
+ duration = timeUntilTokenExpiry
+ }
+ }
+
+ // Cap at MaxSessionLength if configured
+ if s.Config.MaxSessionLength.Duration > 0 && duration > s.Config.MaxSessionLength.Duration {
+ glog.V(2).Infof("Limiting session duration from %v to %v based on MaxSessionLength config",
+ duration, s.Config.MaxSessionLength.Duration)
+ duration = s.Config.MaxSessionLength.Duration
}
- // Use default from config
- return s.Config.TokenDuration.Duration
+ return duration
}
// extractSessionIdFromToken extracts session ID from JWT session token
diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go
index 72d69c8c8..56b6755de 100644
--- a/weed/iam/sts/sts_service_test.go
+++ b/weed/iam/sts/sts_service_test.go
@@ -451,3 +451,212 @@ func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string)
}
return nil, fmt.Errorf("invalid token")
}
+
+// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
+func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
+ service := NewSTSService()
+
+ config := &STSConfig{
+ TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
+ MaxSessionLength: FlexibleDuration{time.Hour * 12},
+ Issuer: "test-sts",
+ SigningKey: []byte("test-signing-key-32-characters-long"),
+ }
+
+ err := service.Initialize(config)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ durationSeconds *int64
+ tokenExpiration *time.Time
+ expectedMaxSeconds int64
+ description string
+ }{
+ {
+ name: "no token expiration - use default duration",
+ durationSeconds: nil,
+ tokenExpiration: nil,
+ expectedMaxSeconds: 3600, // 1 hour default
+ description: "When no token expiration is set, use the configured default duration",
+ },
+ {
+ name: "token expires before default duration",
+ durationSeconds: nil,
+ tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
+ expectedMaxSeconds: 30 * 60, // 30 minutes
+ description: "When token expires in 30 min, session should be capped at 30 min",
+ },
+ {
+ name: "token expires after default duration - use default",
+ durationSeconds: nil,
+ tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
+ expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
+ description: "When token expires after default duration, use the default duration",
+ },
+ {
+ name: "requested duration shorter than token expiry",
+ durationSeconds: int64Ptr(1800), // 30 min requested
+ tokenExpiration: timePtr(time.Now().Add(time.Hour)),
+ expectedMaxSeconds: 1800, // 30 minutes as requested
+ description: "When requested duration is shorter than token expiry, use requested duration",
+ },
+ {
+ name: "requested duration longer than token expiry - cap at token expiry",
+ durationSeconds: int64Ptr(3600), // 1 hour requested
+ tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
+ expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
+ description: "When requested duration exceeds token expiry, cap at token expiry",
+ },
+ {
+ name: "GitLab CI short-lived token scenario",
+ durationSeconds: nil,
+ tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
+ expectedMaxSeconds: 5 * 60, // 5 minutes
+ description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
+ },
+ {
+ name: "already expired token - defense in depth",
+ durationSeconds: nil,
+ tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
+ expectedMaxSeconds: 60, // 1 minute minimum
+ description: "Already expired token should result in minimal 1 minute session",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
+
+ // Allow 5 second tolerance for time calculations
+ maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
+ minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
+
+ assert.GreaterOrEqual(t, duration, minExpected,
+ "%s: duration %v should be >= %v", tt.description, duration, minExpected)
+ assert.LessOrEqual(t, duration, maxExpected,
+ "%s: duration %v should be <= %v", tt.description, duration, maxExpected)
+ })
+ }
+}
+
+// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
+func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
+ service := NewSTSService()
+
+ config := &STSConfig{
+ TokenDuration: FlexibleDuration{time.Hour},
+ MaxSessionLength: FlexibleDuration{time.Hour * 12},
+ Issuer: "test-sts",
+ SigningKey: []byte("test-signing-key-32-characters-long"),
+ }
+
+ err := service.Initialize(config)
+ require.NoError(t, err)
+
+ // Set up mock trust policy validator
+ mockValidator := &MockTrustPolicyValidator{}
+ service.SetTrustPolicyValidator(mockValidator)
+
+ // Create a mock provider that returns tokens with short expiration
+ shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
+ mockProvider := &MockIdentityProviderWithExpiration{
+ name: "short-lived-issuer",
+ tokenExpiration: &shortLivedTokenExpiration,
+ }
+ service.RegisterProvider(mockProvider)
+
+ ctx := context.Background()
+
+ // Create a JWT token with short expiration
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
+ "iss": "short-lived-issuer",
+ "sub": "test-user",
+ "aud": "test-client",
+ "exp": shortLivedTokenExpiration.Unix(),
+ "iat": time.Now().Unix(),
+ })
+ tokenString, err := token.SignedString([]byte("test-signing-key"))
+ require.NoError(t, err)
+
+ request := &AssumeRoleWithWebIdentityRequest{
+ RoleArn: "arn:aws:iam::role/TestRole",
+ WebIdentityToken: tokenString,
+ RoleSessionName: "test-session",
+ }
+
+ response, err := service.AssumeRoleWithWebIdentity(ctx, request)
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ // Verify the session expires at or before the token expiration
+ // Allow 5 second tolerance
+ assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
+ "Session expiration (%v) should not exceed token expiration (%v)",
+ response.Credentials.Expiration, shortLivedTokenExpiration)
+}
+
+// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
+type MockIdentityProviderWithExpiration struct {
+ name string
+ tokenExpiration *time.Time
+}
+
+func (m *MockIdentityProviderWithExpiration) Name() string {
+ return m.name
+}
+
+func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
+ return m.name
+}
+
+func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
+ return nil
+}
+
+func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
+ // Parse the token to get subject
+ parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse token: %w", err)
+ }
+
+ claims, ok := parsedToken.Claims.(jwt.MapClaims)
+ if !ok {
+ return nil, fmt.Errorf("invalid claims")
+ }
+
+ subject, _ := claims["sub"].(string)
+
+ identity := &providers.ExternalIdentity{
+ UserID: subject,
+ Email: subject + "@example.com",
+ DisplayName: "Test User",
+ Provider: m.name,
+ TokenExpiration: m.tokenExpiration,
+ }
+
+ return identity, nil
+}
+
+func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
+ return &providers.ExternalIdentity{
+ UserID: userID,
+ Provider: m.name,
+ }, nil
+}
+
+func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
+ claims := &providers.TokenClaims{
+ Subject: "test-user",
+ Issuer: m.name,
+ }
+ if m.tokenExpiration != nil {
+ claims.ExpiresAt = *m.tokenExpiration
+ }
+ return claims, nil
+}
+
+func timePtr(t time.Time) *time.Time {
+ return &t
+}