aboutsummaryrefslogtreecommitdiff
path: root/weed/s3api/cors/cors.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/s3api/cors/cors.go')
-rw-r--r--weed/s3api/cors/cors.go649
1 files changed, 649 insertions, 0 deletions
diff --git a/weed/s3api/cors/cors.go b/weed/s3api/cors/cors.go
new file mode 100644
index 000000000..1eef71b72
--- /dev/null
+++ b/weed/s3api/cors/cors.go
@@ -0,0 +1,649 @@
+package cors
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "net/http"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/seaweedfs/seaweedfs/weed/glog"
+ "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
+)
+
+// S3 metadata file name constant to avoid typos and reduce duplication
+const S3MetadataFileName = ".s3metadata"
+
+// CORSRule represents a single CORS rule
+type CORSRule struct {
+ ID string `xml:"ID,omitempty" json:"ID,omitempty"`
+ AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
+ AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
+ AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
+ ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
+ MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
+}
+
+// CORSConfiguration represents the CORS configuration for a bucket
+type CORSConfiguration struct {
+ XMLName xml.Name `xml:"CORSConfiguration"`
+ CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
+}
+
+// CORSRequest represents a CORS request
+type CORSRequest struct {
+ Origin string
+ Method string
+ RequestHeaders []string
+ IsPreflightRequest bool
+ AccessControlRequestMethod string
+ AccessControlRequestHeaders []string
+}
+
+// CORSResponse represents CORS response headers
+type CORSResponse struct {
+ AllowOrigin string
+ AllowMethods string
+ AllowHeaders string
+ ExposeHeaders string
+ MaxAge string
+ AllowCredentials bool
+}
+
+// ValidateConfiguration validates a CORS configuration
+func ValidateConfiguration(config *CORSConfiguration) error {
+ if config == nil {
+ return fmt.Errorf("CORS configuration cannot be nil")
+ }
+
+ if len(config.CORSRules) == 0 {
+ return fmt.Errorf("CORS configuration must have at least one rule")
+ }
+
+ if len(config.CORSRules) > 100 {
+ return fmt.Errorf("CORS configuration cannot have more than 100 rules")
+ }
+
+ for i, rule := range config.CORSRules {
+ if err := validateRule(&rule); err != nil {
+ return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
+ }
+ }
+
+ return nil
+}
+
+// validateRule validates a single CORS rule
+func validateRule(rule *CORSRule) error {
+ if len(rule.AllowedMethods) == 0 {
+ return fmt.Errorf("AllowedMethods cannot be empty")
+ }
+
+ if len(rule.AllowedOrigins) == 0 {
+ return fmt.Errorf("AllowedOrigins cannot be empty")
+ }
+
+ // Validate allowed methods
+ validMethods := map[string]bool{
+ "GET": true,
+ "PUT": true,
+ "POST": true,
+ "DELETE": true,
+ "HEAD": true,
+ }
+
+ for _, method := range rule.AllowedMethods {
+ if !validMethods[method] {
+ return fmt.Errorf("invalid HTTP method: %s", method)
+ }
+ }
+
+ // Validate origins
+ for _, origin := range rule.AllowedOrigins {
+ if origin == "*" {
+ continue
+ }
+ if err := validateOrigin(origin); err != nil {
+ return fmt.Errorf("invalid origin %s: %v", origin, err)
+ }
+ }
+
+ // Validate MaxAgeSeconds
+ if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
+ return fmt.Errorf("MaxAgeSeconds cannot be negative")
+ }
+
+ return nil
+}
+
+// validateOrigin validates an origin string
+func validateOrigin(origin string) error {
+ if origin == "" {
+ return fmt.Errorf("origin cannot be empty")
+ }
+
+ // Special case: "*" is always valid
+ if origin == "*" {
+ return nil
+ }
+
+ // Count wildcards
+ wildcardCount := strings.Count(origin, "*")
+ if wildcardCount > 1 {
+ return fmt.Errorf("origin can contain at most one wildcard")
+ }
+
+ // If there's a wildcard, it should be in a valid position
+ if wildcardCount == 1 {
+ // Must be in the format: http://*.example.com or https://*.example.com
+ if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
+ return fmt.Errorf("origin with wildcard must start with http:// or https://")
+ }
+ }
+
+ return nil
+}
+
+// ParseRequest parses an HTTP request to extract CORS information
+func ParseRequest(r *http.Request) *CORSRequest {
+ corsReq := &CORSRequest{
+ Origin: r.Header.Get("Origin"),
+ Method: r.Method,
+ }
+
+ // Check if this is a preflight request
+ if r.Method == "OPTIONS" {
+ corsReq.IsPreflightRequest = true
+ corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
+
+ if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
+ corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
+ for i := range corsReq.AccessControlRequestHeaders {
+ corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
+ }
+ }
+ }
+
+ return corsReq
+}
+
+// EvaluateRequest evaluates a CORS request against a CORS configuration
+func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
+ if config == nil || corsReq == nil {
+ return nil, fmt.Errorf("config and corsReq cannot be nil")
+ }
+
+ if corsReq.Origin == "" {
+ return nil, fmt.Errorf("origin header is required for CORS requests")
+ }
+
+ // Find the first rule that matches the origin
+ for _, rule := range config.CORSRules {
+ if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
+ // For preflight requests, we need more detailed validation
+ if corsReq.IsPreflightRequest {
+ return buildPreflightResponse(&rule, corsReq), nil
+ } else {
+ // For actual requests, check method
+ if contains(rule.AllowedMethods, corsReq.Method) {
+ return buildResponse(&rule, corsReq), nil
+ }
+ }
+ }
+ }
+
+ return nil, fmt.Errorf("no matching CORS rule found")
+}
+
+// matchesRule checks if a CORS request matches a CORS rule
+func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool {
+ // Check origin - this is the primary matching criterion
+ if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
+ return false
+ }
+
+ // For preflight requests, we need to validate both the requested method and headers
+ if corsReq.IsPreflightRequest {
+ // Check if the requested method is allowed
+ if corsReq.AccessControlRequestMethod != "" {
+ if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) {
+ return false
+ }
+ }
+
+ // Check if all requested headers are allowed
+ if len(corsReq.AccessControlRequestHeaders) > 0 {
+ for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
+ if !matchesHeader(rule.AllowedHeaders, requestedHeader) {
+ return false
+ }
+ }
+ }
+
+ return true
+ }
+
+ // For non-preflight requests, check method matching
+ method := corsReq.Method
+ if !contains(rule.AllowedMethods, method) {
+ return false
+ }
+
+ return true
+}
+
+// matchesOrigin checks if an origin matches any of the allowed origins
+func matchesOrigin(allowedOrigins []string, origin string) bool {
+ for _, allowedOrigin := range allowedOrigins {
+ if allowedOrigin == "*" {
+ return true
+ }
+
+ if allowedOrigin == origin {
+ return true
+ }
+
+ // Check wildcard matching
+ if strings.Contains(allowedOrigin, "*") {
+ if matchesWildcard(allowedOrigin, origin) {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// matchesWildcard checks if an origin matches a wildcard pattern
+// Uses string manipulation instead of regex for better performance
+func matchesWildcard(pattern, origin string) bool {
+ // Handle simple cases first
+ if pattern == "*" {
+ return true
+ }
+ if pattern == origin {
+ return true
+ }
+
+ // For CORS, we typically only deal with * wildcards (not ? wildcards)
+ // Use string manipulation for * wildcards only (more efficient than regex)
+
+ // Split pattern by wildcards
+ parts := strings.Split(pattern, "*")
+ if len(parts) == 1 {
+ // No wildcards, exact match
+ return pattern == origin
+ }
+
+ // Check if string starts with first part
+ if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) {
+ return false
+ }
+
+ // Check if string ends with last part
+ if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) {
+ return false
+ }
+
+ // Check middle parts
+ searchStr := origin
+ if len(parts[0]) > 0 {
+ searchStr = searchStr[len(parts[0]):]
+ }
+ if len(parts[len(parts)-1]) > 0 {
+ searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])]
+ }
+
+ for i := 1; i < len(parts)-1; i++ {
+ if len(parts[i]) > 0 {
+ index := strings.Index(searchStr, parts[i])
+ if index == -1 {
+ return false
+ }
+ searchStr = searchStr[index+len(parts[i]):]
+ }
+ }
+
+ return true
+}
+
+// matchesHeader checks if a header matches allowed headers
+func matchesHeader(allowedHeaders []string, header string) bool {
+ if len(allowedHeaders) == 0 {
+ return true // No restrictions
+ }
+
+ for _, allowedHeader := range allowedHeaders {
+ if allowedHeader == "*" {
+ return true
+ }
+
+ if strings.EqualFold(allowedHeader, header) {
+ return true
+ }
+
+ // Check wildcard matching for headers
+ if strings.Contains(allowedHeader, "*") {
+ if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// buildPreflightResponse builds a CORS response for preflight requests
+// This function allows partial matches - origin can match while methods/headers may not
+func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
+ response := &CORSResponse{
+ AllowOrigin: corsReq.Origin,
+ }
+
+ // Check if the requested method is allowed
+ methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
+
+ // Check requested headers
+ var allowedRequestHeaders []string
+ allHeadersAllowed := true
+
+ if len(corsReq.AccessControlRequestHeaders) > 0 {
+ // Check if wildcard is allowed
+ hasWildcard := false
+ for _, header := range rule.AllowedHeaders {
+ if header == "*" {
+ hasWildcard = true
+ break
+ }
+ }
+
+ if hasWildcard {
+ // All requested headers are allowed with wildcard
+ allowedRequestHeaders = corsReq.AccessControlRequestHeaders
+ } else {
+ // Check each requested header individually
+ for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
+ if matchesHeader(rule.AllowedHeaders, requestedHeader) {
+ allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
+ } else {
+ allHeadersAllowed = false
+ }
+ }
+ }
+ }
+
+ // Only set method and header info if both method and ALL headers are allowed
+ if methodAllowed && allHeadersAllowed {
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+
+ if len(allowedRequestHeaders) > 0 {
+ response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
+ }
+
+ // Set exposed headers
+ if len(rule.ExposeHeaders) > 0 {
+ response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
+ }
+
+ // Set max age
+ if rule.MaxAgeSeconds != nil {
+ response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
+ }
+ }
+
+ return response
+}
+
+// buildResponse builds a CORS response from a matching rule
+func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
+ response := &CORSResponse{
+ AllowOrigin: corsReq.Origin,
+ }
+
+ // Set allowed methods - for preflight requests, return all allowed methods
+ if corsReq.IsPreflightRequest {
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+ } else {
+ // For non-preflight requests, return all allowed methods
+ response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
+ }
+
+ // Set allowed headers
+ if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 {
+ // For preflight requests, check if wildcard is allowed
+ hasWildcard := false
+ for _, header := range rule.AllowedHeaders {
+ if header == "*" {
+ hasWildcard = true
+ break
+ }
+ }
+
+ if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 {
+ // Return the specific headers that were requested when wildcard is allowed
+ response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
+ } else if len(corsReq.AccessControlRequestHeaders) > 0 {
+ // For non-wildcard cases, return the requested headers (preserving case)
+ // since we already validated they are allowed in matchesRule
+ response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
+ } else {
+ // Fallback to configured headers if no specific headers were requested
+ response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
+ }
+ } else if len(rule.AllowedHeaders) > 0 {
+ // For non-preflight requests, return the allowed headers from the rule
+ response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
+ }
+
+ // Set exposed headers
+ if len(rule.ExposeHeaders) > 0 {
+ response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
+ }
+
+ // Set max age
+ if rule.MaxAgeSeconds != nil {
+ response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
+ }
+
+ return response
+}
+
+// contains checks if a slice contains a string
+func contains(slice []string, item string) bool {
+ for _, s := range slice {
+ if s == item {
+ return true
+ }
+ }
+ return false
+}
+
+// ApplyHeaders applies CORS headers to an HTTP response
+func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
+ if corsResp == nil {
+ return
+ }
+
+ if corsResp.AllowOrigin != "" {
+ w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
+ }
+
+ if corsResp.AllowMethods != "" {
+ w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
+ }
+
+ if corsResp.AllowHeaders != "" {
+ w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
+ }
+
+ if corsResp.ExposeHeaders != "" {
+ w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
+ }
+
+ if corsResp.MaxAge != "" {
+ w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
+ }
+
+ if corsResp.AllowCredentials {
+ w.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+}
+
+// FilerClient interface for dependency injection
+type FilerClient interface {
+ WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error
+}
+
+// EntryGetter interface for getting filer entries
+type EntryGetter interface {
+ GetEntry(directory, name string) (*filer_pb.Entry, error)
+}
+
+// Storage provides CORS configuration storage operations
+type Storage struct {
+ filerClient FilerClient
+ entryGetter EntryGetter
+ bucketsPath string
+}
+
+// NewStorage creates a new CORS storage instance
+func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage {
+ return &Storage{
+ filerClient: filerClient,
+ entryGetter: entryGetter,
+ bucketsPath: bucketsPath,
+ }
+}
+
+// Store stores CORS configuration in the filer
+func (s *Storage) Store(bucket string, config *CORSConfiguration) error {
+ // Store in bucket metadata
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ // Get existing metadata
+ existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ var metadata map[string]interface{}
+
+ if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 {
+ if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil {
+ glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err)
+ metadata = make(map[string]interface{})
+ }
+ } else {
+ metadata = make(map[string]interface{})
+ }
+
+ metadata["cors"] = config
+
+ metadataBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return fmt.Errorf("failed to marshal bucket metadata: %v", err)
+ }
+
+ // Store metadata
+ return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
+ request := &filer_pb.CreateEntryRequest{
+ Directory: s.bucketsPath + "/" + bucket,
+ Entry: &filer_pb.Entry{
+ Name: S3MetadataFileName,
+ IsDirectory: false,
+ Attributes: &filer_pb.FuseAttributes{
+ Crtime: time.Now().Unix(),
+ Mtime: time.Now().Unix(),
+ FileMode: 0644,
+ },
+ Content: metadataBytes,
+ },
+ }
+
+ _, err := client.CreateEntry(context.Background(), request)
+ return err
+ })
+}
+
+// Load loads CORS configuration from the filer
+func (s *Storage) Load(bucket string) (*CORSConfiguration, error) {
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ if err != nil || entry == nil {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ if len(entry.Content) == 0 {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ var metadata map[string]interface{}
+ if err := json.Unmarshal(entry.Content, &metadata); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
+ }
+
+ corsData, exists := metadata["cors"]
+ if !exists {
+ return nil, fmt.Errorf("no CORS configuration found")
+ }
+
+ // Convert back to CORSConfiguration
+ corsBytes, err := json.Marshal(corsData)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal CORS data: %v", err)
+ }
+
+ var config CORSConfiguration
+ if err := json.Unmarshal(corsBytes, &config); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
+ }
+
+ return &config, nil
+}
+
+// Delete deletes CORS configuration from the filer
+func (s *Storage) Delete(bucket string) error {
+ bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
+
+ entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
+ if err != nil || entry == nil {
+ return nil // Already deleted or doesn't exist
+ }
+
+ var metadata map[string]interface{}
+ if len(entry.Content) > 0 {
+ if err := json.Unmarshal(entry.Content, &metadata); err != nil {
+ return fmt.Errorf("failed to unmarshal metadata: %v", err)
+ }
+ } else {
+ return nil // No metadata to delete
+ }
+
+ // Remove CORS configuration
+ delete(metadata, "cors")
+
+ metadataBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return fmt.Errorf("failed to marshal metadata: %v", err)
+ }
+
+ // Update metadata
+ return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
+ request := &filer_pb.CreateEntryRequest{
+ Directory: s.bucketsPath + "/" + bucket,
+ Entry: &filer_pb.Entry{
+ Name: S3MetadataFileName,
+ IsDirectory: false,
+ Attributes: &filer_pb.FuseAttributes{
+ Crtime: time.Now().Unix(),
+ Mtime: time.Now().Unix(),
+ FileMode: 0644,
+ },
+ Content: metadataBytes,
+ },
+ }
+
+ _, err := client.CreateEntry(context.Background(), request)
+ return err
+ })
+}