diff options
Diffstat (limited to 'weed')
| -rw-r--r-- | weed/s3api/cors/cors.go | 649 | ||||
| -rw-r--r-- | weed/s3api/cors/cors_test.go | 526 | ||||
| -rw-r--r-- | weed/s3api/cors/middleware.go | 143 | ||||
| -rw-r--r-- | weed/s3api/s3api_bucket_config.go | 129 | ||||
| -rw-r--r-- | weed/s3api/s3api_bucket_cors_handlers.go | 140 | ||||
| -rw-r--r-- | weed/s3api/s3api_bucket_skip_handlers.go | 18 | ||||
| -rw-r--r-- | weed/s3api/s3api_object_handlers.go | 35 | ||||
| -rw-r--r-- | weed/s3api/s3api_server.go | 84 | ||||
| -rw-r--r-- | weed/s3api/s3err/error_handler.go | 28 |
9 files changed, 1701 insertions, 51 deletions
diff --git a/weed/s3api/cors/cors.go b/weed/s3api/cors/cors.go new file mode 100644 index 000000000..1eef71b72 --- /dev/null +++ b/weed/s3api/cors/cors.go @@ -0,0 +1,649 @@ +package cors + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" +) + +// S3 metadata file name constant to avoid typos and reduce duplication +const S3MetadataFileName = ".s3metadata" + +// CORSRule represents a single CORS rule +type CORSRule struct { + ID string `xml:"ID,omitempty" json:"ID,omitempty"` + AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"` + AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"` + AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"` + ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"` + MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"` +} + +// CORSConfiguration represents the CORS configuration for a bucket +type CORSConfiguration struct { + XMLName xml.Name `xml:"CORSConfiguration"` + CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"` +} + +// CORSRequest represents a CORS request +type CORSRequest struct { + Origin string + Method string + RequestHeaders []string + IsPreflightRequest bool + AccessControlRequestMethod string + AccessControlRequestHeaders []string +} + +// CORSResponse represents CORS response headers +type CORSResponse struct { + AllowOrigin string + AllowMethods string + AllowHeaders string + ExposeHeaders string + MaxAge string + AllowCredentials bool +} + +// ValidateConfiguration validates a CORS configuration +func ValidateConfiguration(config *CORSConfiguration) error { + if config == nil { + return fmt.Errorf("CORS configuration cannot be nil") + } + + if len(config.CORSRules) == 0 { + return fmt.Errorf("CORS configuration must have at least one rule") + } + + if len(config.CORSRules) > 100 { + return fmt.Errorf("CORS configuration cannot have more than 100 rules") + } + + for i, rule := range config.CORSRules { + if err := validateRule(&rule); err != nil { + return fmt.Errorf("invalid CORS rule at index %d: %v", i, err) + } + } + + return nil +} + +// validateRule validates a single CORS rule +func validateRule(rule *CORSRule) error { + if len(rule.AllowedMethods) == 0 { + return fmt.Errorf("AllowedMethods cannot be empty") + } + + if len(rule.AllowedOrigins) == 0 { + return fmt.Errorf("AllowedOrigins cannot be empty") + } + + // Validate allowed methods + validMethods := map[string]bool{ + "GET": true, + "PUT": true, + "POST": true, + "DELETE": true, + "HEAD": true, + } + + for _, method := range rule.AllowedMethods { + if !validMethods[method] { + return fmt.Errorf("invalid HTTP method: %s", method) + } + } + + // Validate origins + for _, origin := range rule.AllowedOrigins { + if origin == "*" { + continue + } + if err := validateOrigin(origin); err != nil { + return fmt.Errorf("invalid origin %s: %v", origin, err) + } + } + + // Validate MaxAgeSeconds + if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 { + return fmt.Errorf("MaxAgeSeconds cannot be negative") + } + + return nil +} + +// validateOrigin validates an origin string +func validateOrigin(origin string) error { + if origin == "" { + return fmt.Errorf("origin cannot be empty") + } + + // Special case: "*" is always valid + if origin == "*" { + return nil + } + + // Count wildcards + wildcardCount := strings.Count(origin, "*") + if wildcardCount > 1 { + return fmt.Errorf("origin can contain at most one wildcard") + } + + // If there's a wildcard, it should be in a valid position + if wildcardCount == 1 { + // Must be in the format: http://*.example.com or https://*.example.com + if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { + return fmt.Errorf("origin with wildcard must start with http:// or https://") + } + } + + return nil +} + +// ParseRequest parses an HTTP request to extract CORS information +func ParseRequest(r *http.Request) *CORSRequest { + corsReq := &CORSRequest{ + Origin: r.Header.Get("Origin"), + Method: r.Method, + } + + // Check if this is a preflight request + if r.Method == "OPTIONS" { + corsReq.IsPreflightRequest = true + corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method") + + if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" { + corsReq.AccessControlRequestHeaders = strings.Split(headers, ",") + for i := range corsReq.AccessControlRequestHeaders { + corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i]) + } + } + } + + return corsReq +} + +// EvaluateRequest evaluates a CORS request against a CORS configuration +func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) { + if config == nil || corsReq == nil { + return nil, fmt.Errorf("config and corsReq cannot be nil") + } + + if corsReq.Origin == "" { + return nil, fmt.Errorf("origin header is required for CORS requests") + } + + // Find the first rule that matches the origin + for _, rule := range config.CORSRules { + if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) { + // For preflight requests, we need more detailed validation + if corsReq.IsPreflightRequest { + return buildPreflightResponse(&rule, corsReq), nil + } else { + // For actual requests, check method + if contains(rule.AllowedMethods, corsReq.Method) { + return buildResponse(&rule, corsReq), nil + } + } + } + } + + return nil, fmt.Errorf("no matching CORS rule found") +} + +// matchesRule checks if a CORS request matches a CORS rule +func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool { + // Check origin - this is the primary matching criterion + if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) { + return false + } + + // For preflight requests, we need to validate both the requested method and headers + if corsReq.IsPreflightRequest { + // Check if the requested method is allowed + if corsReq.AccessControlRequestMethod != "" { + if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) { + return false + } + } + + // Check if all requested headers are allowed + if len(corsReq.AccessControlRequestHeaders) > 0 { + for _, requestedHeader := range corsReq.AccessControlRequestHeaders { + if !matchesHeader(rule.AllowedHeaders, requestedHeader) { + return false + } + } + } + + return true + } + + // For non-preflight requests, check method matching + method := corsReq.Method + if !contains(rule.AllowedMethods, method) { + return false + } + + return true +} + +// matchesOrigin checks if an origin matches any of the allowed origins +func matchesOrigin(allowedOrigins []string, origin string) bool { + for _, allowedOrigin := range allowedOrigins { + if allowedOrigin == "*" { + return true + } + + if allowedOrigin == origin { + return true + } + + // Check wildcard matching + if strings.Contains(allowedOrigin, "*") { + if matchesWildcard(allowedOrigin, origin) { + return true + } + } + } + return false +} + +// matchesWildcard checks if an origin matches a wildcard pattern +// Uses string manipulation instead of regex for better performance +func matchesWildcard(pattern, origin string) bool { + // Handle simple cases first + if pattern == "*" { + return true + } + if pattern == origin { + return true + } + + // For CORS, we typically only deal with * wildcards (not ? wildcards) + // Use string manipulation for * wildcards only (more efficient than regex) + + // Split pattern by wildcards + parts := strings.Split(pattern, "*") + if len(parts) == 1 { + // No wildcards, exact match + return pattern == origin + } + + // Check if string starts with first part + if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) { + return false + } + + // Check if string ends with last part + if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) { + return false + } + + // Check middle parts + searchStr := origin + if len(parts[0]) > 0 { + searchStr = searchStr[len(parts[0]):] + } + if len(parts[len(parts)-1]) > 0 { + searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])] + } + + for i := 1; i < len(parts)-1; i++ { + if len(parts[i]) > 0 { + index := strings.Index(searchStr, parts[i]) + if index == -1 { + return false + } + searchStr = searchStr[index+len(parts[i]):] + } + } + + return true +} + +// matchesHeader checks if a header matches allowed headers +func matchesHeader(allowedHeaders []string, header string) bool { + if len(allowedHeaders) == 0 { + return true // No restrictions + } + + for _, allowedHeader := range allowedHeaders { + if allowedHeader == "*" { + return true + } + + if strings.EqualFold(allowedHeader, header) { + return true + } + + // Check wildcard matching for headers + if strings.Contains(allowedHeader, "*") { + if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) { + return true + } + } + } + + return false +} + +// buildPreflightResponse builds a CORS response for preflight requests +// This function allows partial matches - origin can match while methods/headers may not +func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse { + response := &CORSResponse{ + AllowOrigin: corsReq.Origin, + } + + // Check if the requested method is allowed + methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) + + // Check requested headers + var allowedRequestHeaders []string + allHeadersAllowed := true + + if len(corsReq.AccessControlRequestHeaders) > 0 { + // Check if wildcard is allowed + hasWildcard := false + for _, header := range rule.AllowedHeaders { + if header == "*" { + hasWildcard = true + break + } + } + + if hasWildcard { + // All requested headers are allowed with wildcard + allowedRequestHeaders = corsReq.AccessControlRequestHeaders + } else { + // Check each requested header individually + for _, requestedHeader := range corsReq.AccessControlRequestHeaders { + if matchesHeader(rule.AllowedHeaders, requestedHeader) { + allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader) + } else { + allHeadersAllowed = false + } + } + } + } + + // Only set method and header info if both method and ALL headers are allowed + if methodAllowed && allHeadersAllowed { + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + + if len(allowedRequestHeaders) > 0 { + response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ") + } + + // Set exposed headers + if len(rule.ExposeHeaders) > 0 { + response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ") + } + + // Set max age + if rule.MaxAgeSeconds != nil { + response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds) + } + } + + return response +} + +// buildResponse builds a CORS response from a matching rule +func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse { + response := &CORSResponse{ + AllowOrigin: corsReq.Origin, + } + + // Set allowed methods - for preflight requests, return all allowed methods + if corsReq.IsPreflightRequest { + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + } else { + // For non-preflight requests, return all allowed methods + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + } + + // Set allowed headers + if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 { + // For preflight requests, check if wildcard is allowed + hasWildcard := false + for _, header := range rule.AllowedHeaders { + if header == "*" { + hasWildcard = true + break + } + } + + if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 { + // Return the specific headers that were requested when wildcard is allowed + response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ") + } else if len(corsReq.AccessControlRequestHeaders) > 0 { + // For non-wildcard cases, return the requested headers (preserving case) + // since we already validated they are allowed in matchesRule + response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ") + } else { + // Fallback to configured headers if no specific headers were requested + response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ") + } + } else if len(rule.AllowedHeaders) > 0 { + // For non-preflight requests, return the allowed headers from the rule + response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ") + } + + // Set exposed headers + if len(rule.ExposeHeaders) > 0 { + response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ") + } + + // Set max age + if rule.MaxAgeSeconds != nil { + response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds) + } + + return response +} + +// contains checks if a slice contains a string +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// ApplyHeaders applies CORS headers to an HTTP response +func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) { + if corsResp == nil { + return + } + + if corsResp.AllowOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin) + } + + if corsResp.AllowMethods != "" { + w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods) + } + + if corsResp.AllowHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders) + } + + if corsResp.ExposeHeaders != "" { + w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders) + } + + if corsResp.MaxAge != "" { + w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge) + } + + if corsResp.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } +} + +// FilerClient interface for dependency injection +type FilerClient interface { + WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error +} + +// EntryGetter interface for getting filer entries +type EntryGetter interface { + GetEntry(directory, name string) (*filer_pb.Entry, error) +} + +// Storage provides CORS configuration storage operations +type Storage struct { + filerClient FilerClient + entryGetter EntryGetter + bucketsPath string +} + +// NewStorage creates a new CORS storage instance +func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage { + return &Storage{ + filerClient: filerClient, + entryGetter: entryGetter, + bucketsPath: bucketsPath, + } +} + +// Store stores CORS configuration in the filer +func (s *Storage) Store(bucket string, config *CORSConfiguration) error { + // Store in bucket metadata + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + // Get existing metadata + existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + var metadata map[string]interface{} + + if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 { + if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil { + glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err) + metadata = make(map[string]interface{}) + } + } else { + metadata = make(map[string]interface{}) + } + + metadata["cors"] = config + + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal bucket metadata: %v", err) + } + + // Store metadata + return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { + request := &filer_pb.CreateEntryRequest{ + Directory: s.bucketsPath + "/" + bucket, + Entry: &filer_pb.Entry{ + Name: S3MetadataFileName, + IsDirectory: false, + Attributes: &filer_pb.FuseAttributes{ + Crtime: time.Now().Unix(), + Mtime: time.Now().Unix(), + FileMode: 0644, + }, + Content: metadataBytes, + }, + } + + _, err := client.CreateEntry(context.Background(), request) + return err + }) +} + +// Load loads CORS configuration from the filer +func (s *Storage) Load(bucket string) (*CORSConfiguration, error) { + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + entry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + if err != nil || entry == nil { + return nil, fmt.Errorf("no CORS configuration found") + } + + if len(entry.Content) == 0 { + return nil, fmt.Errorf("no CORS configuration found") + } + + var metadata map[string]interface{} + if err := json.Unmarshal(entry.Content, &metadata); err != nil { + return nil, fmt.Errorf("failed to unmarshal metadata: %v", err) + } + + corsData, exists := metadata["cors"] + if !exists { + return nil, fmt.Errorf("no CORS configuration found") + } + + // Convert back to CORSConfiguration + corsBytes, err := json.Marshal(corsData) + if err != nil { + return nil, fmt.Errorf("failed to marshal CORS data: %v", err) + } + + var config CORSConfiguration + if err := json.Unmarshal(corsBytes, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err) + } + + return &config, nil +} + +// Delete deletes CORS configuration from the filer +func (s *Storage) Delete(bucket string) error { + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + entry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + if err != nil || entry == nil { + return nil // Already deleted or doesn't exist + } + + var metadata map[string]interface{} + if len(entry.Content) > 0 { + if err := json.Unmarshal(entry.Content, &metadata); err != nil { + return fmt.Errorf("failed to unmarshal metadata: %v", err) + } + } else { + return nil // No metadata to delete + } + + // Remove CORS configuration + delete(metadata, "cors") + + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal metadata: %v", err) + } + + // Update metadata + return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { + request := &filer_pb.CreateEntryRequest{ + Directory: s.bucketsPath + "/" + bucket, + Entry: &filer_pb.Entry{ + Name: S3MetadataFileName, + IsDirectory: false, + Attributes: &filer_pb.FuseAttributes{ + Crtime: time.Now().Unix(), + Mtime: time.Now().Unix(), + FileMode: 0644, + }, + Content: metadataBytes, + }, + } + + _, err := client.CreateEntry(context.Background(), request) + return err + }) +} diff --git a/weed/s3api/cors/cors_test.go b/weed/s3api/cors/cors_test.go new file mode 100644 index 000000000..1b5c54028 --- /dev/null +++ b/weed/s3api/cors/cors_test.go @@ -0,0 +1,526 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestValidateConfiguration(t *testing.T) { + tests := []struct { + name string + config *CORSConfiguration + wantErr bool + }{ + { + name: "nil config", + config: nil, + wantErr: true, + }, + { + name: "empty rules", + config: &CORSConfiguration{ + CORSRules: []CORSRule{}, + }, + wantErr: true, + }, + { + name: "valid single rule", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"*"}, + }, + }, + }, + wantErr: false, + }, + { + name: "too many rules", + config: &CORSConfiguration{ + CORSRules: make([]CORSRule, 101), + }, + wantErr: true, + }, + { + name: "invalid method", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"INVALID"}, + AllowedOrigins: []string{"*"}, + }, + }, + }, + wantErr: true, + }, + { + name: "empty origins", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{}, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid origin with multiple wildcards", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{"http://*.*.example.com"}, + }, + }, + }, + wantErr: true, + }, + { + name: "negative MaxAgeSeconds", + config: &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET"}, + AllowedOrigins: []string{"*"}, + MaxAgeSeconds: intPtr(-1), + }, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateConfiguration(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateOrigin(t *testing.T) { + tests := []struct { + name string + origin string + wantErr bool + }{ + { + name: "empty origin", + origin: "", + wantErr: true, + }, + { + name: "valid origin", + origin: "http://example.com", + wantErr: false, + }, + { + name: "wildcard origin", + origin: "*", + wantErr: false, + }, + { + name: "valid wildcard origin", + origin: "http://*.example.com", + wantErr: false, + }, + { + name: "https wildcard origin", + origin: "https://*.example.com", + wantErr: false, + }, + { + name: "invalid wildcard origin", + origin: "*.example.com", + wantErr: true, + }, + { + name: "multiple wildcards", + origin: "http://*.*.example.com", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateOrigin(tt.origin) + if (err != nil) != tt.wantErr { + t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestParseRequest(t *testing.T) { + tests := []struct { + name string + req *http.Request + want *CORSRequest + }{ + { + name: "simple GET request", + req: &http.Request{ + Method: "GET", + Header: http.Header{ + "Origin": []string{"http://example.com"}, + }, + }, + want: &CORSRequest{ + Origin: "http://example.com", + Method: "GET", + IsPreflightRequest: false, + }, + }, + { + name: "OPTIONS preflight request", + req: &http.Request{ + Method: "OPTIONS", + Header: http.Header{ + "Origin": []string{"http://example.com"}, + "Access-Control-Request-Method": []string{"PUT"}, + "Access-Control-Request-Headers": []string{"Content-Type, Authorization"}, + }, + }, + want: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "PUT", + AccessControlRequestHeaders: []string{"Content-Type", "Authorization"}, + }, + }, + { + name: "request without origin", + req: &http.Request{ + Method: "GET", + Header: http.Header{}, + }, + want: &CORSRequest{ + Origin: "", + Method: "GET", + IsPreflightRequest: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseRequest(tt.req) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMatchesOrigin(t *testing.T) { + tests := []struct { + name string + allowedOrigins []string + origin string + want bool + }{ + { + name: "wildcard match", + allowedOrigins: []string{"*"}, + origin: "http://example.com", + want: true, + }, + { + name: "exact match", + allowedOrigins: []string{"http://example.com"}, + origin: "http://example.com", + want: true, + }, + { + name: "no match", + allowedOrigins: []string{"http://example.com"}, + origin: "http://other.com", + want: false, + }, + { + name: "wildcard subdomain match", + allowedOrigins: []string{"http://*.example.com"}, + origin: "http://api.example.com", + want: true, + }, + { + name: "wildcard subdomain no match", + allowedOrigins: []string{"http://*.example.com"}, + origin: "http://example.com", + want: false, + }, + { + name: "multiple origins with match", + allowedOrigins: []string{"http://example.com", "http://other.com"}, + origin: "http://other.com", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesOrigin(tt.allowedOrigins, tt.origin) + if got != tt.want { + t.Errorf("matchesOrigin() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMatchesHeader(t *testing.T) { + tests := []struct { + name string + allowedHeaders []string + header string + want bool + }{ + { + name: "empty allowed headers", + allowedHeaders: []string{}, + header: "Content-Type", + want: true, + }, + { + name: "wildcard match", + allowedHeaders: []string{"*"}, + header: "Content-Type", + want: true, + }, + { + name: "exact match", + allowedHeaders: []string{"Content-Type"}, + header: "Content-Type", + want: true, + }, + { + name: "case insensitive match", + allowedHeaders: []string{"content-type"}, + header: "Content-Type", + want: true, + }, + { + name: "no match", + allowedHeaders: []string{"Authorization"}, + header: "Content-Type", + want: false, + }, + { + name: "wildcard prefix match", + allowedHeaders: []string{"x-amz-*"}, + header: "x-amz-date", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesHeader(tt.allowedHeaders, tt.header) + if got != tt.want { + t.Errorf("matchesHeader() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEvaluateRequest(t *testing.T) { + config := &CORSConfiguration{ + CORSRules: []CORSRule{ + { + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"http://example.com"}, + AllowedHeaders: []string{"Content-Type"}, + ExposeHeaders: []string{"ETag"}, + MaxAgeSeconds: intPtr(3600), + }, + { + AllowedMethods: []string{"PUT"}, + AllowedOrigins: []string{"*"}, + }, + }, + } + + tests := []struct { + name string + config *CORSConfiguration + corsReq *CORSRequest + want *CORSResponse + wantErr bool + }{ + { + name: "matching first rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "GET", + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + wantErr: false, + }, + { + name: "matching second rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://other.com", + Method: "PUT", + }, + want: &CORSResponse{ + AllowOrigin: "http://other.com", + AllowMethods: "PUT", + }, + wantErr: false, + }, + { + name: "no matching rule", + config: config, + corsReq: &CORSRequest{ + Origin: "http://forbidden.com", + Method: "GET", + }, + want: nil, + wantErr: true, + }, + { + name: "preflight request", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "POST", + AccessControlRequestHeaders: []string{"Content-Type"}, + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + wantErr: false, + }, + { + name: "preflight request with forbidden header", + config: config, + corsReq: &CORSRequest{ + Origin: "http://example.com", + Method: "OPTIONS", + IsPreflightRequest: true, + AccessControlRequestMethod: "POST", + AccessControlRequestHeaders: []string{"Authorization"}, + }, + want: &CORSResponse{ + AllowOrigin: "http://example.com", + // No AllowMethods or AllowHeaders because the requested header is forbidden + }, + wantErr: false, + }, + { + name: "request without origin", + config: config, + corsReq: &CORSRequest{ + Origin: "", + Method: "GET", + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := EvaluateRequest(tt.config, tt.corsReq) + if (err != nil) != tt.wantErr { + t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestApplyHeaders(t *testing.T) { + tests := []struct { + name string + corsResp *CORSResponse + want map[string]string + }{ + { + name: "nil response", + corsResp: nil, + want: map[string]string{}, + }, + { + name: "complete response", + corsResp: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET, POST", + AllowHeaders: "Content-Type", + ExposeHeaders: "ETag", + MaxAge: "3600", + }, + want: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET, POST", + "Access-Control-Allow-Headers": "Content-Type", + "Access-Control-Expose-Headers": "ETag", + "Access-Control-Max-Age": "3600", + }, + }, + { + name: "with credentials", + corsResp: &CORSResponse{ + AllowOrigin: "http://example.com", + AllowMethods: "GET", + AllowCredentials: true, + }, + want: map[string]string{ + "Access-Control-Allow-Origin": "http://example.com", + "Access-Control-Allow-Methods": "GET", + "Access-Control-Allow-Credentials": "true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a proper response writer using httptest + w := httptest.NewRecorder() + + ApplyHeaders(w, tt.corsResp) + + // Extract headers from the response + headers := make(map[string]string) + for key, values := range w.Header() { + if len(values) > 0 { + headers[key] = values[0] + } + } + + if !reflect.DeepEqual(headers, tt.want) { + t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want) + } + }) + } +} + +// Helper functions and types for testing + +func intPtr(i int) *int { + return &i +} diff --git a/weed/s3api/cors/middleware.go b/weed/s3api/cors/middleware.go new file mode 100644 index 000000000..14ff32355 --- /dev/null +++ b/weed/s3api/cors/middleware.go @@ -0,0 +1,143 @@ +package cors + +import ( + "net/http" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" +) + +// BucketChecker interface for checking bucket existence +type BucketChecker interface { + CheckBucket(r *http.Request, bucket string) s3err.ErrorCode +} + +// CORSConfigGetter interface for getting CORS configuration +type CORSConfigGetter interface { + GetCORSConfiguration(bucket string) (*CORSConfiguration, s3err.ErrorCode) +} + +// Middleware handles CORS evaluation for all S3 API requests +type Middleware struct { + storage *Storage + bucketChecker BucketChecker + corsConfigGetter CORSConfigGetter +} + +// NewMiddleware creates a new CORS middleware instance +func NewMiddleware(storage *Storage, bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter) *Middleware { + return &Middleware{ + storage: storage, + bucketChecker: bucketChecker, + corsConfigGetter: corsConfigGetter, + } +} + +// evaluateCORSRequest performs the common CORS request evaluation logic +// Returns: (corsResponse, responseWritten, shouldContinue) +// - corsResponse: the CORS response if evaluation succeeded +// - responseWritten: true if an error response was already written +// - shouldContinue: true if the request should continue to the next handler +func (m *Middleware) evaluateCORSRequest(w http.ResponseWriter, r *http.Request) (*CORSResponse, bool, bool) { + // Parse CORS request + corsReq := ParseRequest(r) + if corsReq.Origin == "" { + // Not a CORS request + return nil, false, true + } + + // Extract bucket from request + bucket, _ := s3_constants.GetBucketAndObject(r) + if bucket == "" { + return nil, false, true + } + + // Check if bucket exists + if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone { + // For non-existent buckets, let the normal handler deal with it + return nil, false, true + } + + // Load CORS configuration from cache + config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket) + if errCode != s3err.ErrNone || config == nil { + // No CORS configuration, handle based on request type + if corsReq.IsPreflightRequest { + // Preflight request without CORS config should fail + s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) + return nil, true, false // Response written, don't continue + } + // Non-preflight request, continue normally + return nil, false, true + } + + // Evaluate CORS request + corsResp, err := EvaluateRequest(config, corsReq) + if err != nil { + glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err) + if corsReq.IsPreflightRequest { + // Preflight request that doesn't match CORS rules should fail + s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) + return nil, true, false // Response written, don't continue + } + // Non-preflight request, continue normally but without CORS headers + return nil, false, true + } + + return corsResp, false, false +} + +// Handler returns the CORS middleware handler +func (m *Middleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use the common evaluation logic + corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r) + if responseWritten { + // Response was already written (error case) + return + } + + if shouldContinue { + // Continue with normal request processing + next.ServeHTTP(w, r) + return + } + + // Parse request to check if it's a preflight request + corsReq := ParseRequest(r) + + // Apply CORS headers to response + ApplyHeaders(w, corsResp) + + // Handle preflight requests + if corsReq.IsPreflightRequest { + // Preflight request should return 200 OK with just CORS headers + w.WriteHeader(http.StatusOK) + return + } + + // Continue with normal request processing + next.ServeHTTP(w, r) + }) +} + +// HandleOptionsRequest handles OPTIONS requests for CORS preflight +func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request) { + // Use the common evaluation logic + corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r) + if responseWritten { + // Response was already written (error case) + return + } + + if shouldContinue || corsResp == nil { + // Not a CORS request or should continue normally + w.WriteHeader(http.StatusOK) + return + } + + // Apply CORS headers and return success + ApplyHeaders(w, corsResp) + w.WriteHeader(http.StatusOK) +} diff --git a/weed/s3api/s3api_bucket_config.go b/weed/s3api/s3api_bucket_config.go index 273eb6fbd..a157b93e8 100644 --- a/weed/s3api/s3api_bucket_config.go +++ b/weed/s3api/s3api_bucket_config.go @@ -1,12 +1,16 @@ package s3api import ( + "encoding/json" "fmt" + "path/filepath" + "strings" "sync" "time" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/s3api/cors" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" ) @@ -18,6 +22,7 @@ type BucketConfig struct { Ownership string ACL []byte Owner string + CORS *cors.CORSConfiguration LastModified time.Time Entry *filer_pb.Entry } @@ -118,6 +123,19 @@ func (s3a *S3ApiServer) getBucketConfig(bucket string) (*BucketConfig, s3err.Err } } + // Load CORS configuration from .s3metadata + if corsConfig, err := s3a.loadCORSFromMetadata(bucket); err != nil { + if err == filer_pb.ErrNotFound { + // Missing metadata is not an error; fall back cleanly + glog.V(2).Infof("CORS metadata not found for bucket %s, falling back to default behavior", bucket) + } else { + // Log parsing or validation errors + glog.Errorf("Failed to load CORS configuration for bucket %s: %v", bucket, err) + } + } else { + config.CORS = corsConfig + } + // Cache the result s3a.bucketConfigCache.Set(bucket, config) @@ -244,3 +262,114 @@ func (s3a *S3ApiServer) removeBucketConfigKey(bucket, key string) s3err.ErrorCod return nil }) } + +// loadCORSFromMetadata loads CORS configuration from bucket metadata +func (s3a *S3ApiServer) loadCORSFromMetadata(bucket string) (*cors.CORSConfiguration, error) { + // Validate bucket name to prevent path traversal attacks + if bucket == "" || strings.Contains(bucket, "/") || strings.Contains(bucket, "\\") || + strings.Contains(bucket, "..") || strings.Contains(bucket, "~") { + return nil, fmt.Errorf("invalid bucket name: %s", bucket) + } + + // Clean the bucket name further to prevent any potential path traversal + bucket = filepath.Clean(bucket) + if bucket == "." || bucket == ".." { + return nil, fmt.Errorf("invalid bucket name: %s", bucket) + } + + bucketMetadataPath := filepath.Join(s3a.option.BucketsPath, bucket, cors.S3MetadataFileName) + + entry, err := s3a.getEntry("", bucketMetadataPath) + if err != nil { + glog.V(3).Infof("loadCORSFromMetadata: error retrieving metadata for bucket %s: %v", bucket, err) + return nil, fmt.Errorf("error retrieving metadata for bucket %s: %v", bucket, err) + } + if entry == nil { + glog.V(3).Infof("loadCORSFromMetadata: no metadata entry found for bucket %s", bucket) + return nil, fmt.Errorf("no metadata entry found for bucket %s", bucket) + } + + if len(entry.Content) == 0 { + glog.V(3).Infof("loadCORSFromMetadata: empty metadata content for bucket %s", bucket) + return nil, fmt.Errorf("no metadata content for bucket %s", bucket) + } + + var metadata map[string]json.RawMessage + if err := json.Unmarshal(entry.Content, &metadata); err != nil { + glog.Errorf("loadCORSFromMetadata: failed to unmarshal metadata for bucket %s: %v", bucket, err) + return nil, fmt.Errorf("failed to unmarshal metadata: %v", err) + } + + corsData, exists := metadata["cors"] + if !exists { + glog.V(3).Infof("loadCORSFromMetadata: no CORS configuration found for bucket %s", bucket) + return nil, fmt.Errorf("no CORS configuration found") + } + + // Directly unmarshal the raw JSON to CORSConfiguration to avoid round-trip allocations + var config cors.CORSConfiguration + if err := json.Unmarshal(corsData, &config); err != nil { + glog.Errorf("loadCORSFromMetadata: failed to unmarshal CORS configuration for bucket %s: %v", bucket, err) + return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err) + } + + return &config, nil +} + +// getCORSConfiguration retrieves CORS configuration with caching +func (s3a *S3ApiServer) getCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) { + config, errCode := s3a.getBucketConfig(bucket) + if errCode != s3err.ErrNone { + return nil, errCode + } + + return config.CORS, s3err.ErrNone +} + +// getCORSStorage returns a CORS storage instance for persistent operations +func (s3a *S3ApiServer) getCORSStorage() *cors.Storage { + entryGetter := &S3EntryGetter{server: s3a} + return cors.NewStorage(s3a, entryGetter, s3a.option.BucketsPath) +} + +// updateCORSConfiguration updates CORS configuration and invalidates cache +func (s3a *S3ApiServer) updateCORSConfiguration(bucket string, corsConfig *cors.CORSConfiguration) s3err.ErrorCode { + // Update in-memory cache + errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error { + config.CORS = corsConfig + return nil + }) + if errCode != s3err.ErrNone { + return errCode + } + + // Persist to .s3metadata file + storage := s3a.getCORSStorage() + if err := storage.Store(bucket, corsConfig); err != nil { + glog.Errorf("updateCORSConfiguration: failed to persist CORS config to metadata for bucket %s: %v", bucket, err) + return s3err.ErrInternalError + } + + return s3err.ErrNone +} + +// removeCORSConfiguration removes CORS configuration and invalidates cache +func (s3a *S3ApiServer) removeCORSConfiguration(bucket string) s3err.ErrorCode { + // Remove from in-memory cache + errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error { + config.CORS = nil + return nil + }) + if errCode != s3err.ErrNone { + return errCode + } + + // Remove from .s3metadata file + storage := s3a.getCORSStorage() + if err := storage.Delete(bucket); err != nil { + glog.Errorf("removeCORSConfiguration: failed to remove CORS config from metadata for bucket %s: %v", bucket, err) + return s3err.ErrInternalError + } + + return s3err.ErrNone +} diff --git a/weed/s3api/s3api_bucket_cors_handlers.go b/weed/s3api/s3api_bucket_cors_handlers.go new file mode 100644 index 000000000..e46021d7e --- /dev/null +++ b/weed/s3api/s3api_bucket_cors_handlers.go @@ -0,0 +1,140 @@ +package s3api + +import ( + "encoding/xml" + "net/http" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/s3api/cors" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" +) + +// S3EntryGetter implements cors.EntryGetter interface +type S3EntryGetter struct { + server *S3ApiServer +} + +func (g *S3EntryGetter) GetEntry(directory, name string) (*filer_pb.Entry, error) { + return g.server.getEntry(directory, name) +} + +// S3BucketChecker implements cors.BucketChecker interface +type S3BucketChecker struct { + server *S3ApiServer +} + +func (c *S3BucketChecker) CheckBucket(r *http.Request, bucket string) s3err.ErrorCode { + return c.server.checkBucket(r, bucket) +} + +// S3CORSConfigGetter implements cors.CORSConfigGetter interface +type S3CORSConfigGetter struct { + server *S3ApiServer +} + +func (g *S3CORSConfigGetter) GetCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) { + return g.server.getCORSConfiguration(bucket) +} + +// getCORSMiddleware returns a CORS middleware instance with caching +func (s3a *S3ApiServer) getCORSMiddleware() *cors.Middleware { + storage := s3a.getCORSStorage() + bucketChecker := &S3BucketChecker{server: s3a} + corsConfigGetter := &S3CORSConfigGetter{server: s3a} + + return cors.NewMiddleware(storage, bucketChecker, corsConfigGetter) +} + +// GetBucketCorsHandler handles Get bucket CORS configuration +// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html +func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) { + bucket, _ := s3_constants.GetBucketAndObject(r) + glog.V(3).Infof("GetBucketCorsHandler %s", bucket) + + if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone { + s3err.WriteErrorResponse(w, r, err) + return + } + + // Load CORS configuration from cache + config, errCode := s3a.getCORSConfiguration(bucket) + if errCode != s3err.ErrNone { + if errCode == s3err.ErrNoSuchBucket { + s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchBucket) + } else { + s3err.WriteErrorResponse(w, r, s3err.ErrInternalError) + } + return + } + + if config == nil { + s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration) + return + } + + // Return CORS configuration as XML + writeSuccessResponseXML(w, r, config) +} + +// PutBucketCorsHandler handles Put bucket CORS configuration +// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html +func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) { + bucket, _ := s3_constants.GetBucketAndObject(r) + glog.V(3).Infof("PutBucketCorsHandler %s", bucket) + + if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone { + s3err.WriteErrorResponse(w, r, err) + return + } + + // Parse CORS configuration from request body + var config cors.CORSConfiguration + if err := xml.NewDecoder(r.Body).Decode(&config); err != nil { + glog.V(1).Infof("Failed to parse CORS configuration: %v", err) + s3err.WriteErrorResponse(w, r, s3err.ErrMalformedXML) + return + } + + // Validate CORS configuration + if err := cors.ValidateConfiguration(&config); err != nil { + glog.V(1).Infof("Invalid CORS configuration: %v", err) + s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) + return + } + + // Store CORS configuration and update cache + // This handles both cache update and persistent storage through the unified bucket config system + if err := s3a.updateCORSConfiguration(bucket, &config); err != s3err.ErrNone { + glog.Errorf("Failed to update CORS configuration: %v", err) + s3err.WriteErrorResponse(w, r, s3err.ErrInternalError) + return + } + + // Return success + writeSuccessResponseEmpty(w, r) +} + +// DeleteBucketCorsHandler handles Delete bucket CORS configuration +// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html +func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) { + bucket, _ := s3_constants.GetBucketAndObject(r) + glog.V(3).Infof("DeleteBucketCorsHandler %s", bucket) + + if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone { + s3err.WriteErrorResponse(w, r, err) + return + } + + // Remove CORS configuration from cache and persistent storage + // This handles both cache invalidation and persistent storage cleanup through the unified bucket config system + if err := s3a.removeCORSConfiguration(bucket); err != s3err.ErrNone { + glog.Errorf("Failed to remove CORS configuration: %v", err) + s3err.WriteErrorResponse(w, r, s3err.ErrInternalError) + return + } + + // Return success (204 No Content) + w.WriteHeader(http.StatusNoContent) +} diff --git a/weed/s3api/s3api_bucket_skip_handlers.go b/weed/s3api/s3api_bucket_skip_handlers.go index 798725203..d51d92b4d 100644 --- a/weed/s3api/s3api_bucket_skip_handlers.go +++ b/weed/s3api/s3api_bucket_skip_handlers.go @@ -8,24 +8,6 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" ) -// GetBucketCorsHandler Get bucket CORS -// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html -func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) { - s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration) -} - -// PutBucketCorsHandler Put bucket CORS -// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html -func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) { - s3err.WriteErrorResponse(w, r, s3err.ErrNotImplemented) -} - -// DeleteBucketCorsHandler Delete bucket CORS -// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html -func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) { - s3err.WriteErrorResponse(w, r, http.StatusNoContent) -} - // GetBucketPolicyHandler Get bucket Policy // https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketPolicy.html func (s3a *S3ApiServer) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request) { diff --git a/weed/s3api/s3api_object_handlers.go b/weed/s3api/s3api_object_handlers.go index 5163a72c2..6b811a024 100644 --- a/weed/s3api/s3api_object_handlers.go +++ b/weed/s3api/s3api_object_handlers.go @@ -20,6 +20,17 @@ import ( util_http "github.com/seaweedfs/seaweedfs/weed/util/http" ) +// corsHeaders defines the CORS headers that need to be preserved +// Package-level constant to avoid repeated allocations +var corsHeaders = []string{ + "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Headers", + "Access-Control-Expose-Headers", + "Access-Control-Max-Age", + "Access-Control-Allow-Credentials", +} + func mimeDetect(r *http.Request, dataReader io.Reader) io.ReadCloser { mimeBuffer := make([]byte, 512) size, _ := dataReader.Read(mimeBuffer) @@ -381,10 +392,34 @@ func setUserMetadataKeyToLowercase(resp *http.Response) { } } +func captureCORSHeaders(w http.ResponseWriter, headersToCapture []string) map[string]string { + captured := make(map[string]string) + for _, corsHeader := range headersToCapture { + if value := w.Header().Get(corsHeader); value != "" { + captured[corsHeader] = value + } + } + return captured +} + +func restoreCORSHeaders(w http.ResponseWriter, capturedCORSHeaders map[string]string) { + for corsHeader, value := range capturedCORSHeaders { + w.Header().Set(corsHeader, value) + } +} + func passThroughResponse(proxyResponse *http.Response, w http.ResponseWriter) (statusCode int, bytesTransferred int64) { + // Capture existing CORS headers that may have been set by middleware + capturedCORSHeaders := captureCORSHeaders(w, corsHeaders) + + // Copy headers from proxy response for k, v := range proxyResponse.Header { w.Header()[k] = v } + + // Restore CORS headers that were set by middleware + restoreCORSHeaders(w, capturedCORSHeaders) + if proxyResponse.Header.Get("Content-Range") != "" && proxyResponse.StatusCode == 200 { w.WriteHeader(http.StatusPartialContent) statusCode = http.StatusPartialContent diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index 426535fe0..5d113c645 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -121,6 +121,35 @@ func NewS3ApiServerWithStore(router *mux.Router, option *S3ApiServerOption, expl return s3ApiServer, nil } +// handleCORSOriginValidation handles the common CORS origin validation logic +func (s3a *S3ApiServer) handleCORSOriginValidation(w http.ResponseWriter, r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin != "" { + if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" { + origin = "*" + } else { + originFound := false + for _, allowedOrigin := range s3a.option.AllowedOrigins { + if origin == allowedOrigin { + originFound = true + break + } + } + if !originFound { + writeFailureResponse(w, r, http.StatusForbidden) + return false + } + } + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Set("Access-Control-Allow-Methods", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Allow-Credentials", "true") + return true +} + func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // API Router apiRouter := router.PathPrefix("/").Subrouter() @@ -129,33 +158,6 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { apiRouter.Methods(http.MethodGet).Path("/status").HandlerFunc(s3a.StatusHandler) apiRouter.Methods(http.MethodGet).Path("/healthz").HandlerFunc(s3a.StatusHandler) - apiRouter.Methods(http.MethodOptions).HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - if origin != "" { - if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" { - origin = "*" - } else { - originFound := false - for _, allowedOrigin := range s3a.option.AllowedOrigins { - if origin == allowedOrigin { - originFound = true - } - } - if !originFound { - writeFailureResponse(w, r, http.StatusForbidden) - return - } - } - } - - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Expose-Headers", "*") - w.Header().Set("Access-Control-Allow-Methods", "*") - w.Header().Set("Access-Control-Allow-Headers", "*") - writeSuccessResponseEmpty(w, r) - }) - var routers []*mux.Router if s3a.option.DomainName != "" { domainNames := strings.Split(s3a.option.DomainName, ",") @@ -168,7 +170,16 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { } routers = append(routers, apiRouter.PathPrefix("/{bucket}").Subrouter()) + // Get CORS middleware instance with caching + corsMiddleware := s3a.getCORSMiddleware() + for _, bucket := range routers { + // Apply CORS middleware to bucket routers for automatic CORS header handling + bucket.Use(corsMiddleware.Handler) + + // Bucket-specific OPTIONS handler for CORS preflight requests + // Use PathPrefix to catch all bucket-level preflight routes including /bucket/object + bucket.PathPrefix("/").Methods(http.MethodOptions).HandlerFunc(corsMiddleware.HandleOptionsRequest) // each case should follow the next rule: // - requesting object with query must precede any other methods @@ -330,6 +341,25 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { } + // Global OPTIONS handler for service-level requests (non-bucket requests) + // This handles requests like OPTIONS /, OPTIONS /status, OPTIONS /healthz + // Place this after bucket handlers to avoid interfering with bucket CORS middleware + apiRouter.Methods(http.MethodOptions).PathPrefix("/").HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Only handle if this is not a bucket-specific request + vars := mux.Vars(r) + bucket := vars["bucket"] + if bucket != "" { + // This is a bucket-specific request, let bucket CORS middleware handle it + http.NotFound(w, r) + return + } + + if s3a.handleCORSOriginValidation(w, r) { + writeSuccessResponseEmpty(w, r) + } + }) + // ListBuckets apiRouter.Methods(http.MethodGet).Path("/").HandlerFunc(track(s3a.ListBucketsHandler, "LIST")) diff --git a/weed/s3api/s3err/error_handler.go b/weed/s3api/s3err/error_handler.go index 910dab12a..81335c489 100644 --- a/weed/s3api/s3err/error_handler.go +++ b/weed/s3api/s3err/error_handler.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/xml" "fmt" - "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" - "github.com/gorilla/mux" - "github.com/seaweedfs/seaweedfs/weed/glog" "net/http" "strconv" "strings" "time" + + "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" + "github.com/gorilla/mux" + "github.com/seaweedfs/seaweedfs/weed/glog" ) type mimeType string @@ -76,10 +77,25 @@ func EncodeXMLResponse(response interface{}) []byte { func setCommonHeaders(w http.ResponseWriter, r *http.Request) { w.Header().Set("x-amz-request-id", fmt.Sprintf("%d", time.Now().UnixNano())) w.Header().Set("Accept-Ranges", "bytes") + + // Only set static CORS headers for service-level requests, not bucket-specific requests if r.Header.Get("Origin") != "" { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Expose-Headers", "*") - w.Header().Set("Access-Control-Allow-Credentials", "true") + // Use mux.Vars to detect bucket-specific requests more reliably + vars := mux.Vars(r) + bucket := vars["bucket"] + isBucketRequest := bucket != "" + + // Only apply static CORS headers if this is NOT a bucket-specific request + // and no bucket-specific CORS headers were already set + if !isBucketRequest && w.Header().Get("Access-Control-Allow-Origin") == "" { + // This is a service-level request (like OPTIONS /), apply static CORS + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + // For bucket-specific requests, let the CORS middleware handle the headers } } |
