aboutsummaryrefslogtreecommitdiff
path: root/weed
diff options
context:
space:
mode:
Diffstat (limited to 'weed')
-rw-r--r--weed/s3api/cors/cors.go649
-rw-r--r--weed/s3api/cors/cors_test.go526
-rw-r--r--weed/s3api/cors/middleware.go143
-rw-r--r--weed/s3api/s3api_bucket_config.go129
-rw-r--r--weed/s3api/s3api_bucket_cors_handlers.go140
-rw-r--r--weed/s3api/s3api_bucket_skip_handlers.go18
-rw-r--r--weed/s3api/s3api_object_handlers.go35
-rw-r--r--weed/s3api/s3api_server.go84
-rw-r--r--weed/s3api/s3err/error_handler.go28
9 files changed, 1701 insertions, 51 deletions
diff --git a/weed/s3api/cors/cors.go b/weed/s3api/cors/cors.go
new file mode 100644
index 000000000..1eef71b72
--- /dev/null
+++ b/weed/s3api/cors/cors.go
@@ -0,0 +1,649 @@
+package cors
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "net/http"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
+)
+
+// S3 metadata file name constant to avoid typos and reduce duplication
+const S3MetadataFileName = ".s3metadata"
+
+// CORSRule represents a single CORS rule
+type CORSRule struct {
+ ID string `xml:"ID,omitempty" json:"ID,omitempty"`
+ AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
+ AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
+ AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
+ ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
+ MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
+}
+
+// CORSConfiguration represents the CORS configuration for a bucket
+type CORSConfiguration struct {
+ XMLName xml.Name `xml:"CORSConfiguration"`
+ CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
+}
+
+// CORSRequest represents a CORS request
+type CORSRequest struct {
+ Origin string
+ Method string
+ RequestHeaders []string
+ IsPreflightRequest bool
+ AccessControlRequestMethod string
+ AccessControlRequestHeaders []string
+}
+
+// CORSResponse represents CORS response headers
+type CORSResponse struct {
+ AllowOrigin string
+ AllowMethods string
+ AllowHeaders string
+ ExposeHeaders string
+ MaxAge string
+ AllowCredentials bool
+}
+
+// ValidateConfiguration validates a CORS configuration
+func ValidateConfiguration(config *CORSConfiguration) error {
+ if config == nil {
+ return fmt.Errorf("CORS configuration cannot be nil")
+ }
+
+ if len(config.CORSRules) == 0 {
+ return fmt.Errorf("CORS configuration must have at least one rule")
+ }
+
+ if len(config.CORSRules) > 100 {
+ return fmt.Errorf("CORS configuration cannot have more than 100 rules")
+ }
+
+ for i, rule := range config.CORSRules {
+ if err := validateRule(&rule); err != nil {
+ return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
+ }
+ }
+
+ return nil
+}
+
+// validateRule validates a single CORS rule
+func validateRule(rule *CORSRule) error {
+ if len(rule.AllowedMethods) == 0 {
+ return fmt.Errorf("AllowedMethods cannot be empty")
+ }
+
+ if len(rule.AllowedOrigins) == 0 {
+ return fmt.Errorf("AllowedOrigins cannot be empty")
+ }
+
+ // Validate allowed methods
+ validMethods := map[string]bool{
+ "GET": true,
+ "PUT": true,
+ "POST": true,
+ "DELETE": true,
+ "HEAD": true,
+ }
+
+ for _, method := range rule.AllowedMethods {
+ if !validMethods[method] {
+ return fmt.Errorf("invalid HTTP method: %s", method)
+ }
+ }
+
+ // Validate origins
+ for _, origin := range rule.AllowedOrigins {
+ if origin == "*" {
+ continue
+ }
+ if err := validateOrigin(origin); err != nil {
+ return fmt.Errorf("invalid origin %s: %v", origin, err)
+ }
+ }
+
+ // Validate MaxAgeSeconds
+ if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
+ return fmt.Errorf("MaxAgeSeconds cannot be negative")
+ }
+
+ return nil
+}
+
+// validateOrigin validates an origin string
+func validateOrigin(origin string) error {
+ if origin == "" {
+ return fmt.Errorf("origin cannot be empty")
+ }
+
+ // Special case: "*" is always valid
+ if origin == "*" {
+ return nil
+ }
+
+ // Count wildcards
+ wildcardCount := strings.Count(origin, "*")
+ if wildcardCount > 1 {
+ return fmt.Errorf("origin can contain at most one wildcard")
+ }
+
+ // If there's a wildcard, it should be in a valid position
+ if wildcardCount == 1 {
+ // Must be in the format: http://*.example.com or https://*.example.com
+ if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
+ return fmt.Errorf("origin with wildcard must start with http:// or https://")
+ }
+ }
+
+ return nil
+}
+
+// ParseRequest parses an HTTP request to extract CORS information
+func ParseRequest(r *http.Request) *CORSRequest {
+ corsReq := &CORSRequest{
+ Origin: r.Header.Get("Origin"),
+ Method: r.Method,
+ }
+
+ // Check if this is a preflight request
+ if r.Method == "OPTIONS" {
+ corsReq.IsPreflightRequest = true
+ corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
+
+ if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
+ corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
+ for i := range corsReq.AccessControlRequestHeaders {
+ corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
+ }
+ }
+ }
+
+ return corsReq
+}
+
+// EvaluateRequest evaluates a CORS request against a CORS configuration
+func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
+ if config == nil || corsReq == nil {
+ return nil, fmt.Errorf("config and corsReq cannot be nil")
+ }
+
+ if corsReq.Origin == "" {
+ return nil, fmt.Errorf("origin header is required for CORS requests")
+ }
+
+ // Find the first rule that matches the origin
+ for _, rule := range config.CORSRules {
+ if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
+ // For preflight requests, we need more detailed validation
+ if corsReq.IsPreflightRequest {
+ return buildPreflightResponse(&rule, corsReq), nil
+ } else {
+ // For actual requests, check method
+ if contains(rule.AllowedMethods, corsReq.Method) {
+ return buildResponse(&rule, corsReq), nil
+ }
+ }
+ }
+ }
+
+ return nil, fmt.Errorf("no matching CORS rule found")
+}
+
+// matchesRule checks if a CORS request matches a CORS rule
+func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool {
+ // Check origin - this is the primary matching criterion
+ if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
+ return false
+ }
+
+ // For preflight requests, we need to validate both the requested method and headers
+ if corsReq.IsPreflightRequest {
+ // Check if the requested method is allowed
+ if corsReq.AccessControlRequestMethod != "" {
+ if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) {
+ return false
+ }
+ }
+
+ // Check if all requested headers are allowed
+ if len(corsReq.AccessControlRequestHeaders) > 0 {
+ for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
+ if !matchesHeader(rule.AllowedHeaders, requestedHeader) {
+ return false
+ }
+ }
+ }
+
+ return true
+ }
+
+ // For non-preflight requests, check method matching
+ method := corsReq.Method
+ if !contains(rule.AllowedMethods, method) {
+ return false
+ }
+
+ return true
+}
+
+// matchesOrigin checks if an origin matches any of the allowed origins
+func matchesOrigin(allowedOrigins []string, origin string) bool {
+ for _, allowedOrigin := range allowedOrigins {
+ if allowedOrigin == "*" {
+ return true
+ }
+
+ if allowedOrigin == origin {
+ return true
+ }
+
+ // Check wildcard matching
+ if strings.Contains(allowedOrigin, "*") {
+ if matchesWildcard(allowedOrigin, origin) {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// matchesWildcard checks if an origin matches a wildcard pattern
+// Uses string manipulation instead of regex for better performance
+func matchesWildcard(pattern, origin string) bool {
+ // Handle simple cases first
+ if pattern == "*" {
+ return true
+ }
+ if pattern == origin {
+ return true
+ }
+
+ // For CORS, we typically only deal with * wildcards (not ? wildcards)
+ // Use string manipulation for * wildcards only (more efficient than regex)
+
+ // Split pattern by wildcards
+ parts := strings.Split(pattern, "*")
+ if len(parts) == 1 {
+ // No wildcards, exact match
+ return pattern == origin
+ }
+
+ // Check if string starts with first part
+ if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) {
+ return false
+ }
+
+ // Check if string ends with last part
+ if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) {
+ return false
+ }
+
+ // Check middle parts
+ searchStr := origin
+ if len(parts[0]) > 0 {
+ searchStr = searchStr[len(parts[0]):]
+ }
+ if len(parts[len(parts)-1]) > 0 {
+ searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])]
+ }
+
+ for i := 1; i < len(parts)-1; i++ {
+ if len(parts[i]) > 0 {
+ index := strings.Index(searchStr, parts[i])
+ if index == -1 {
+ return false
+ }
+ searchStr = searchStr[index+len(parts[i]):]
+ }
+ }
+
+ return true
+}
+
+// matchesHeader checks if a header matches allowed headers
+func matchesHeader(allowedHeaders []string, header string) bool {
+ if len(allowedHeaders) == 0 {
+ return true // No restrictions
+ }
+
+ for _, allowedHeader := range allowedHeaders {
+ if allowedHeader == "*" {
+ return true
+ }
+
+ if strings.EqualFold(allowedHeader, header) {
+ return true
+ }
+
+ // Check wildcard matching for headers
+ if strings.Contains(allowedHeader, "*") {
+ if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// buildPreflightResponse builds a CORS response for preflight requests
+// This function allows partial matches - origin can match while methods/headers may not
+func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
+ response := &CORSResponse{
+ AllowOrigin: corsReq.Origin,
+ }
+
+ // Check if the requested method is allowed
+ methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
+
+ // Check requested headers
+ var allowedRequestHeaders []string
+ allHeadersAllowed := true
+
+ if len(corsReq.AccessControlRequestHeaders) > 0 {
+ // Check if wildcard is allowed
+ hasWildcard := false
+ for _, header := range rule.AllowedHeaders {
+ if header == "*" {
+ hasWildcard = true
+ break
+ }
+ }
+
+ if hasWildcard {
+ // All requested headers are allowed with wildcard
+ allowedRequestHeaders = corsReq.AccessControlRequestHeaders
+ } else {
+ // Check each requested header individually
+ for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
+ if matchesHeader(rule.AllowedHeaders, requestedHeader) {
+ allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
+ } else {
+ allHeadersAllowed = false
+ }
+ }
+ }
+ }
+
+ // Only set method and header info if both method and ALL headers are allowed
+ if methodAllowed && allHeadersAllowed {
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+
+ if len(allowedRequestHeaders) > 0 {
+ response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
+ }
+
+ // Set exposed headers
+ if len(rule.ExposeHeaders) > 0 {
+ response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
+ }
+
+ // Set max age
+ if rule.MaxAgeSeconds != nil {
+ response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
+ }
+ }
+
+ return response
+}
+
+// buildResponse builds a CORS response from a matching rule
+func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
+ response := &CORSResponse{
+ AllowOrigin: corsReq.Origin,
+ }
+
+ // Set allowed methods - for preflight requests, return all allowed methods
+ if corsReq.IsPreflightRequest {
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+ } else {
+ // For non-preflight requests, return all allowed methods
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+ }
+
+ // Set allowed headers
+ if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 {
+ // For preflight requests, check if wildcard is allowed
+ hasWildcard := false
+ for _, header := range rule.AllowedHeaders {
+ if header == "*" {
+ hasWildcard = true
+ break
+ }
+ }
+
+ if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 {
+ // Return the specific headers that were requested when wildcard is allowed
+ response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
+ } else if len(corsReq.AccessControlRequestHeaders) > 0 {
+ // For non-wildcard cases, return the requested headers (preserving case)
+ // since we already validated they are allowed in matchesRule
+ response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
+ } else {
+ // Fallback to configured headers if no specific headers were requested
+ response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
+ }
+ } else if len(rule.AllowedHeaders) > 0 {
+ // For non-preflight requests, return the allowed headers from the rule
+ response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
+ }
+
+ // Set exposed headers
+ if len(rule.ExposeHeaders) > 0 {
+ response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
+ }
+
+ // Set max age
+ if rule.MaxAgeSeconds != nil {
+ response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
+ }
+
+ return response
+}
+
+// contains checks if a slice contains a string
+func contains(slice []string, item string) bool {
+ for _, s := range slice {
+ if s == item {
+ return true
+ }
+ }
+ return false
+}
+
+// ApplyHeaders applies CORS headers to an HTTP response
+func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
+ if corsResp == nil {
+ return
+ }
+
+ if corsResp.AllowOrigin != "" {
+ w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
+ }
+
+ if corsResp.AllowMethods != "" {
+ w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
+ }
+
+ if corsResp.AllowHeaders != "" {
+ w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
+ }
+
+ if corsResp.ExposeHeaders != "" {
+ w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
+ }
+
+ if corsResp.MaxAge != "" {
+ w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
+ }
+
+ if corsResp.AllowCredentials {
+ w.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+}
+
+// FilerClient interface for dependency injection
+type FilerClient interface {
+ WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error
+}
+
+// EntryGetter interface for getting filer entries
+type EntryGetter interface {
+ GetEntry(directory, name string) (*filer_pb.Entry, error)
+}
+
+// Storage provides CORS configuration storage operations
+type Storage struct {
+ filerClient FilerClient
+ entryGetter EntryGetter
+ bucketsPath string
+}
+
+// NewStorage creates a new CORS storage instance
+func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage {
+ return &Storage{
+ filerClient: filerClient,
+ entryGetter: entryGetter,
+ bucketsPath: bucketsPath,
+ }
+}
+
+// Store stores CORS configuration in the filer
+func (s *Storage) Store(bucket string, config *CORSConfiguration) error {
+ // Store in bucket metadata
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ // Get existing metadata
+ existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ var metadata map[string]interface{}
+
+ if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 {
+ if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil {
+ glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err)
+ metadata = make(map[string]interface{})
+ }
+ } else {
+ metadata = make(map[string]interface{})
+ }
+
+ metadata["cors"] = config
+
+ metadataBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return fmt.Errorf("failed to marshal bucket metadata: %v", err)
+ }
+
+ // Store metadata
+ return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
+ request := &filer_pb.CreateEntryRequest{
+ Directory: s.bucketsPath + "/" + bucket,
+ Entry: &filer_pb.Entry{
+ Name: S3MetadataFileName,
+ IsDirectory: false,
+ Attributes: &filer_pb.FuseAttributes{
+ Crtime: time.Now().Unix(),
+ Mtime: time.Now().Unix(),
+ FileMode: 0644,
+ },
+ Content: metadataBytes,
+ },
+ }
+
+ _, err := client.CreateEntry(context.Background(), request)
+ return err
+ })
+}
+
+// Load loads CORS configuration from the filer
+func (s *Storage) Load(bucket string) (*CORSConfiguration, error) {
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ if err != nil || entry == nil {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ if len(entry.Content) == 0 {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ var metadata map[string]interface{}
+ if err := json.Unmarshal(entry.Content, &metadata); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
+ }
+
+ corsData, exists := metadata["cors"]
+ if !exists {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ // Convert back to CORSConfiguration
+ corsBytes, err := json.Marshal(corsData)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal CORS data: %v", err)
+ }
+
+ var config CORSConfiguration
+ if err := json.Unmarshal(corsBytes, &config); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
+ }
+
+ return &config, nil
+}
+
+// Delete deletes CORS configuration from the filer
+func (s *Storage) Delete(bucket string) error {
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ if err != nil || entry == nil {
+ return nil // Already deleted or doesn't exist
+ }
+
+ var metadata map[string]interface{}
+ if len(entry.Content) > 0 {
+ if err := json.Unmarshal(entry.Content, &metadata); err != nil {
+ return fmt.Errorf("failed to unmarshal metadata: %v", err)
+ }
+ } else {
+ return nil // No metadata to delete
+ }
+
+ // Remove CORS configuration
+ delete(metadata, "cors")
+
+ metadataBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return fmt.Errorf("failed to marshal metadata: %v", err)
+ }
+
+ // Update metadata
+ return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
+ request := &filer_pb.CreateEntryRequest{
+ Directory: s.bucketsPath + "/" + bucket,
+ Entry: &filer_pb.Entry{
+ Name: S3MetadataFileName,
+ IsDirectory: false,
+ Attributes: &filer_pb.FuseAttributes{
+ Crtime: time.Now().Unix(),
+ Mtime: time.Now().Unix(),
+ FileMode: 0644,
+ },
+ Content: metadataBytes,
+ },
+ }
+
+ _, err := client.CreateEntry(context.Background(), request)
+ return err
+ })
+}
diff --git a/weed/s3api/cors/cors_test.go b/weed/s3api/cors/cors_test.go
new file mode 100644
index 000000000..1b5c54028
--- /dev/null
+++ b/weed/s3api/cors/cors_test.go
@@ -0,0 +1,526 @@
+package cors
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "testing"
+)
+
+func TestValidateConfiguration(t *testing.T) {
+ tests := []struct {
+ name string
+ config *CORSConfiguration
+ wantErr bool
+ }{
+ {
+ name: "nil config",
+ config: nil,
+ wantErr: true,
+ },
+ {
+ name: "empty rules",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "valid single rule",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET", "POST"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "too many rules",
+ config: &CORSConfiguration{
+ CORSRules: make([]CORSRule, 101),
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid method",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"INVALID"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "empty origins",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid origin with multiple wildcards",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{"http://*.*.example.com"},
+ },
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "negative MaxAgeSeconds",
+ config: &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET"},
+ AllowedOrigins: []string{"*"},
+ MaxAgeSeconds: intPtr(-1),
+ },
+ },
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ValidateConfiguration(tt.config)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestValidateOrigin(t *testing.T) {
+ tests := []struct {
+ name string
+ origin string
+ wantErr bool
+ }{
+ {
+ name: "empty origin",
+ origin: "",
+ wantErr: true,
+ },
+ {
+ name: "valid origin",
+ origin: "http://example.com",
+ wantErr: false,
+ },
+ {
+ name: "wildcard origin",
+ origin: "*",
+ wantErr: false,
+ },
+ {
+ name: "valid wildcard origin",
+ origin: "http://*.example.com",
+ wantErr: false,
+ },
+ {
+ name: "https wildcard origin",
+ origin: "https://*.example.com",
+ wantErr: false,
+ },
+ {
+ name: "invalid wildcard origin",
+ origin: "*.example.com",
+ wantErr: true,
+ },
+ {
+ name: "multiple wildcards",
+ origin: "http://*.*.example.com",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateOrigin(tt.origin)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestParseRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ req *http.Request
+ want *CORSRequest
+ }{
+ {
+ name: "simple GET request",
+ req: &http.Request{
+ Method: "GET",
+ Header: http.Header{
+ "Origin": []string{"http://example.com"},
+ },
+ },
+ want: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "GET",
+ IsPreflightRequest: false,
+ },
+ },
+ {
+ name: "OPTIONS preflight request",
+ req: &http.Request{
+ Method: "OPTIONS",
+ Header: http.Header{
+ "Origin": []string{"http://example.com"},
+ "Access-Control-Request-Method": []string{"PUT"},
+ "Access-Control-Request-Headers": []string{"Content-Type, Authorization"},
+ },
+ },
+ want: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "PUT",
+ AccessControlRequestHeaders: []string{"Content-Type", "Authorization"},
+ },
+ },
+ {
+ name: "request without origin",
+ req: &http.Request{
+ Method: "GET",
+ Header: http.Header{},
+ },
+ want: &CORSRequest{
+ Origin: "",
+ Method: "GET",
+ IsPreflightRequest: false,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ParseRequest(tt.req)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("ParseRequest() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMatchesOrigin(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedOrigins []string
+ origin string
+ want bool
+ }{
+ {
+ name: "wildcard match",
+ allowedOrigins: []string{"*"},
+ origin: "http://example.com",
+ want: true,
+ },
+ {
+ name: "exact match",
+ allowedOrigins: []string{"http://example.com"},
+ origin: "http://example.com",
+ want: true,
+ },
+ {
+ name: "no match",
+ allowedOrigins: []string{"http://example.com"},
+ origin: "http://other.com",
+ want: false,
+ },
+ {
+ name: "wildcard subdomain match",
+ allowedOrigins: []string{"http://*.example.com"},
+ origin: "http://api.example.com",
+ want: true,
+ },
+ {
+ name: "wildcard subdomain no match",
+ allowedOrigins: []string{"http://*.example.com"},
+ origin: "http://example.com",
+ want: false,
+ },
+ {
+ name: "multiple origins with match",
+ allowedOrigins: []string{"http://example.com", "http://other.com"},
+ origin: "http://other.com",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := matchesOrigin(tt.allowedOrigins, tt.origin)
+ if got != tt.want {
+ t.Errorf("matchesOrigin() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMatchesHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedHeaders []string
+ header string
+ want bool
+ }{
+ {
+ name: "empty allowed headers",
+ allowedHeaders: []string{},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "wildcard match",
+ allowedHeaders: []string{"*"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "exact match",
+ allowedHeaders: []string{"Content-Type"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "case insensitive match",
+ allowedHeaders: []string{"content-type"},
+ header: "Content-Type",
+ want: true,
+ },
+ {
+ name: "no match",
+ allowedHeaders: []string{"Authorization"},
+ header: "Content-Type",
+ want: false,
+ },
+ {
+ name: "wildcard prefix match",
+ allowedHeaders: []string{"x-amz-*"},
+ header: "x-amz-date",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := matchesHeader(tt.allowedHeaders, tt.header)
+ if got != tt.want {
+ t.Errorf("matchesHeader() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEvaluateRequest(t *testing.T) {
+ config := &CORSConfiguration{
+ CORSRules: []CORSRule{
+ {
+ AllowedMethods: []string{"GET", "POST"},
+ AllowedOrigins: []string{"http://example.com"},
+ AllowedHeaders: []string{"Content-Type"},
+ ExposeHeaders: []string{"ETag"},
+ MaxAgeSeconds: intPtr(3600),
+ },
+ {
+ AllowedMethods: []string{"PUT"},
+ AllowedOrigins: []string{"*"},
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ config *CORSConfiguration
+ corsReq *CORSRequest
+ want *CORSResponse
+ wantErr bool
+ }{
+ {
+ name: "matching first rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "GET",
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ wantErr: false,
+ },
+ {
+ name: "matching second rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://other.com",
+ Method: "PUT",
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://other.com",
+ AllowMethods: "PUT",
+ },
+ wantErr: false,
+ },
+ {
+ name: "no matching rule",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://forbidden.com",
+ Method: "GET",
+ },
+ want: nil,
+ wantErr: true,
+ },
+ {
+ name: "preflight request",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "POST",
+ AccessControlRequestHeaders: []string{"Content-Type"},
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ wantErr: false,
+ },
+ {
+ name: "preflight request with forbidden header",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "http://example.com",
+ Method: "OPTIONS",
+ IsPreflightRequest: true,
+ AccessControlRequestMethod: "POST",
+ AccessControlRequestHeaders: []string{"Authorization"},
+ },
+ want: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ // No AllowMethods or AllowHeaders because the requested header is forbidden
+ },
+ wantErr: false,
+ },
+ {
+ name: "request without origin",
+ config: config,
+ corsReq: &CORSRequest{
+ Origin: "",
+ Method: "GET",
+ },
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := EvaluateRequest(tt.config, tt.corsReq)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestApplyHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ corsResp *CORSResponse
+ want map[string]string
+ }{
+ {
+ name: "nil response",
+ corsResp: nil,
+ want: map[string]string{},
+ },
+ {
+ name: "complete response",
+ corsResp: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET, POST",
+ AllowHeaders: "Content-Type",
+ ExposeHeaders: "ETag",
+ MaxAge: "3600",
+ },
+ want: map[string]string{
+ "Access-Control-Allow-Origin": "http://example.com",
+ "Access-Control-Allow-Methods": "GET, POST",
+ "Access-Control-Allow-Headers": "Content-Type",
+ "Access-Control-Expose-Headers": "ETag",
+ "Access-Control-Max-Age": "3600",
+ },
+ },
+ {
+ name: "with credentials",
+ corsResp: &CORSResponse{
+ AllowOrigin: "http://example.com",
+ AllowMethods: "GET",
+ AllowCredentials: true,
+ },
+ want: map[string]string{
+ "Access-Control-Allow-Origin": "http://example.com",
+ "Access-Control-Allow-Methods": "GET",
+ "Access-Control-Allow-Credentials": "true",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a proper response writer using httptest
+ w := httptest.NewRecorder()
+
+ ApplyHeaders(w, tt.corsResp)
+
+ // Extract headers from the response
+ headers := make(map[string]string)
+ for key, values := range w.Header() {
+ if len(values) > 0 {
+ headers[key] = values[0]
+ }
+ }
+
+ if !reflect.DeepEqual(headers, tt.want) {
+ t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want)
+ }
+ })
+ }
+}
+
+// Helper functions and types for testing
+
+func intPtr(i int) *int {
+ return &i
+}
diff --git a/weed/s3api/cors/middleware.go b/weed/s3api/cors/middleware.go
new file mode 100644
index 000000000..14ff32355
--- /dev/null
+++ b/weed/s3api/cors/middleware.go
@@ -0,0 +1,143 @@
+package cors
+
+import (
+ "net/http"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
+)
+
+// BucketChecker interface for checking bucket existence
+type BucketChecker interface {
+ CheckBucket(r *http.Request, bucket string) s3err.ErrorCode
+}
+
+// CORSConfigGetter interface for getting CORS configuration
+type CORSConfigGetter interface {
+ GetCORSConfiguration(bucket string) (*CORSConfiguration, s3err.ErrorCode)
+}
+
+// Middleware handles CORS evaluation for all S3 API requests
+type Middleware struct {
+ storage *Storage
+ bucketChecker BucketChecker
+ corsConfigGetter CORSConfigGetter
+}
+
+// NewMiddleware creates a new CORS middleware instance
+func NewMiddleware(storage *Storage, bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter) *Middleware {
+ return &Middleware{
+ storage: storage,
+ bucketChecker: bucketChecker,
+ corsConfigGetter: corsConfigGetter,
+ }
+}
+
+// evaluateCORSRequest performs the common CORS request evaluation logic
+// Returns: (corsResponse, responseWritten, shouldContinue)
+// - corsResponse: the CORS response if evaluation succeeded
+// - responseWritten: true if an error response was already written
+// - shouldContinue: true if the request should continue to the next handler
+func (m *Middleware) evaluateCORSRequest(w http.ResponseWriter, r *http.Request) (*CORSResponse, bool, bool) {
+ // Parse CORS request
+ corsReq := ParseRequest(r)
+ if corsReq.Origin == "" {
+ // Not a CORS request
+ return nil, false, true
+ }
+
+ // Extract bucket from request
+ bucket, _ := s3_constants.GetBucketAndObject(r)
+ if bucket == "" {
+ return nil, false, true
+ }
+
+ // Check if bucket exists
+ if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone {
+ // For non-existent buckets, let the normal handler deal with it
+ return nil, false, true
+ }
+
+ // Load CORS configuration from cache
+ config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket)
+ if errCode != s3err.ErrNone || config == nil {
+ // No CORS configuration, handle based on request type
+ if corsReq.IsPreflightRequest {
+ // Preflight request without CORS config should fail
+ s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
+ return nil, true, false // Response written, don't continue
+ }
+ // Non-preflight request, continue normally
+ return nil, false, true
+ }
+
+ // Evaluate CORS request
+ corsResp, err := EvaluateRequest(config, corsReq)
+ if err != nil {
+ glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err)
+ if corsReq.IsPreflightRequest {
+ // Preflight request that doesn't match CORS rules should fail
+ s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
+ return nil, true, false // Response written, don't continue
+ }
+ // Non-preflight request, continue normally but without CORS headers
+ return nil, false, true
+ }
+
+ return corsResp, false, false
+}
+
+// Handler returns the CORS middleware handler
+func (m *Middleware) Handler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Use the common evaluation logic
+ corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
+ if responseWritten {
+ // Response was already written (error case)
+ return
+ }
+
+ if shouldContinue {
+ // Continue with normal request processing
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // Parse request to check if it's a preflight request
+ corsReq := ParseRequest(r)
+
+ // Apply CORS headers to response
+ ApplyHeaders(w, corsResp)
+
+ // Handle preflight requests
+ if corsReq.IsPreflightRequest {
+ // Preflight request should return 200 OK with just CORS headers
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ // Continue with normal request processing
+ next.ServeHTTP(w, r)
+ })
+}
+
+// HandleOptionsRequest handles OPTIONS requests for CORS preflight
+func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request) {
+ // Use the common evaluation logic
+ corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
+ if responseWritten {
+ // Response was already written (error case)
+ return
+ }
+
+ if shouldContinue || corsResp == nil {
+ // Not a CORS request or should continue normally
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ // Apply CORS headers and return success
+ ApplyHeaders(w, corsResp)
+ w.WriteHeader(http.StatusOK)
+}
diff --git a/weed/s3api/s3api_bucket_config.go b/weed/s3api/s3api_bucket_config.go
index 273eb6fbd..a157b93e8 100644
--- a/weed/s3api/s3api_bucket_config.go
+++ b/weed/s3api/s3api_bucket_config.go
@@ -1,12 +1,16 @@
package s3api
import (
+ "encoding/json"
"fmt"
+ "path/filepath"
+ "strings"
"sync"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/cors"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
@@ -18,6 +22,7 @@ type BucketConfig struct {
Ownership string
ACL []byte
Owner string
+ CORS *cors.CORSConfiguration
LastModified time.Time
Entry *filer_pb.Entry
}
@@ -118,6 +123,19 @@ func (s3a *S3ApiServer) getBucketConfig(bucket string) (*BucketConfig, s3err.Err
}
}
+ // Load CORS configuration from .s3metadata
+ if corsConfig, err := s3a.loadCORSFromMetadata(bucket); err != nil {
+ if err == filer_pb.ErrNotFound {
+ // Missing metadata is not an error; fall back cleanly
+ glog.V(2).Infof("CORS metadata not found for bucket %s, falling back to default behavior", bucket)
+ } else {
+ // Log parsing or validation errors
+ glog.Errorf("Failed to load CORS configuration for bucket %s: %v", bucket, err)
+ }
+ } else {
+ config.CORS = corsConfig
+ }
+
// Cache the result
s3a.bucketConfigCache.Set(bucket, config)
@@ -244,3 +262,114 @@ func (s3a *S3ApiServer) removeBucketConfigKey(bucket, key string) s3err.ErrorCod
return nil
})
}
+
+// loadCORSFromMetadata loads CORS configuration from bucket metadata
+func (s3a *S3ApiServer) loadCORSFromMetadata(bucket string) (*cors.CORSConfiguration, error) {
+ // Validate bucket name to prevent path traversal attacks
+ if bucket == "" || strings.Contains(bucket, "/") || strings.Contains(bucket, "\\") ||
+ strings.Contains(bucket, "..") || strings.Contains(bucket, "~") {
+ return nil, fmt.Errorf("invalid bucket name: %s", bucket)
+ }
+
+ // Clean the bucket name further to prevent any potential path traversal
+ bucket = filepath.Clean(bucket)
+ if bucket == "." || bucket == ".." {
+ return nil, fmt.Errorf("invalid bucket name: %s", bucket)
+ }
+
+ bucketMetadataPath := filepath.Join(s3a.option.BucketsPath, bucket, cors.S3MetadataFileName)
+
+ entry, err := s3a.getEntry("", bucketMetadataPath)
+ if err != nil {
+ glog.V(3).Infof("loadCORSFromMetadata: error retrieving metadata for bucket %s: %v", bucket, err)
+ return nil, fmt.Errorf("error retrieving metadata for bucket %s: %v", bucket, err)
+ }
+ if entry == nil {
+ glog.V(3).Infof("loadCORSFromMetadata: no metadata entry found for bucket %s", bucket)
+ return nil, fmt.Errorf("no metadata entry found for bucket %s", bucket)
+ }
+
+ if len(entry.Content) == 0 {
+ glog.V(3).Infof("loadCORSFromMetadata: empty metadata content for bucket %s", bucket)
+ return nil, fmt.Errorf("no metadata content for bucket %s", bucket)
+ }
+
+ var metadata map[string]json.RawMessage
+ if err := json.Unmarshal(entry.Content, &metadata); err != nil {
+ glog.Errorf("loadCORSFromMetadata: failed to unmarshal metadata for bucket %s: %v", bucket, err)
+ return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
+ }
+
+ corsData, exists := metadata["cors"]
+ if !exists {
+ glog.V(3).Infof("loadCORSFromMetadata: no CORS configuration found for bucket %s", bucket)
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ // Directly unmarshal the raw JSON to CORSConfiguration to avoid round-trip allocations
+ var config cors.CORSConfiguration
+ if err := json.Unmarshal(corsData, &config); err != nil {
+ glog.Errorf("loadCORSFromMetadata: failed to unmarshal CORS configuration for bucket %s: %v", bucket, err)
+ return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
+ }
+
+ return &config, nil
+}
+
+// getCORSConfiguration retrieves CORS configuration with caching
+func (s3a *S3ApiServer) getCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) {
+ config, errCode := s3a.getBucketConfig(bucket)
+ if errCode != s3err.ErrNone {
+ return nil, errCode
+ }
+
+ return config.CORS, s3err.ErrNone
+}
+
+// getCORSStorage returns a CORS storage instance for persistent operations
+func (s3a *S3ApiServer) getCORSStorage() *cors.Storage {
+ entryGetter := &S3EntryGetter{server: s3a}
+ return cors.NewStorage(s3a, entryGetter, s3a.option.BucketsPath)
+}
+
+// updateCORSConfiguration updates CORS configuration and invalidates cache
+func (s3a *S3ApiServer) updateCORSConfiguration(bucket string, corsConfig *cors.CORSConfiguration) s3err.ErrorCode {
+ // Update in-memory cache
+ errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error {
+ config.CORS = corsConfig
+ return nil
+ })
+ if errCode != s3err.ErrNone {
+ return errCode
+ }
+
+ // Persist to .s3metadata file
+ storage := s3a.getCORSStorage()
+ if err := storage.Store(bucket, corsConfig); err != nil {
+ glog.Errorf("updateCORSConfiguration: failed to persist CORS config to metadata for bucket %s: %v", bucket, err)
+ return s3err.ErrInternalError
+ }
+
+ return s3err.ErrNone
+}
+
+// removeCORSConfiguration removes CORS configuration and invalidates cache
+func (s3a *S3ApiServer) removeCORSConfiguration(bucket string) s3err.ErrorCode {
+ // Remove from in-memory cache
+ errCode := s3a.updateBucketConfig(bucket, func(config *BucketConfig) error {
+ config.CORS = nil
+ return nil
+ })
+ if errCode != s3err.ErrNone {
+ return errCode
+ }
+
+ // Remove from .s3metadata file
+ storage := s3a.getCORSStorage()
+ if err := storage.Delete(bucket); err != nil {
+ glog.Errorf("removeCORSConfiguration: failed to remove CORS config from metadata for bucket %s: %v", bucket, err)
+ return s3err.ErrInternalError
+ }
+
+ return s3err.ErrNone
+}
diff --git a/weed/s3api/s3api_bucket_cors_handlers.go b/weed/s3api/s3api_bucket_cors_handlers.go
new file mode 100644
index 000000000..e46021d7e
--- /dev/null
+++ b/weed/s3api/s3api_bucket_cors_handlers.go
@@ -0,0 +1,140 @@
+package s3api
+
+import (
+ "encoding/xml"
+ "net/http"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/cors"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
+ "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
+)
+
+// S3EntryGetter implements cors.EntryGetter interface
+type S3EntryGetter struct {
+ server *S3ApiServer
+}
+
+func (g *S3EntryGetter) GetEntry(directory, name string) (*filer_pb.Entry, error) {
+ return g.server.getEntry(directory, name)
+}
+
+// S3BucketChecker implements cors.BucketChecker interface
+type S3BucketChecker struct {
+ server *S3ApiServer
+}
+
+func (c *S3BucketChecker) CheckBucket(r *http.Request, bucket string) s3err.ErrorCode {
+ return c.server.checkBucket(r, bucket)
+}
+
+// S3CORSConfigGetter implements cors.CORSConfigGetter interface
+type S3CORSConfigGetter struct {
+ server *S3ApiServer
+}
+
+func (g *S3CORSConfigGetter) GetCORSConfiguration(bucket string) (*cors.CORSConfiguration, s3err.ErrorCode) {
+ return g.server.getCORSConfiguration(bucket)
+}
+
+// getCORSMiddleware returns a CORS middleware instance with caching
+func (s3a *S3ApiServer) getCORSMiddleware() *cors.Middleware {
+ storage := s3a.getCORSStorage()
+ bucketChecker := &S3BucketChecker{server: s3a}
+ corsConfigGetter := &S3CORSConfigGetter{server: s3a}
+
+ return cors.NewMiddleware(storage, bucketChecker, corsConfigGetter)
+}
+
+// GetBucketCorsHandler handles Get bucket CORS configuration
+// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html
+func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
+ bucket, _ := s3_constants.GetBucketAndObject(r)
+ glog.V(3).Infof("GetBucketCorsHandler %s", bucket)
+
+ if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
+ s3err.WriteErrorResponse(w, r, err)
+ return
+ }
+
+ // Load CORS configuration from cache
+ config, errCode := s3a.getCORSConfiguration(bucket)
+ if errCode != s3err.ErrNone {
+ if errCode == s3err.ErrNoSuchBucket {
+ s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchBucket)
+ } else {
+ s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
+ }
+ return
+ }
+
+ if config == nil {
+ s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration)
+ return
+ }
+
+ // Return CORS configuration as XML
+ writeSuccessResponseXML(w, r, config)
+}
+
+// PutBucketCorsHandler handles Put bucket CORS configuration
+// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html
+func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
+ bucket, _ := s3_constants.GetBucketAndObject(r)
+ glog.V(3).Infof("PutBucketCorsHandler %s", bucket)
+
+ if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
+ s3err.WriteErrorResponse(w, r, err)
+ return
+ }
+
+ // Parse CORS configuration from request body
+ var config cors.CORSConfiguration
+ if err := xml.NewDecoder(r.Body).Decode(&config); err != nil {
+ glog.V(1).Infof("Failed to parse CORS configuration: %v", err)
+ s3err.WriteErrorResponse(w, r, s3err.ErrMalformedXML)
+ return
+ }
+
+ // Validate CORS configuration
+ if err := cors.ValidateConfiguration(&config); err != nil {
+ glog.V(1).Infof("Invalid CORS configuration: %v", err)
+ s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
+ return
+ }
+
+ // Store CORS configuration and update cache
+ // This handles both cache update and persistent storage through the unified bucket config system
+ if err := s3a.updateCORSConfiguration(bucket, &config); err != s3err.ErrNone {
+ glog.Errorf("Failed to update CORS configuration: %v", err)
+ s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
+ return
+ }
+
+ // Return success
+ writeSuccessResponseEmpty(w, r)
+}
+
+// DeleteBucketCorsHandler handles Delete bucket CORS configuration
+// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html
+func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
+ bucket, _ := s3_constants.GetBucketAndObject(r)
+ glog.V(3).Infof("DeleteBucketCorsHandler %s", bucket)
+
+ if err := s3a.checkBucket(r, bucket); err != s3err.ErrNone {
+ s3err.WriteErrorResponse(w, r, err)
+ return
+ }
+
+ // Remove CORS configuration from cache and persistent storage
+ // This handles both cache invalidation and persistent storage cleanup through the unified bucket config system
+ if err := s3a.removeCORSConfiguration(bucket); err != s3err.ErrNone {
+ glog.Errorf("Failed to remove CORS configuration: %v", err)
+ s3err.WriteErrorResponse(w, r, s3err.ErrInternalError)
+ return
+ }
+
+ // Return success (204 No Content)
+ w.WriteHeader(http.StatusNoContent)
+}
diff --git a/weed/s3api/s3api_bucket_skip_handlers.go b/weed/s3api/s3api_bucket_skip_handlers.go
index 798725203..d51d92b4d 100644
--- a/weed/s3api/s3api_bucket_skip_handlers.go
+++ b/weed/s3api/s3api_bucket_skip_handlers.go
@@ -8,24 +8,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
-// GetBucketCorsHandler Get bucket CORS
-// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketCors.html
-func (s3a *S3ApiServer) GetBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
- s3err.WriteErrorResponse(w, r, s3err.ErrNoSuchCORSConfiguration)
-}
-
-// PutBucketCorsHandler Put bucket CORS
-// https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketCors.html
-func (s3a *S3ApiServer) PutBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
- s3err.WriteErrorResponse(w, r, s3err.ErrNotImplemented)
-}
-
-// DeleteBucketCorsHandler Delete bucket CORS
-// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteBucketCors.html
-func (s3a *S3ApiServer) DeleteBucketCorsHandler(w http.ResponseWriter, r *http.Request) {
- s3err.WriteErrorResponse(w, r, http.StatusNoContent)
-}
-
// GetBucketPolicyHandler Get bucket Policy
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketPolicy.html
func (s3a *S3ApiServer) GetBucketPolicyHandler(w http.ResponseWriter, r *http.Request) {
diff --git a/weed/s3api/s3api_object_handlers.go b/weed/s3api/s3api_object_handlers.go
index 5163a72c2..6b811a024 100644
--- a/weed/s3api/s3api_object_handlers.go
+++ b/weed/s3api/s3api_object_handlers.go
@@ -20,6 +20,17 @@ import (
util_http "github.com/seaweedfs/seaweedfs/weed/util/http"
)
+// corsHeaders defines the CORS headers that need to be preserved
+// Package-level constant to avoid repeated allocations
+var corsHeaders = []string{
+ "Access-Control-Allow-Origin",
+ "Access-Control-Allow-Methods",
+ "Access-Control-Allow-Headers",
+ "Access-Control-Expose-Headers",
+ "Access-Control-Max-Age",
+ "Access-Control-Allow-Credentials",
+}
+
func mimeDetect(r *http.Request, dataReader io.Reader) io.ReadCloser {
mimeBuffer := make([]byte, 512)
size, _ := dataReader.Read(mimeBuffer)
@@ -381,10 +392,34 @@ func setUserMetadataKeyToLowercase(resp *http.Response) {
}
}
+func captureCORSHeaders(w http.ResponseWriter, headersToCapture []string) map[string]string {
+ captured := make(map[string]string)
+ for _, corsHeader := range headersToCapture {
+ if value := w.Header().Get(corsHeader); value != "" {
+ captured[corsHeader] = value
+ }
+ }
+ return captured
+}
+
+func restoreCORSHeaders(w http.ResponseWriter, capturedCORSHeaders map[string]string) {
+ for corsHeader, value := range capturedCORSHeaders {
+ w.Header().Set(corsHeader, value)
+ }
+}
+
func passThroughResponse(proxyResponse *http.Response, w http.ResponseWriter) (statusCode int, bytesTransferred int64) {
+ // Capture existing CORS headers that may have been set by middleware
+ capturedCORSHeaders := captureCORSHeaders(w, corsHeaders)
+
+ // Copy headers from proxy response
for k, v := range proxyResponse.Header {
w.Header()[k] = v
}
+
+ // Restore CORS headers that were set by middleware
+ restoreCORSHeaders(w, capturedCORSHeaders)
+
if proxyResponse.Header.Get("Content-Range") != "" && proxyResponse.StatusCode == 200 {
w.WriteHeader(http.StatusPartialContent)
statusCode = http.StatusPartialContent
diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go
index 426535fe0..5d113c645 100644
--- a/weed/s3api/s3api_server.go
+++ b/weed/s3api/s3api_server.go
@@ -121,6 +121,35 @@ func NewS3ApiServerWithStore(router *mux.Router, option *S3ApiServerOption, expl
return s3ApiServer, nil
}
+// handleCORSOriginValidation handles the common CORS origin validation logic
+func (s3a *S3ApiServer) handleCORSOriginValidation(w http.ResponseWriter, r *http.Request) bool {
+ origin := r.Header.Get("Origin")
+ if origin != "" {
+ if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" {
+ origin = "*"
+ } else {
+ originFound := false
+ for _, allowedOrigin := range s3a.option.AllowedOrigins {
+ if origin == allowedOrigin {
+ originFound = true
+ break
+ }
+ }
+ if !originFound {
+ writeFailureResponse(w, r, http.StatusForbidden)
+ return false
+ }
+ }
+ }
+
+ w.Header().Set("Access-Control-Allow-Origin", origin)
+ w.Header().Set("Access-Control-Expose-Headers", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "*")
+ w.Header().Set("Access-Control-Allow-Headers", "*")
+ w.Header().Set("Access-Control-Allow-Credentials", "true")
+ return true
+}
+
func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
// API Router
apiRouter := router.PathPrefix("/").Subrouter()
@@ -129,33 +158,6 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
apiRouter.Methods(http.MethodGet).Path("/status").HandlerFunc(s3a.StatusHandler)
apiRouter.Methods(http.MethodGet).Path("/healthz").HandlerFunc(s3a.StatusHandler)
- apiRouter.Methods(http.MethodOptions).HandlerFunc(
- func(w http.ResponseWriter, r *http.Request) {
- origin := r.Header.Get("Origin")
- if origin != "" {
- if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" {
- origin = "*"
- } else {
- originFound := false
- for _, allowedOrigin := range s3a.option.AllowedOrigins {
- if origin == allowedOrigin {
- originFound = true
- }
- }
- if !originFound {
- writeFailureResponse(w, r, http.StatusForbidden)
- return
- }
- }
- }
-
- w.Header().Set("Access-Control-Allow-Origin", origin)
- w.Header().Set("Access-Control-Expose-Headers", "*")
- w.Header().Set("Access-Control-Allow-Methods", "*")
- w.Header().Set("Access-Control-Allow-Headers", "*")
- writeSuccessResponseEmpty(w, r)
- })
-
var routers []*mux.Router
if s3a.option.DomainName != "" {
domainNames := strings.Split(s3a.option.DomainName, ",")
@@ -168,7 +170,16 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
}
routers = append(routers, apiRouter.PathPrefix("/{bucket}").Subrouter())
+ // Get CORS middleware instance with caching
+ corsMiddleware := s3a.getCORSMiddleware()
+
for _, bucket := range routers {
+ // Apply CORS middleware to bucket routers for automatic CORS header handling
+ bucket.Use(corsMiddleware.Handler)
+
+ // Bucket-specific OPTIONS handler for CORS preflight requests
+ // Use PathPrefix to catch all bucket-level preflight routes including /bucket/object
+ bucket.PathPrefix("/").Methods(http.MethodOptions).HandlerFunc(corsMiddleware.HandleOptionsRequest)
// each case should follow the next rule:
// - requesting object with query must precede any other methods
@@ -330,6 +341,25 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
}
+ // Global OPTIONS handler for service-level requests (non-bucket requests)
+ // This handles requests like OPTIONS /, OPTIONS /status, OPTIONS /healthz
+ // Place this after bucket handlers to avoid interfering with bucket CORS middleware
+ apiRouter.Methods(http.MethodOptions).PathPrefix("/").HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ // Only handle if this is not a bucket-specific request
+ vars := mux.Vars(r)
+ bucket := vars["bucket"]
+ if bucket != "" {
+ // This is a bucket-specific request, let bucket CORS middleware handle it
+ http.NotFound(w, r)
+ return
+ }
+
+ if s3a.handleCORSOriginValidation(w, r) {
+ writeSuccessResponseEmpty(w, r)
+ }
+ })
+
// ListBuckets
apiRouter.Methods(http.MethodGet).Path("/").HandlerFunc(track(s3a.ListBucketsHandler, "LIST"))
diff --git a/weed/s3api/s3err/error_handler.go b/weed/s3api/s3err/error_handler.go
index 910dab12a..81335c489 100644
--- a/weed/s3api/s3err/error_handler.go
+++ b/weed/s3api/s3err/error_handler.go
@@ -4,13 +4,14 @@ import (
"bytes"
"encoding/xml"
"fmt"
- "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
- "github.com/gorilla/mux"
- "github.com/seaweedfs/seaweedfs/weed/glog"
"net/http"
"strconv"
"strings"
"time"
+
+ "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
+ "github.com/gorilla/mux"
+ "github.com/seaweedfs/seaweedfs/weed/glog"
)
type mimeType string
@@ -76,10 +77,25 @@ func EncodeXMLResponse(response interface{}) []byte {
func setCommonHeaders(w http.ResponseWriter, r *http.Request) {
w.Header().Set("x-amz-request-id", fmt.Sprintf("%d", time.Now().UnixNano()))
w.Header().Set("Accept-Ranges", "bytes")
+
+ // Only set static CORS headers for service-level requests, not bucket-specific requests
if r.Header.Get("Origin") != "" {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Expose-Headers", "*")
- w.Header().Set("Access-Control-Allow-Credentials", "true")
+ // Use mux.Vars to detect bucket-specific requests more reliably
+ vars := mux.Vars(r)
+ bucket := vars["bucket"]
+ isBucketRequest := bucket != ""
+
+ // Only apply static CORS headers if this is NOT a bucket-specific request
+ // and no bucket-specific CORS headers were already set
+ if !isBucketRequest && w.Header().Get("Access-Control-Allow-Origin") == "" {
+ // This is a service-level request (like OPTIONS /), apply static CORS
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "*")
+ w.Header().Set("Access-Control-Allow-Headers", "*")
+ w.Header().Set("Access-Control-Expose-Headers", "*")
+ w.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+ // For bucket-specific requests, let the CORS middleware handle the headers
}
}