aboutsummaryrefslogtreecommitdiff
path: root/weed/iam/oidc/oidc_provider_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/iam/oidc/oidc_provider_test.go')
-rw-r--r--weed/iam/oidc/oidc_provider_test.go460
1 files changed, 460 insertions, 0 deletions
diff --git a/weed/iam/oidc/oidc_provider_test.go b/weed/iam/oidc/oidc_provider_test.go
new file mode 100644
index 000000000..d37bee1f0
--- /dev/null
+++ b/weed/iam/oidc/oidc_provider_test.go
@@ -0,0 +1,460 @@
+package oidc
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/seaweedfs/seaweedfs/weed/iam/providers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestOIDCProviderInitialization tests OIDC provider initialization
+func TestOIDCProviderInitialization(t *testing.T) {
+ tests := []struct {
+ name string
+ config *OIDCConfig
+ wantErr bool
+ }{
+ {
+ name: "valid config",
+ config: &OIDCConfig{
+ Issuer: "https://accounts.google.com",
+ ClientID: "test-client-id",
+ JWKSUri: "https://www.googleapis.com/oauth2/v3/certs",
+ },
+ wantErr: false,
+ },
+ {
+ name: "missing issuer",
+ config: &OIDCConfig{
+ ClientID: "test-client-id",
+ },
+ wantErr: true,
+ },
+ {
+ name: "missing client id",
+ config: &OIDCConfig{
+ Issuer: "https://accounts.google.com",
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid issuer url",
+ config: &OIDCConfig{
+ Issuer: "not-a-url",
+ ClientID: "test-client-id",
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider := NewOIDCProvider("test-provider")
+
+ err := provider.Initialize(tt.config)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, "test-provider", provider.Name())
+ }
+ })
+ }
+}
+
+// TestOIDCProviderJWTValidation tests JWT token validation
+func TestOIDCProviderJWTValidation(t *testing.T) {
+ // Set up test server with JWKS endpoint
+ privateKey, publicKey := generateTestKeys(t)
+
+ jwks := map[string]interface{}{
+ "keys": []map[string]interface{}{
+ {
+ "kty": "RSA",
+ "kid": "test-key-id",
+ "use": "sig",
+ "alg": "RS256",
+ "n": encodePublicKey(t, publicKey),
+ "e": "AQAB",
+ },
+ },
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/.well-known/openid_configuration" {
+ config := map[string]interface{}{
+ "issuer": "http://" + r.Host,
+ "jwks_uri": "http://" + r.Host + "/jwks",
+ }
+ json.NewEncoder(w).Encode(config)
+ } else if r.URL.Path == "/jwks" {
+ json.NewEncoder(w).Encode(jwks)
+ }
+ }))
+ defer server.Close()
+
+ provider := NewOIDCProvider("test-oidc")
+ config := &OIDCConfig{
+ Issuer: server.URL,
+ ClientID: "test-client",
+ JWKSUri: server.URL + "/jwks",
+ }
+
+ err := provider.Initialize(config)
+ require.NoError(t, err)
+
+ t.Run("valid token", func(t *testing.T) {
+ // Create valid JWT token
+ token := createTestJWT(t, privateKey, jwt.MapClaims{
+ "iss": server.URL,
+ "aud": "test-client",
+ "sub": "user123",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ "email": "user@example.com",
+ "name": "Test User",
+ })
+
+ claims, err := provider.ValidateToken(context.Background(), token)
+ require.NoError(t, err)
+ require.NotNil(t, claims)
+ assert.Equal(t, "user123", claims.Subject)
+ assert.Equal(t, server.URL, claims.Issuer)
+
+ email, exists := claims.GetClaimString("email")
+ assert.True(t, exists)
+ assert.Equal(t, "user@example.com", email)
+ })
+
+ t.Run("valid token with array audience", func(t *testing.T) {
+ // Create valid JWT token with audience as an array (per RFC 7519)
+ token := createTestJWT(t, privateKey, jwt.MapClaims{
+ "iss": server.URL,
+ "aud": []string{"test-client", "another-client"},
+ "sub": "user456",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ "email": "user2@example.com",
+ "name": "Test User 2",
+ })
+
+ claims, err := provider.ValidateToken(context.Background(), token)
+ require.NoError(t, err)
+ require.NotNil(t, claims)
+ assert.Equal(t, "user456", claims.Subject)
+ assert.Equal(t, server.URL, claims.Issuer)
+
+ email, exists := claims.GetClaimString("email")
+ assert.True(t, exists)
+ assert.Equal(t, "user2@example.com", email)
+ })
+
+ t.Run("expired token", func(t *testing.T) {
+ // Create expired JWT token
+ token := createTestJWT(t, privateKey, jwt.MapClaims{
+ "iss": server.URL,
+ "aud": "test-client",
+ "sub": "user123",
+ "exp": time.Now().Add(-time.Hour).Unix(), // Expired
+ "iat": time.Now().Add(-time.Hour * 2).Unix(),
+ })
+
+ _, err := provider.ValidateToken(context.Background(), token)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expired")
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ // Create token with wrong key
+ wrongKey, _ := generateTestKeys(t)
+ token := createTestJWT(t, wrongKey, jwt.MapClaims{
+ "iss": server.URL,
+ "aud": "test-client",
+ "sub": "user123",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ })
+
+ _, err := provider.ValidateToken(context.Background(), token)
+ assert.Error(t, err)
+ })
+}
+
+// TestOIDCProviderAuthentication tests authentication flow
+func TestOIDCProviderAuthentication(t *testing.T) {
+ // Set up test OIDC provider
+ privateKey, publicKey := generateTestKeys(t)
+
+ server := setupOIDCTestServer(t, publicKey)
+ defer server.Close()
+
+ provider := NewOIDCProvider("test-oidc")
+ config := &OIDCConfig{
+ Issuer: server.URL,
+ ClientID: "test-client",
+ JWKSUri: server.URL + "/jwks",
+ RoleMapping: &providers.RoleMapping{
+ Rules: []providers.MappingRule{
+ {
+ Claim: "email",
+ Value: "*@example.com",
+ Role: "arn:seaweed:iam::role/UserRole",
+ },
+ {
+ Claim: "groups",
+ Value: "admins",
+ Role: "arn:seaweed:iam::role/AdminRole",
+ },
+ },
+ DefaultRole: "arn:seaweed:iam::role/GuestRole",
+ },
+ }
+
+ err := provider.Initialize(config)
+ require.NoError(t, err)
+
+ t.Run("successful authentication", func(t *testing.T) {
+ token := createTestJWT(t, privateKey, jwt.MapClaims{
+ "iss": server.URL,
+ "aud": "test-client",
+ "sub": "user123",
+ "exp": time.Now().Add(time.Hour).Unix(),
+ "iat": time.Now().Unix(),
+ "email": "user@example.com",
+ "name": "Test User",
+ "groups": []string{"users", "developers"},
+ })
+
+ identity, err := provider.Authenticate(context.Background(), token)
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "user123", identity.UserID)
+ assert.Equal(t, "user@example.com", identity.Email)
+ assert.Equal(t, "Test User", identity.DisplayName)
+ assert.Equal(t, "test-oidc", identity.Provider)
+ assert.Contains(t, identity.Groups, "users")
+ assert.Contains(t, identity.Groups, "developers")
+ })
+
+ t.Run("authentication with invalid token", func(t *testing.T) {
+ _, err := provider.Authenticate(context.Background(), "invalid-token")
+ assert.Error(t, err)
+ })
+}
+
+// TestOIDCProviderUserInfo tests user info retrieval
+func TestOIDCProviderUserInfo(t *testing.T) {
+ // Set up test server with UserInfo endpoint
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/userinfo" {
+ // Check for Authorization header
+ authHeader := r.Header.Get("Authorization")
+ if !strings.HasPrefix(authHeader, "Bearer ") {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(`{"error": "unauthorized"}`))
+ return
+ }
+
+ accessToken := strings.TrimPrefix(authHeader, "Bearer ")
+
+ // Return 401 for explicitly invalid tokens
+ if accessToken == "invalid-token" {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(`{"error": "invalid_token"}`))
+ return
+ }
+
+ // Mock user info response
+ userInfo := map[string]interface{}{
+ "sub": "user123",
+ "email": "user@example.com",
+ "name": "Test User",
+ "groups": []string{"users", "developers"},
+ }
+
+ // Customize response based on token
+ if strings.Contains(accessToken, "admin") {
+ userInfo["groups"] = []string{"admins"}
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(userInfo)
+ }
+ }))
+ defer server.Close()
+
+ provider := NewOIDCProvider("test-oidc")
+ config := &OIDCConfig{
+ Issuer: server.URL,
+ ClientID: "test-client",
+ UserInfoUri: server.URL + "/userinfo",
+ }
+
+ err := provider.Initialize(config)
+ require.NoError(t, err)
+
+ t.Run("get user info with access token", func(t *testing.T) {
+ // Test using access token (real UserInfo endpoint call)
+ identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token")
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "user123", identity.UserID)
+ assert.Equal(t, "user@example.com", identity.Email)
+ assert.Equal(t, "Test User", identity.DisplayName)
+ assert.Contains(t, identity.Groups, "users")
+ assert.Contains(t, identity.Groups, "developers")
+ assert.Equal(t, "test-oidc", identity.Provider)
+ })
+
+ t.Run("get admin user info", func(t *testing.T) {
+ // Test admin token response
+ identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token")
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "user123", identity.UserID)
+ assert.Contains(t, identity.Groups, "admins")
+ })
+
+ t.Run("get user info without token", func(t *testing.T) {
+ // Test without access token (should fail)
+ _, err := provider.GetUserInfoWithToken(context.Background(), "")
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "access token cannot be empty")
+ })
+
+ t.Run("get user info with invalid token", func(t *testing.T) {
+ // Test with invalid access token (should get 401)
+ _, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token")
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401")
+ })
+
+ t.Run("get user info with custom claims mapping", func(t *testing.T) {
+ // Create provider with custom claims mapping
+ customProvider := NewOIDCProvider("test-custom-oidc")
+ customConfig := &OIDCConfig{
+ Issuer: server.URL,
+ ClientID: "test-client",
+ UserInfoUri: server.URL + "/userinfo",
+ ClaimsMapping: map[string]string{
+ "customEmail": "email",
+ "customName": "name",
+ },
+ }
+
+ err := customProvider.Initialize(customConfig)
+ require.NoError(t, err)
+
+ identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token")
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+
+ // Standard claims should still work
+ assert.Equal(t, "user123", identity.UserID)
+ assert.Equal(t, "user@example.com", identity.Email)
+ assert.Equal(t, "Test User", identity.DisplayName)
+ })
+
+ t.Run("get user info with empty id", func(t *testing.T) {
+ _, err := provider.GetUserInfo(context.Background(), "")
+ assert.Error(t, err)
+ })
+}
+
+// Helper functions for testing
+
+func generateTestKeys(t *testing.T) (*rsa.PrivateKey, *rsa.PublicKey) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+ return privateKey, &privateKey.PublicKey
+}
+
+func createTestJWT(t *testing.T, privateKey *rsa.PrivateKey, claims jwt.MapClaims) string {
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = "test-key-id"
+
+ tokenString, err := token.SignedString(privateKey)
+ require.NoError(t, err)
+ return tokenString
+}
+
+func encodePublicKey(t *testing.T, publicKey *rsa.PublicKey) string {
+ // Properly encode the RSA modulus (N) as base64url
+ return base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes())
+}
+
+func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Server {
+ jwks := map[string]interface{}{
+ "keys": []map[string]interface{}{
+ {
+ "kty": "RSA",
+ "kid": "test-key-id",
+ "use": "sig",
+ "alg": "RS256",
+ "n": encodePublicKey(t, publicKey),
+ "e": "AQAB",
+ },
+ },
+ }
+
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/.well-known/openid_configuration":
+ config := map[string]interface{}{
+ "issuer": "http://" + r.Host,
+ "jwks_uri": "http://" + r.Host + "/jwks",
+ "userinfo_endpoint": "http://" + r.Host + "/userinfo",
+ }
+ json.NewEncoder(w).Encode(config)
+ case "/jwks":
+ json.NewEncoder(w).Encode(jwks)
+ case "/userinfo":
+ // Mock UserInfo endpoint
+ authHeader := r.Header.Get("Authorization")
+ if !strings.HasPrefix(authHeader, "Bearer ") {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(`{"error": "unauthorized"}`))
+ return
+ }
+
+ accessToken := strings.TrimPrefix(authHeader, "Bearer ")
+
+ // Return 401 for explicitly invalid tokens
+ if accessToken == "invalid-token" {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(`{"error": "invalid_token"}`))
+ return
+ }
+
+ // Mock user info response based on access token
+ userInfo := map[string]interface{}{
+ "sub": "user123",
+ "email": "user@example.com",
+ "name": "Test User",
+ "groups": []string{"users", "developers"},
+ }
+
+ // Customize response based on token
+ if strings.Contains(accessToken, "admin") {
+ userInfo["groups"] = []string{"admins"}
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(userInfo)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+}