diff options
Diffstat (limited to 'weed/s3api/cors/cors.go')
| -rw-r--r-- | weed/s3api/cors/cors.go | 649 |
1 files changed, 649 insertions, 0 deletions
diff --git a/weed/s3api/cors/cors.go b/weed/s3api/cors/cors.go new file mode 100644 index 000000000..1eef71b72 --- /dev/null +++ b/weed/s3api/cors/cors.go @@ -0,0 +1,649 @@ +package cors + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" +) + +// S3 metadata file name constant to avoid typos and reduce duplication +const S3MetadataFileName = ".s3metadata" + +// CORSRule represents a single CORS rule +type CORSRule struct { + ID string `xml:"ID,omitempty" json:"ID,omitempty"` + AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"` + AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"` + AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"` + ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"` + MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"` +} + +// CORSConfiguration represents the CORS configuration for a bucket +type CORSConfiguration struct { + XMLName xml.Name `xml:"CORSConfiguration"` + CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"` +} + +// CORSRequest represents a CORS request +type CORSRequest struct { + Origin string + Method string + RequestHeaders []string + IsPreflightRequest bool + AccessControlRequestMethod string + AccessControlRequestHeaders []string +} + +// CORSResponse represents CORS response headers +type CORSResponse struct { + AllowOrigin string + AllowMethods string + AllowHeaders string + ExposeHeaders string + MaxAge string + AllowCredentials bool +} + +// ValidateConfiguration validates a CORS configuration +func ValidateConfiguration(config *CORSConfiguration) error { + if config == nil { + return fmt.Errorf("CORS configuration cannot be nil") + } + + if len(config.CORSRules) == 0 { + return fmt.Errorf("CORS configuration must have at least one rule") + } + + if len(config.CORSRules) > 100 { + return fmt.Errorf("CORS configuration cannot have more than 100 rules") + } + + for i, rule := range config.CORSRules { + if err := validateRule(&rule); err != nil { + return fmt.Errorf("invalid CORS rule at index %d: %v", i, err) + } + } + + return nil +} + +// validateRule validates a single CORS rule +func validateRule(rule *CORSRule) error { + if len(rule.AllowedMethods) == 0 { + return fmt.Errorf("AllowedMethods cannot be empty") + } + + if len(rule.AllowedOrigins) == 0 { + return fmt.Errorf("AllowedOrigins cannot be empty") + } + + // Validate allowed methods + validMethods := map[string]bool{ + "GET": true, + "PUT": true, + "POST": true, + "DELETE": true, + "HEAD": true, + } + + for _, method := range rule.AllowedMethods { + if !validMethods[method] { + return fmt.Errorf("invalid HTTP method: %s", method) + } + } + + // Validate origins + for _, origin := range rule.AllowedOrigins { + if origin == "*" { + continue + } + if err := validateOrigin(origin); err != nil { + return fmt.Errorf("invalid origin %s: %v", origin, err) + } + } + + // Validate MaxAgeSeconds + if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 { + return fmt.Errorf("MaxAgeSeconds cannot be negative") + } + + return nil +} + +// validateOrigin validates an origin string +func validateOrigin(origin string) error { + if origin == "" { + return fmt.Errorf("origin cannot be empty") + } + + // Special case: "*" is always valid + if origin == "*" { + return nil + } + + // Count wildcards + wildcardCount := strings.Count(origin, "*") + if wildcardCount > 1 { + return fmt.Errorf("origin can contain at most one wildcard") + } + + // If there's a wildcard, it should be in a valid position + if wildcardCount == 1 { + // Must be in the format: http://*.example.com or https://*.example.com + if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { + return fmt.Errorf("origin with wildcard must start with http:// or https://") + } + } + + return nil +} + +// ParseRequest parses an HTTP request to extract CORS information +func ParseRequest(r *http.Request) *CORSRequest { + corsReq := &CORSRequest{ + Origin: r.Header.Get("Origin"), + Method: r.Method, + } + + // Check if this is a preflight request + if r.Method == "OPTIONS" { + corsReq.IsPreflightRequest = true + corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method") + + if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" { + corsReq.AccessControlRequestHeaders = strings.Split(headers, ",") + for i := range corsReq.AccessControlRequestHeaders { + corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i]) + } + } + } + + return corsReq +} + +// EvaluateRequest evaluates a CORS request against a CORS configuration +func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) { + if config == nil || corsReq == nil { + return nil, fmt.Errorf("config and corsReq cannot be nil") + } + + if corsReq.Origin == "" { + return nil, fmt.Errorf("origin header is required for CORS requests") + } + + // Find the first rule that matches the origin + for _, rule := range config.CORSRules { + if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) { + // For preflight requests, we need more detailed validation + if corsReq.IsPreflightRequest { + return buildPreflightResponse(&rule, corsReq), nil + } else { + // For actual requests, check method + if contains(rule.AllowedMethods, corsReq.Method) { + return buildResponse(&rule, corsReq), nil + } + } + } + } + + return nil, fmt.Errorf("no matching CORS rule found") +} + +// matchesRule checks if a CORS request matches a CORS rule +func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool { + // Check origin - this is the primary matching criterion + if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) { + return false + } + + // For preflight requests, we need to validate both the requested method and headers + if corsReq.IsPreflightRequest { + // Check if the requested method is allowed + if corsReq.AccessControlRequestMethod != "" { + if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) { + return false + } + } + + // Check if all requested headers are allowed + if len(corsReq.AccessControlRequestHeaders) > 0 { + for _, requestedHeader := range corsReq.AccessControlRequestHeaders { + if !matchesHeader(rule.AllowedHeaders, requestedHeader) { + return false + } + } + } + + return true + } + + // For non-preflight requests, check method matching + method := corsReq.Method + if !contains(rule.AllowedMethods, method) { + return false + } + + return true +} + +// matchesOrigin checks if an origin matches any of the allowed origins +func matchesOrigin(allowedOrigins []string, origin string) bool { + for _, allowedOrigin := range allowedOrigins { + if allowedOrigin == "*" { + return true + } + + if allowedOrigin == origin { + return true + } + + // Check wildcard matching + if strings.Contains(allowedOrigin, "*") { + if matchesWildcard(allowedOrigin, origin) { + return true + } + } + } + return false +} + +// matchesWildcard checks if an origin matches a wildcard pattern +// Uses string manipulation instead of regex for better performance +func matchesWildcard(pattern, origin string) bool { + // Handle simple cases first + if pattern == "*" { + return true + } + if pattern == origin { + return true + } + + // For CORS, we typically only deal with * wildcards (not ? wildcards) + // Use string manipulation for * wildcards only (more efficient than regex) + + // Split pattern by wildcards + parts := strings.Split(pattern, "*") + if len(parts) == 1 { + // No wildcards, exact match + return pattern == origin + } + + // Check if string starts with first part + if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) { + return false + } + + // Check if string ends with last part + if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) { + return false + } + + // Check middle parts + searchStr := origin + if len(parts[0]) > 0 { + searchStr = searchStr[len(parts[0]):] + } + if len(parts[len(parts)-1]) > 0 { + searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])] + } + + for i := 1; i < len(parts)-1; i++ { + if len(parts[i]) > 0 { + index := strings.Index(searchStr, parts[i]) + if index == -1 { + return false + } + searchStr = searchStr[index+len(parts[i]):] + } + } + + return true +} + +// matchesHeader checks if a header matches allowed headers +func matchesHeader(allowedHeaders []string, header string) bool { + if len(allowedHeaders) == 0 { + return true // No restrictions + } + + for _, allowedHeader := range allowedHeaders { + if allowedHeader == "*" { + return true + } + + if strings.EqualFold(allowedHeader, header) { + return true + } + + // Check wildcard matching for headers + if strings.Contains(allowedHeader, "*") { + if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) { + return true + } + } + } + + return false +} + +// buildPreflightResponse builds a CORS response for preflight requests +// This function allows partial matches - origin can match while methods/headers may not +func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse { + response := &CORSResponse{ + AllowOrigin: corsReq.Origin, + } + + // Check if the requested method is allowed + methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) + + // Check requested headers + var allowedRequestHeaders []string + allHeadersAllowed := true + + if len(corsReq.AccessControlRequestHeaders) > 0 { + // Check if wildcard is allowed + hasWildcard := false + for _, header := range rule.AllowedHeaders { + if header == "*" { + hasWildcard = true + break + } + } + + if hasWildcard { + // All requested headers are allowed with wildcard + allowedRequestHeaders = corsReq.AccessControlRequestHeaders + } else { + // Check each requested header individually + for _, requestedHeader := range corsReq.AccessControlRequestHeaders { + if matchesHeader(rule.AllowedHeaders, requestedHeader) { + allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader) + } else { + allHeadersAllowed = false + } + } + } + } + + // Only set method and header info if both method and ALL headers are allowed + if methodAllowed && allHeadersAllowed { + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + + if len(allowedRequestHeaders) > 0 { + response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ") + } + + // Set exposed headers + if len(rule.ExposeHeaders) > 0 { + response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ") + } + + // Set max age + if rule.MaxAgeSeconds != nil { + response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds) + } + } + + return response +} + +// buildResponse builds a CORS response from a matching rule +func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse { + response := &CORSResponse{ + AllowOrigin: corsReq.Origin, + } + + // Set allowed methods - for preflight requests, return all allowed methods + if corsReq.IsPreflightRequest { + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + } else { + // For non-preflight requests, return all allowed methods + response.AllowMethods = strings.Join(rule.AllowedMethods, ", ") + } + + // Set allowed headers + if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 { + // For preflight requests, check if wildcard is allowed + hasWildcard := false + for _, header := range rule.AllowedHeaders { + if header == "*" { + hasWildcard = true + break + } + } + + if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 { + // Return the specific headers that were requested when wildcard is allowed + response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ") + } else if len(corsReq.AccessControlRequestHeaders) > 0 { + // For non-wildcard cases, return the requested headers (preserving case) + // since we already validated they are allowed in matchesRule + response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ") + } else { + // Fallback to configured headers if no specific headers were requested + response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ") + } + } else if len(rule.AllowedHeaders) > 0 { + // For non-preflight requests, return the allowed headers from the rule + response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ") + } + + // Set exposed headers + if len(rule.ExposeHeaders) > 0 { + response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ") + } + + // Set max age + if rule.MaxAgeSeconds != nil { + response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds) + } + + return response +} + +// contains checks if a slice contains a string +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// ApplyHeaders applies CORS headers to an HTTP response +func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) { + if corsResp == nil { + return + } + + if corsResp.AllowOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin) + } + + if corsResp.AllowMethods != "" { + w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods) + } + + if corsResp.AllowHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders) + } + + if corsResp.ExposeHeaders != "" { + w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders) + } + + if corsResp.MaxAge != "" { + w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge) + } + + if corsResp.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } +} + +// FilerClient interface for dependency injection +type FilerClient interface { + WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error +} + +// EntryGetter interface for getting filer entries +type EntryGetter interface { + GetEntry(directory, name string) (*filer_pb.Entry, error) +} + +// Storage provides CORS configuration storage operations +type Storage struct { + filerClient FilerClient + entryGetter EntryGetter + bucketsPath string +} + +// NewStorage creates a new CORS storage instance +func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage { + return &Storage{ + filerClient: filerClient, + entryGetter: entryGetter, + bucketsPath: bucketsPath, + } +} + +// Store stores CORS configuration in the filer +func (s *Storage) Store(bucket string, config *CORSConfiguration) error { + // Store in bucket metadata + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + // Get existing metadata + existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + var metadata map[string]interface{} + + if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 { + if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil { + glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err) + metadata = make(map[string]interface{}) + } + } else { + metadata = make(map[string]interface{}) + } + + metadata["cors"] = config + + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal bucket metadata: %v", err) + } + + // Store metadata + return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { + request := &filer_pb.CreateEntryRequest{ + Directory: s.bucketsPath + "/" + bucket, + Entry: &filer_pb.Entry{ + Name: S3MetadataFileName, + IsDirectory: false, + Attributes: &filer_pb.FuseAttributes{ + Crtime: time.Now().Unix(), + Mtime: time.Now().Unix(), + FileMode: 0644, + }, + Content: metadataBytes, + }, + } + + _, err := client.CreateEntry(context.Background(), request) + return err + }) +} + +// Load loads CORS configuration from the filer +func (s *Storage) Load(bucket string) (*CORSConfiguration, error) { + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + entry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + if err != nil || entry == nil { + return nil, fmt.Errorf("no CORS configuration found") + } + + if len(entry.Content) == 0 { + return nil, fmt.Errorf("no CORS configuration found") + } + + var metadata map[string]interface{} + if err := json.Unmarshal(entry.Content, &metadata); err != nil { + return nil, fmt.Errorf("failed to unmarshal metadata: %v", err) + } + + corsData, exists := metadata["cors"] + if !exists { + return nil, fmt.Errorf("no CORS configuration found") + } + + // Convert back to CORSConfiguration + corsBytes, err := json.Marshal(corsData) + if err != nil { + return nil, fmt.Errorf("failed to marshal CORS data: %v", err) + } + + var config CORSConfiguration + if err := json.Unmarshal(corsBytes, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err) + } + + return &config, nil +} + +// Delete deletes CORS configuration from the filer +func (s *Storage) Delete(bucket string) error { + bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName) + + entry, err := s.entryGetter.GetEntry("", bucketMetadataPath) + if err != nil || entry == nil { + return nil // Already deleted or doesn't exist + } + + var metadata map[string]interface{} + if len(entry.Content) > 0 { + if err := json.Unmarshal(entry.Content, &metadata); err != nil { + return fmt.Errorf("failed to unmarshal metadata: %v", err) + } + } else { + return nil // No metadata to delete + } + + // Remove CORS configuration + delete(metadata, "cors") + + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("failed to marshal metadata: %v", err) + } + + // Update metadata + return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { + request := &filer_pb.CreateEntryRequest{ + Directory: s.bucketsPath + "/" + bucket, + Entry: &filer_pb.Entry{ + Name: S3MetadataFileName, + IsDirectory: false, + Attributes: &filer_pb.FuseAttributes{ + Crtime: time.Now().Unix(), + Mtime: time.Now().Unix(), + FileMode: 0644, + }, + Content: metadataBytes, + }, + } + + _, err := client.CreateEntry(context.Background(), request) + return err + }) +} |
