aboutsummaryrefslogtreecommitdiff
path: root/weed/s3api/cors/cors_test.go
diff options
context:
space:
mode:
authorChris Lu <chrislusf@users.noreply.github.com>2025-07-15 00:23:54 -0700
committerGitHub <noreply@github.com>2025-07-15 00:23:54 -0700
commit4b040e8a8701199d4c680bb6f241c4751c8210a2 (patch)
tree45d76546220c8d6f3287e3f5498ddf598079cc8e /weed/s3api/cors/cors_test.go
parent548fa0b50a2a57de538d6f6961bfe819128d0ee5 (diff)
downloadseaweedfs-4b040e8a8701199d4c680bb6f241c4751c8210a2.tar.xz
seaweedfs-4b040e8a8701199d4c680bb6f241c4751c8210a2.zip
adding cors support (#6987)
* adding cors support * address some comments * optimize matchesWildcard * address comments * fix for tests * address comments * address comments * address comments * path building * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * address comment Service-level responses need both Access-Control-Allow-Methods and Access-Control-Allow-Headers. After setting Access-Control-Allow-Origin and Access-Control-Expose-Headers, also set Access-Control-Allow-Methods: * and Access-Control-Allow-Headers: * so service endpoints satisfy CORS preflight requirements. * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * refactor * Update weed/s3api/s3api_bucket_config.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_object_handlers.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update weed/s3api/s3api_server.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify * add cors tests * fix tests * fix tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Diffstat (limited to 'weed/s3api/cors/cors_test.go')
-rw-r--r--weed/s3api/cors/cors_test.go526
1 files changed, 526 insertions, 0 deletions
diff --git a/weed/s3api/cors/cors_test.go b/weed/s3api/cors/cors_test.go
new file mode 100644
index 000000000..1b5c54028
--- /dev/null
+++ b/weed/s3api/cors/cors_test.go
@@ -0,0 +1,526 @@
+package cors
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "testing"
+)
+
+func TestValidateConfiguration(t *testing.T) {
+ tests := []struct {
+ name string
+ config *CORSConfiguration
+ wantErr bool
+ }{
+ {
+ name: "nil config",
+ config: nil,
+ wantErr: true,
+ },
+ {
+ name: "empty rules",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "valid single rule",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET", "POST"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "too many rules",
+ config: &CORSConfiguration{
+ CORSRules: make([]CORSRule, 101),
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid method",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"INVALID"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "empty origins",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid origin with multiple wildcards",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{"http://*.*.example.com"},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "negative MaxAgeSeconds",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{"*"},
+ MaxAgeSeconds: intPtr(-1),
+ },
+ },
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ValidateConfiguration(tt.config)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestValidateOrigin(t *testing.T) {
+ tests := []struct {
+ name string
+ origin string
+ wantErr bool
+ }{
+ {
+ name: "empty origin",
+ origin: "",
+ wantErr: true,
+ },
+ {
+ name: "valid origin",
+ origin: "http://example.com",
+ wantErr: false,
+ },
+ {
+ name: "wildcard origin",
+ origin: "*",
+ wantErr: false,
+ },
+ {
+ name: "valid wildcard origin",
+ origin: "http://*.example.com",
+ wantErr: false,
+ },
+ {
+ name: "https wildcard origin",
+ origin: "https://*.example.com",
+ wantErr: false,
+ },
+ {
+ name: "invalid wildcard origin",
+ origin: "*.example.com",
+ wantErr: true,
+ },
+ {
+ name: "multiple wildcards",
+ origin: "http://*.*.example.com",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateOrigin(tt.origin)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestParseRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ req *http.Request
+ want *CORSRequest
+ }{
+ {
+ name: "simple GET request",
+ req: &http.Request{
+ Method: "GET",
+ Header: http.Header{
+ "Origin": []string{"http://example.com"},
+ },
+ },
+ want: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "GET",
+ IsPreflightRequest: false,
+ },
+ },
+ {
+ name: "OPTIONS preflight request",
+ req: &http.Request{
+ Method: "OPTIONS",
+ Header: http.Header{
+ "Origin": []string{"http://example.com"},
+ "Access-Control-Request-Method": []string{"PUT"},
+ "Access-Control-Request-Headers": []string{"Content-Type, Authorization"},
+ },
+ },
+ want: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "PUT",
+ AccessControlRequestHeaders: []string{"Content-Type", "Authorization"},
+ },
+ },
+ {
+ name: "request without origin",
+ req: &http.Request{
+ Method: "GET",
+ Header: http.Header{},
+ },
+ want: &CORSRequest{
+ Origin: "",
+ Method: "GET",
+ IsPreflightRequest: false,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ParseRequest(tt.req)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("ParseRequest() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMatchesOrigin(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedOrigins []string
+ origin string
+ want bool
+ }{
+ {
+ name: "wildcard match",
+ allowedOrigins: []string{"*"},
+ origin: "http://example.com",
+ want: true,
+ },
+ {
+ name: "exact match",
+ allowedOrigins: []string{"http://example.com"},
+ origin: "http://example.com",
+ want: true,
+ },
+ {
+ name: "no match",
+ allowedOrigins: []string{"http://example.com"},
+ origin: "http://other.com",
+ want: false,
+ },
+ {
+ name: "wildcard subdomain match",
+ allowedOrigins: []string{"http://*.example.com"},
+ origin: "http://api.example.com",
+ want: true,
+ },
+ {
+ name: "wildcard subdomain no match",
+ allowedOrigins: []string{"http://*.example.com"},
+ origin: "http://example.com",
+ want: false,
+ },
+ {
+ name: "multiple origins with match",
+ allowedOrigins: []string{"http://example.com", "http://other.com"},
+ origin: "http://other.com",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := matchesOrigin(tt.allowedOrigins, tt.origin)
+ if got != tt.want {
+ t.Errorf("matchesOrigin() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMatchesHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedHeaders []string
+ header string
+ want bool
+ }{
+ {
+ name: "empty allowed headers",
+ allowedHeaders: []string{},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "wildcard match",
+ allowedHeaders: []string{"*"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "exact match",
+ allowedHeaders: []string{"Content-Type"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "case insensitive match",
+ allowedHeaders: []string{"content-type"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "no match",
+ allowedHeaders: []string{"Authorization"},
+ header: "Content-Type",
+ want: false,
+ },
+ {
+ name: "wildcard prefix match",
+ allowedHeaders: []string{"x-amz-*"},
+ header: "x-amz-date",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := matchesHeader(tt.allowedHeaders, tt.header)
+ if got != tt.want {
+ t.Errorf("matchesHeader() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEvaluateRequest(t *testing.T) {
+ config := &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET", "POST"},
+ AllowedOrigins: []string{"http://example.com"},
+ AllowedHeaders: []string{"Content-Type"},
+ ExposeHeaders: []string{"ETag"},
+ MaxAgeSeconds: intPtr(3600),
+ },
+ {
+ AllowedMethods: []string{"PUT"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ config *CORSConfiguration
+ corsReq *CORSRequest
+ want *CORSResponse
+ wantErr bool
+ }{
+ {
+ name: "matching first rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "GET",
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ wantErr: false,
+ },
+ {
+ name: "matching second rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://other.com",
+ Method: "PUT",
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://other.com",
+ AllowMethods: "PUT",
+ },
+ wantErr: false,
+ },
+ {
+ name: "no matching rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://forbidden.com",
+ Method: "GET",
+ },
+ want: nil,
+ wantErr: true,
+ },
+ {
+ name: "preflight request",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "POST",
+ AccessControlRequestHeaders: []string{"Content-Type"},
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ wantErr: false,
+ },
+ {
+ name: "preflight request with forbidden header",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "POST",
+ AccessControlRequestHeaders: []string{"Authorization"},
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ // No AllowMethods or AllowHeaders because the requested header is forbidden
+ },
+ wantErr: false,
+ },
+ {
+ name: "request without origin",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "",
+ Method: "GET",
+ },
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := EvaluateRequest(tt.config, tt.corsReq)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestApplyHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ corsResp *CORSResponse
+ want map[string]string
+ }{
+ {
+ name: "nil response",
+ corsResp: nil,
+ want: map[string]string{},
+ },
+ {
+ name: "complete response",
+ corsResp: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ want: map[string]string{
+ "Access-Control-Allow-Origin": "http://example.com",
+ "Access-Control-Allow-Methods": "GET, POST",
+ "Access-Control-Allow-Headers": "Content-Type",
+ "Access-Control-Expose-Headers": "ETag",
+ "Access-Control-Max-Age": "3600",
+ },
+ },
+ {
+ name: "with credentials",
+ corsResp: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET",
+ AllowCredentials: true,
+ },
+ want: map[string]string{
+ "Access-Control-Allow-Origin": "http://example.com",
+ "Access-Control-Allow-Methods": "GET",
+ "Access-Control-Allow-Credentials": "true",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a proper response writer using httptest
+ w := httptest.NewRecorder()
+
+ ApplyHeaders(w, tt.corsResp)
+
+ // Extract headers from the response
+ headers := make(map[string]string)
+ for key, values := range w.Header() {
+ if len(values) > 0 {
+ headers[key] = values[0]
+ }
+ }
+
+ if !reflect.DeepEqual(headers, tt.want) {
+ t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want)
+ }
+ })
+ }
+}
+
+// Helper functions and types for testing
+
+func intPtr(i int) *int {
+ return &i
+}