diff options
| author | Chris Lu <chrislusf@users.noreply.github.com> | 2025-07-15 00:23:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-15 00:23:54 -0700 |
| commit | 4b040e8a8701199d4c680bb6f241c4751c8210a2 (patch) | |
| tree | 45d76546220c8d6f3287e3f5498ddf598079cc8e /weed/s3api/cors/cors_test.go | |
| parent | 548fa0b50a2a57de538d6f6961bfe819128d0ee5 (diff) | |
| download | seaweedfs-4b040e8a8701199d4c680bb6f241c4751c8210a2.tar.xz seaweedfs-4b040e8a8701199d4c680bb6f241c4751c8210a2.zip | |
adding cors support (#6987)
* adding cors support
* address some comments
* optimize matchesWildcard
* address comments
* fix for tests
* address comments
* address comments
* address comments
* path building
* refactor
* Update weed/s3api/s3api_bucket_config.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* address comment
Service-level responses need both Access-Control-Allow-Methods and Access-Control-Allow-Headers. After setting Access-Control-Allow-Origin and Access-Control-Expose-Headers, also set Access-Control-Allow-Methods: * and Access-Control-Allow-Headers: * so service endpoints satisfy CORS preflight requirements.
* Update weed/s3api/s3api_bucket_config.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update weed/s3api/s3api_object_handlers.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update weed/s3api/s3api_object_handlers.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* fix
* refactor
* Update weed/s3api/s3api_bucket_config.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update weed/s3api/s3api_object_handlers.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update weed/s3api/s3api_server.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* simplify
* add cors tests
* fix tests
* fix tests
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Diffstat (limited to 'weed/s3api/cors/cors_test.go')
| -rw-r--r-- | weed/s3api/cors/cors_test.go | 526 |
1 files changed, 526 insertions, 0 deletions
diff --git a/weed/s3api/cors/cors_test.go b/weed/s3api/cors/cors_test.go new file mode 100644 index 000000000..1b5c54028 --- /dev/null +++ b/weed/s3api/cors/cors_test.go @@ -0,0 +1,526 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestValidateConfiguration(t *testing.T) { + tests := []struct { + name string + config *CORSConfiguration + wantErr bool + }{ + { + name: "nil config", + config: nil, + wantErr: true, + }, + { + name: "empty rules", + config: &CORSConfiguration{ + CORSRules: []CORSRule{}, + }, + wantErr: true, + }, + { + name: "valid single rule", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"*"}, + }, + }, + }, + wantErr: false, + }, + { + name: "too many rules", + config: &CORSConfiguration{ + CORSRules: make([]CORSRule, 101), + }, + wantErr: true, + }, + { + name: "invalid method", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"INVALID"}, + AllowedOrigins: []string{"*"}, + }, + }, + }, + wantErr: true, + }, + { + name: "empty origins", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{}, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid origin with multiple wildcards", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{"http://*.*.example.com"}, + }, + }, + }, + wantErr: true, + }, + { + name: "negative MaxAgeSeconds", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{"*"}, + MaxAgeSeconds: intPtr(-1), + }, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateConfiguration(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateOrigin(t *testing.T) { + tests := []struct { + name string + origin string + wantErr bool + }{ + { + name: "empty origin", + origin: "", + wantErr: true, + }, + { + name: "valid origin", + origin: "http://example.com", + wantErr: false, + }, + { + name: "wildcard origin", + origin: "*", + wantErr: false, + }, + { + name: "valid wildcard origin", + origin: "http://*.example.com", + wantErr: false, + }, + { + name: "https wildcard origin", + origin: "https://*.example.com", + wantErr: false, + }, + { + name: "invalid wildcard origin", + origin: "*.example.com", + wantErr: true, + }, + { + name: "multiple wildcards", + origin: "http://*.*.example.com", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateOrigin(tt.origin) + if (err != nil) != tt.wantErr { + t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestParseRequest(t *testing.T) { + tests := []struct { + name string + req *http.Request + want *CORSRequest + }{ + { + name: "simple GET request", + req: &http.Request{ + Method: "GET", + Header: http.Header{ + "Origin": []string{"http://example.com"}, + }, + }, + want: &CORSRequest{ + Origin: "http://example.com", + Method: "GET", + IsPreflightRequest: false, + }, + }, + { + name: "OPTIONS preflight request", + req: &http.Request{ + Method: "OPTIONS", + Header: http.Header{ + "Origin": []string{"http://example.com"}, + "Access-Control-Request-Method": []string{"PUT"}, + "Access-Control-Request-Headers": []string{"Content-Type, Authorization"}, + }, + }, + want: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "PUT", + AccessControlRequestHeaders: []string{"Content-Type", "Authorization"}, + }, + }, + { + name: "request without origin", + req: &http.Request{ + Method: "GET", + Header: http.Header{}, + }, + want: &CORSRequest{ + Origin: "", + Method: "GET", + IsPreflightRequest: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseRequest(tt.req) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMatchesOrigin(t *testing.T) { + tests := []struct { + name string + allowedOrigins []string + origin string + want bool + }{ + { + name: "wildcard match", + allowedOrigins: []string{"*"}, + origin: "http://example.com", + want: true, + }, + { + name: "exact match", + allowedOrigins: []string{"http://example.com"}, + origin: "http://example.com", + want: true, + }, + { + name: "no match", + allowedOrigins: []string{"http://example.com"}, + origin: "http://other.com", + want: false, + }, + { + name: "wildcard subdomain match", + allowedOrigins: []string{"http://*.example.com"}, + origin: "http://api.example.com", + want: true, + }, + { + name: "wildcard subdomain no match", + allowedOrigins: []string{"http://*.example.com"}, + origin: "http://example.com", + want: false, + }, + { + name: "multiple origins with match", + allowedOrigins: []string{"http://example.com", "http://other.com"}, + origin: "http://other.com", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesOrigin(tt.allowedOrigins, tt.origin) + if got != tt.want { + t.Errorf("matchesOrigin() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMatchesHeader(t *testing.T) { + tests := []struct { + name string + allowedHeaders []string + header string + want bool + }{ + { + name: "empty allowed headers", + allowedHeaders: []string{}, + header: "Content-Type", + want: true, + }, + { + name: "wildcard match", + allowedHeaders: []string{"*"}, + header: "Content-Type", + want: true, + }, + { + name: "exact match", + allowedHeaders: []string{"Content-Type"}, + header: "Content-Type", + want: true, + }, + { + name: "case insensitive match", + allowedHeaders: []string{"content-type"}, + header: "Content-Type", + want: true, + }, + { + name: "no match", + allowedHeaders: []string{"Authorization"}, + header: "Content-Type", + want: false, + }, + { + name: "wildcard prefix match", + allowedHeaders: []string{"x-amz-*"}, + header: "x-amz-date", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesHeader(tt.allowedHeaders, tt.header) + if got != tt.want { + t.Errorf("matchesHeader() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEvaluateRequest(t *testing.T) { + config := &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"http://example.com"}, + AllowedHeaders: []string{"Content-Type"}, + ExposeHeaders: []string{"ETag"}, + MaxAgeSeconds: intPtr(3600), + }, + { + AllowedMethods: []string{"PUT"}, + AllowedOrigins: []string{"*"}, + }, + }, + } + + tests := []struct { + name string + config *CORSConfiguration + corsReq *CORSRequest + want *CORSResponse + wantErr bool + }{ + { + name: "matching first rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "GET", + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + wantErr: false, + }, + { + name: "matching second rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://other.com", + Method: "PUT", + }, + want: &CORSResponse{ + AllowOrigin: "http://other.com", + AllowMethods: "PUT", + }, + wantErr: false, + }, + { + name: "no matching rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://forbidden.com", + Method: "GET", + }, + want: nil, + wantErr: true, + }, + { + name: "preflight request", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "POST", + AccessControlRequestHeaders: []string{"Content-Type"}, + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + wantErr: false, + }, + { + name: "preflight request with forbidden header", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "POST", + AccessControlRequestHeaders: []string{"Authorization"}, + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + // No AllowMethods or AllowHeaders because the requested header is forbidden + }, + wantErr: false, + }, + { + name: "request without origin", + config: config, + corsReq: &CORSRequest{ + Origin: "", + Method: "GET", + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := EvaluateRequest(tt.config, tt.corsReq) + if (err != nil) != tt.wantErr { + t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestApplyHeaders(t *testing.T) { + tests := []struct { + name string + corsResp *CORSResponse + want map[string]string + }{ + { + name: "nil response", + corsResp: nil, + want: map[string]string{}, + }, + { + name: "complete response", + corsResp: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + want: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET, POST", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Expose-Headers": "ETag", + "Access-Control-Max-Age": "3600", + }, + }, + { + name: "with credentials", + corsResp: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET", + AllowCredentials: true, + }, + want: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET", + "Access-Control-Allow-Credentials": "true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a proper response writer using httptest + w := httptest.NewRecorder() + + ApplyHeaders(w, tt.corsResp) + + // Extract headers from the response + headers := make(map[string]string) + for key, values := range w.Header() { + if len(values) > 0 { + headers[key] = values[0] + } + } + + if !reflect.DeepEqual(headers, tt.want) { + t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want) + } + }) + } +} + +// Helper functions and types for testing + +func intPtr(i int) *int { + return &i +} |
